From ed5dc6a0776ce96c6a6cfe990395f3aca44106d5 Mon Sep 17 00:00:00 2001 From: Alexandros Solanos Date: Tue, 16 Dec 2025 15:28:35 +0100 Subject: [PATCH 001/539] Improve model repetition detection performance --- .../litellm_core_utils/streaming_handler.py | 52 ++++---- .../test_streaming_handler.py | 117 ++++++++++++++++++ 2 files changed, 147 insertions(+), 22 deletions(-) diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index d92af41717..b8240839d7 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -145,6 +145,7 @@ class CustomStreamWrapper: self.chunks: List = ( [] ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options + self._repeated_messages_count = 1 self.is_function_call = self.check_is_function_call(logging_obj=logging_obj) self.created: Optional[int] = None @@ -190,7 +191,7 @@ class CustomStreamWrapper: except Exception as e: raise e - def safety_checker(self) -> None: + def raise_on_model_repetition(self) -> None: """ Fixes - https://github.com/BerriAI/litellm/issues/5158 @@ -198,29 +199,36 @@ class CustomStreamWrapper: Raises - InternalServerError, if LLM enters infinite loop while streaming """ - if len(self.chunks) >= litellm.REPEATED_STREAMING_CHUNK_LIMIT: - # Get the last n chunks - last_chunks = self.chunks[-litellm.REPEATED_STREAMING_CHUNK_LIMIT :] + if len(self.chunks) < 2: + return - # Extract the relevant content from the chunks - last_contents = [chunk.choices[0].delta.content for chunk in last_chunks] + last_content = self.chunks[-1].choices[0].delta.content - # Check if all extracted contents are identical - if all(content == last_contents[0] for content in last_contents): - if ( - last_contents[0] is not None - and isinstance(last_contents[0], str) - and len(last_contents[0]) > 2 - ): # ignore empty content - https://github.com/BerriAI/litellm/issues/5158#issuecomment-2287156946 - # All last n chunks are identical - raise litellm.InternalServerError( - message="The model is repeating the same chunk = {}.".format( - last_contents[0] - ), - model="", - llm_provider="", - ) + if ( + last_content is None + or not isinstance(last_content, str) + or len(last_content) <= 2 + ): # ignore empty content - https://github.com/BerriAI/litellm/issues/5158#issuecomment-2287156946 + self._repeated_messages_count = 1 + return + second_to_last_content = self.chunks[-2].choices[0].delta.content + + if last_content == second_to_last_content: + self._repeated_messages_count += 1 + else: + self._repeated_messages_count = 1 + + if self._repeated_messages_count >= litellm.REPEATED_STREAMING_CHUNK_LIMIT: + # All last n chunks are identical + raise litellm.InternalServerError( + message="The model is repeating the same chunk = {}.".format( + last_content + ), + model="", + llm_provider="", + ) + def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): """ Output parse / special tokens for sagemaker + hf streaming. @@ -879,7 +887,7 @@ class CustomStreamWrapper: if ( is_chunk_non_empty ): # cannot set content of an OpenAI Object to be an empty string - self.safety_checker() + self.raise_on_model_repetition() hold, model_response_str = self.check_special_tokens( chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason, diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index 6a528fef8f..fcce6ddc67 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -1162,3 +1162,120 @@ def test_is_chunk_non_empty_with_valid_tool_calls( ) is True ) + + +def _make_chunk(content: Optional[str]) -> ModelResponseStream: + return ModelResponseStream( + id="test", + created=1741037890, + model="test-model", + choices=[StreamingChoices(index=0, delta=Delta(content=content))], + ) + + +def _build_chunks(pattern: list[str], N: int) -> list[ModelResponseStream]: + """ + Build a list of chunks based on a pattern specification. + """ + chunks = [] + for i, p in enumerate(pattern): + if p == "same": + chunks.append(_make_chunk("same_chunk")) + elif p == "diff": + chunks.append(_make_chunk(f"chunk_{i}")) + else: + chunks.append(_make_chunk(p)) + return chunks + +_REPETITION_TEST_CASES = [ + # Basic cases + pytest.param( + ["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + True, + id="all_identical_raises", + ), + pytest.param( + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1), + False, + id="below_threshold_no_raise", + ), + pytest.param( + [None] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="none_content_no_raise", + ), + pytest.param( + [""] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="empty_content_no_raise", + ), + # Short content (len <= 2) should not raise + pytest.param( + ["##"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="short_content_2chars_no_raise", + ), + pytest.param( + ["{"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="short_content_1char_no_raise", + ), + pytest.param( + ["ab"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="short_content_2chars_ab_no_raise", + ), + # All different chunks + pytest.param( + ["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT, + False, + id="all_different_no_raise", + ), + # One chunk different at various positions + pytest.param( + ["different_first"] + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1), + False, + id="first_chunk_different_no_raise", + ), + pytest.param( + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1) + ["different_last"], + False, + id="last_chunk_different_no_raise", + ), + pytest.param( + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1) + ["different_mid"] + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - litellm.REPEATED_STREAMING_CHUNK_LIMIT // 2 + 1), + False, + id="middle_chunk_different_no_raise", + ), + pytest.param( + ["same"] * (litellm.REPEATED_STREAMING_CHUNK_LIMIT - 2) + ["diff", "diff"], + False, + id="last_two_different_no_raise", + ), + pytest.param( + ["diff"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["same"] * litellm.REPEATED_STREAMING_CHUNK_LIMIT + ["diff"], + True, + id="in_between_same_and_diff_raise", + ), +] + + +@pytest.mark.parametrize("chunks_pattern,should_raise", _REPETITION_TEST_CASES) +def test_raise_on_model_repetition( + initialized_custom_stream_wrapper: CustomStreamWrapper, + chunks_pattern: list, + should_raise: bool, +): + wrapper = initialized_custom_stream_wrapper + chunks = _build_chunks(chunks_pattern, len(chunks_pattern)) + + if should_raise: + with pytest.raises(litellm.InternalServerError) as exc_info: + for chunk in chunks: + wrapper.chunks.append(chunk) + wrapper.raise_on_model_repetition() + assert "repeating the same chunk" in str(exc_info.value) + else: + for chunk in chunks: + wrapper.chunks.append(chunk) + wrapper.raise_on_model_repetition() \ No newline at end of file From f6e3baafc519a38eec8eb5079b2d8b684a4351b9 Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Feb 2026 22:52:23 -0300 Subject: [PATCH 002/539] fix(anthropic): preserve thinking.summary when routing to OpenAI Responses API Read summary from the original thinking dict instead of hardcoding "detailed" in _route_openai_thinking_to_responses_api_if_needed(). This preserves the user's chosen summary value (e.g. "concise", "auto") for non-Claude models routed through the Anthropic Messages adapter to OpenAI's Responses API. Fixes #20998 --- .../adapters/handler.py | 3 +- ...erimental_pass_through_messages_handler.py | 75 ++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index 73e74c228b..c0b77798f2 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -78,9 +78,10 @@ class LiteLLMMessagesToCompletionTransformationHandler: reasoning_effort = completion_kwargs.get("reasoning_effort") if isinstance(reasoning_effort, str) and reasoning_effort: + summary = thinking.get("summary", "detailed") if isinstance(thinking, dict) else "detailed" completion_kwargs["reasoning_effort"] = { "effort": reasoning_effort, - "summary": "detailed", + "summary": summary, } elif isinstance(reasoning_effort, dict): if ( diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index 376d14416a..e2639a3126 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -209,12 +209,83 @@ class TestThinkingParameterTransformation: from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( LiteLLMAnthropicMessagesAdapter, ) - + thinking = {"type": "enabled", "budget_tokens": 1024} result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( thinking=thinking, model="openai/gpt-5.2", ) - + assert result == {"reasoning_effort": "minimal"} assert "thinking" not in result + + +class TestThinkingSummaryPreservation: + """Tests for issue #20998: thinking.summary must be preserved when routing to OpenAI Responses API.""" + + def test_thinking_summary_concise_preserved_for_openai(self): + """User-provided summary='concise' should not be replaced with 'detailed'.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + thinking = {"type": "enabled", "budget_tokens": 5000, "summary": "concise"} + completion_kwargs = {"model": "openai/gpt-5.1", "reasoning_effort": "medium"} + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking=thinking + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "medium", "summary": "concise"} + + def test_thinking_summary_auto_preserved_for_openai(self): + """User-provided summary='auto' should be preserved.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + thinking = {"type": "enabled", "budget_tokens": 10000, "summary": "auto"} + completion_kwargs = {"model": "openai/gpt-5.1", "reasoning_effort": "high"} + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking=thinking + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "high", "summary": "auto"} + + def test_thinking_without_summary_defaults_to_detailed(self): + """When no summary is provided, default 'detailed' should still be used.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + thinking = {"type": "enabled", "budget_tokens": 5000} + completion_kwargs = {"model": "openai/gpt-5.1", "reasoning_effort": "medium"} + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking=thinking + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "medium", "summary": "detailed"} + + def test_openai_model_with_thinking_summary_end_to_end(self): + """End-to-end: anthropic_messages_handler should preserve thinking.summary for OpenAI models.""" + from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + anthropic_messages_handler, + ) + + with patch("litellm.completion", return_value="test-response") as mock_completion: + try: + anthropic_messages_handler( + max_tokens=1024, + messages=[{"role": "user", "content": "What is 2+2?"}], + model="openai/gpt-5.2", + api_key="test-api-key", + thinking={ + "type": "enabled", + "budget_tokens": 5000, + "summary": "concise", + }, + ) + except Exception: + pass + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args.kwargs + reasoning_effort = call_kwargs["reasoning_effort"] + assert reasoning_effort["summary"] == "concise", \ + f"Expected summary='concise', got summary='{reasoning_effort.get('summary')}'" From ece032523498929d7bde4f52368b18ec247a8750 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 4 Mar 2026 10:45:28 -0300 Subject: [PATCH 003/539] fix(anthropic): make thinking.summary opt-in, don't hardcode default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove hardcoded summary="detailed" injection — summary is opt-in per OpenAI spec and increases costs. Users opt-in per-request via LiteLLM extension: thinking={"type": "enabled", "budget_tokens": N, "summary": "concise"}. Also preserve summary in translate_thinking_for_model() which previously dropped it when converting thinking → reasoning_effort for non-Claude models. Fixes #20998 --- .../adapters/handler.py | 20 +++++----- .../adapters/transformation.py | 3 ++ ...erimental_pass_through_messages_handler.py | 37 ++++++++++++++++--- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index c0b77798f2..01c8f39ee8 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -44,8 +44,9 @@ class LiteLLMMessagesToCompletionTransformationHandler: For OpenAI models, Chat Completions typically does not return reasoning text (only token accounting). To return a thinking-like content block in the - Anthropic response format, we route the request through OpenAI's Responses API - and request a reasoning summary. + Anthropic response format, we route the request through OpenAI's Responses API. + If the user provides a `summary` field in the thinking dict, it is passed + through to the OpenAI reasoning params (opt-in per OpenAI spec). """ custom_llm_provider = completion_kwargs.get("custom_llm_provider") if custom_llm_provider is None: @@ -77,19 +78,20 @@ class LiteLLMMessagesToCompletionTransformationHandler: completion_kwargs["model"] = f"responses/{model}" reasoning_effort = completion_kwargs.get("reasoning_effort") + summary = thinking.get("summary") if isinstance(thinking, dict) else None if isinstance(reasoning_effort, str) and reasoning_effort: - summary = thinking.get("summary", "detailed") if isinstance(thinking, dict) else "detailed" - completion_kwargs["reasoning_effort"] = { - "effort": reasoning_effort, - "summary": summary, - } + reasoning_dict: Dict[str, Any] = {"effort": reasoning_effort} + if summary: + reasoning_dict["summary"] = summary + completion_kwargs["reasoning_effort"] = reasoning_dict elif isinstance(reasoning_effort, dict): if ( - "summary" not in reasoning_effort + summary + and "summary" not in reasoning_effort and "generate_summary" not in reasoning_effort ): updated_reasoning_effort = dict(reasoning_effort) - updated_reasoning_effort["summary"] = "detailed" + updated_reasoning_effort["summary"] = summary completion_kwargs["reasoning_effort"] = updated_reasoning_effort @staticmethod diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index efbac13735..2af5e35112 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -673,6 +673,9 @@ class LiteLLMAnthropicMessagesAdapter: thinking ) if reasoning_effort: + summary = thinking.get("summary") if isinstance(thinking, dict) else None + if summary: + return {"reasoning_effort": {"effort": reasoning_effort, "summary": summary}} return {"reasoning_effort": reasoning_effort} return {} diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index e2639a3126..c6541ccac4 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -177,8 +177,8 @@ def test_openai_model_with_thinking_converts_to_reasoning_effort(): # Verify reasoning_effort is set (converted from thinking) assert "reasoning_effort" in call_kwargs, "reasoning_effort should be passed to completion" - # reasoning_effort is transformed into a dict with effort and summary fields - expected_reasoning_effort = {"effort": "minimal", "summary": "detailed"} + # reasoning_effort is a dict with effort only (summary is opt-in per OpenAI spec) + expected_reasoning_effort = {"effort": "minimal"} assert call_kwargs["reasoning_effort"] == expected_reasoning_effort, \ f"reasoning_effort should be {expected_reasoning_effort} for budget_tokens=1024, got {call_kwargs.get('reasoning_effort')}" @@ -249,8 +249,8 @@ class TestThinkingSummaryPreservation: ) assert completion_kwargs["reasoning_effort"] == {"effort": "high", "summary": "auto"} - def test_thinking_without_summary_defaults_to_detailed(self): - """When no summary is provided, default 'detailed' should still be used.""" + def test_thinking_without_summary_does_not_inject_summary(self): + """When no summary is provided, no summary should be injected (opt-in per OpenAI spec).""" from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( LiteLLMMessagesToCompletionTransformationHandler, ) @@ -260,7 +260,8 @@ class TestThinkingSummaryPreservation: LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( completion_kwargs, thinking=thinking ) - assert completion_kwargs["reasoning_effort"] == {"effort": "medium", "summary": "detailed"} + assert completion_kwargs["reasoning_effort"] == {"effort": "medium"} + assert "summary" not in completion_kwargs["reasoning_effort"] def test_openai_model_with_thinking_summary_end_to_end(self): """End-to-end: anthropic_messages_handler should preserve thinking.summary for OpenAI models.""" @@ -289,3 +290,29 @@ class TestThinkingSummaryPreservation: reasoning_effort = call_kwargs["reasoning_effort"] assert reasoning_effort["summary"] == "concise", \ f"Expected summary='concise', got summary='{reasoning_effort.get('summary')}'" + + def test_translate_thinking_for_model_preserves_summary(self): + """translate_thinking_for_model should include summary in reasoning_effort dict when user provides it.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( + LiteLLMAnthropicMessagesAdapter, + ) + + thinking = {"type": "enabled", "budget_tokens": 5000, "summary": "concise"} + result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( + thinking=thinking, + model="openai/gpt-5.2", + ) + assert result == {"reasoning_effort": {"effort": "medium", "summary": "concise"}} + + def test_translate_thinking_for_model_no_summary_when_not_provided(self): + """translate_thinking_for_model should return plain string reasoning_effort when no summary provided.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( + LiteLLMAnthropicMessagesAdapter, + ) + + thinking = {"type": "enabled", "budget_tokens": 5000} + result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( + thinking=thinking, + model="openai/gpt-5.2", + ) + assert result == {"reasoning_effort": "medium"} From ba5d32b6b81c8dd707d3949a10f80bbbe14f2885 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 4 Mar 2026 10:51:59 -0300 Subject: [PATCH 004/539] =?UTF-8?q?fix:=20address=20review=20feedback=20?= =?UTF-8?q?=E2=80=94=20remove=20redundant=20guard,=20preserve=20summary=20?= =?UTF-8?q?in=20translate=5Fanthropic=5Fto=5Fopenai?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove redundant isinstance(thinking, dict) check in handler.py since early return on line 64 guarantees thinking is a dict at that point - Preserve summary in translate_anthropic_to_openai() for consistency across all code paths (adapter, guardrail, main.py) --- .../anthropic/experimental_pass_through/adapters/handler.py | 2 +- .../experimental_pass_through/adapters/transformation.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index 01c8f39ee8..4935c65f40 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -78,7 +78,7 @@ class LiteLLMMessagesToCompletionTransformationHandler: completion_kwargs["model"] = f"responses/{model}" reasoning_effort = completion_kwargs.get("reasoning_effort") - summary = thinking.get("summary") if isinstance(thinking, dict) else None + summary = thinking.get("summary") if isinstance(reasoning_effort, str) and reasoning_effort: reasoning_dict: Dict[str, Any] = {"effort": reasoning_effort} if summary: diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index 2af5e35112..ca1a94237a 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -892,7 +892,11 @@ class LiteLLMAnthropicMessagesAdapter: cast(Dict[str, Any], thinking) ) if reasoning_effort: - new_kwargs["reasoning_effort"] = reasoning_effort + summary = thinking.get("summary") if isinstance(thinking, dict) else None + if summary: + new_kwargs["reasoning_effort"] = {"effort": reasoning_effort, "summary": summary} + else: + new_kwargs["reasoning_effort"] = reasoning_effort ## CONVERT OUTPUT_FORMAT to RESPONSE_FORMAT if "output_format" in anthropic_message_request: From 57c0b466e1c78350d24e8a5ecb90feb472358750 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 4 Mar 2026 22:04:15 -0300 Subject: [PATCH 005/539] docs: add thinking.summary field to /v1/messages and reasoning_content docs Document the `summary` optional field in the `thinking` object for the Anthropic `/v1/messages` adapter, and add a section on summary preservation when routing to non-Anthropic providers via the adapter. --- docs/my-website/docs/anthropic_unified/index.md | 9 ++++++--- docs/my-website/docs/reasoning_content.md | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/anthropic_unified/index.md b/docs/my-website/docs/anthropic_unified/index.md index 9981547ce1..f8a50e14da 100644 --- a/docs/my-website/docs/anthropic_unified/index.md +++ b/docs/my-website/docs/anthropic_unified/index.md @@ -506,12 +506,15 @@ Request body will be in the Anthropic messages API format. **litellm follows the A system prompt providing context or specific instructions to the model. - **temperature** (number): Controls randomness in the model's responses. Valid range: `0 < temperature < 1`. -- **thinking** (object): +- **thinking** (object): Configuration for enabling extended thinking. If enabled, it includes: - - **budget_tokens** (integer): + - **budget_tokens** (integer): Minimum of 1024 tokens (and less than `max_tokens`). - - **type** (enum): + - **type** (enum): E.g., `"enabled"`. + - **summary** (string, optional): + Enables the summary style for thinking blocks. Possible values: `"auto"`, `"concise"`, `"detailed"`, `"disabled"`. + When routing to non-Anthropic providers (e.g., `openai/gpt-5.1`), the `summary` value is preserved and forwarded to the downstream API. - **tool_choice** (object): Instructs how the model should utilize any provided tools. - **tools** (array of objects): diff --git a/docs/my-website/docs/reasoning_content.md b/docs/my-website/docs/reasoning_content.md index b5a5809bd4..05c374f38d 100644 --- a/docs/my-website/docs/reasoning_content.md +++ b/docs/my-website/docs/reasoning_content.md @@ -675,3 +675,19 @@ response = litellm.completion( reasoning_effort={"effort": "low", "summary": "detailed"}, # Explicit control ) ``` + +### Summary Preservation via `/v1/messages` Adapter + +When using the Anthropic `/v1/messages` adapter to route non-Claude models (e.g., `openai/gpt-5.1`), the `thinking.summary` value is preserved and forwarded to the downstream provider. For example: + +```python +import litellm + +response = await litellm.anthropic.messages.acreate( + model="openai/gpt-5.1", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=8096, + thinking={"type": "enabled", "budget_tokens": 5000, "summary": "concise"}, +) +# The summary="concise" is preserved when routing to OpenAI's Responses API +``` From b3a17596fe6b214b4bf0c7136336eada4c23358e Mon Sep 17 00:00:00 2001 From: Gustavo Martin Alvarez <55332916+gustipardo@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:52:08 -0300 Subject: [PATCH 006/539] fix(gemini): resolve image token undercounting in usage metadata (#22608) * fix(gemini): ensure image token accumulation in usage metadata Fixed an issue where image tokens were being overwritten instead of accumulated in Gemini responses. Added support for both camelCase and snake_case token count keys. Fixes #22082. * test: add regression test for image token accumulation and cleanup files * fix(gemini): ensure consistent accumulation for responseTokensDetails * fix(gemini): harden token count parsing and add vertex accumulation test Parse tokenCount/token_count as int-safe values to satisfy mypy and avoid None/object arithmetic. Add regression test for duplicate modality accumulation in Vertex _calculate_usage. --- .../gemini/image_generation/transformation.py | 13 ++-- .../vertex_and_google_ai_studio_gemini.py | 75 ++++++++++++------- .../test_gemini_image_usage.py | 55 ++++++++++++++ ...test_vertex_and_google_ai_studio_gemini.py | 37 ++++++++- 4 files changed, 148 insertions(+), 32 deletions(-) diff --git a/litellm/llms/gemini/image_generation/transformation.py b/litellm/llms/gemini/image_generation/transformation.py index 73aef15e4c..6716d9a138 100644 --- a/litellm/llms/gemini/image_generation/transformation.py +++ b/litellm/llms/gemini/image_generation/transformation.py @@ -92,12 +92,15 @@ class GoogleImageGenConfig(BaseImageGenerationConfig): tokens_details = usage_metadata.get("promptTokensDetails", []) for details in tokens_details: if isinstance(details, dict): - modality = details.get("modality") - token_count = details.get("tokenCount", 0) + modality = str(details.get("modality", "")).upper() + raw_token_count = details.get( + "tokenCount", details.get("token_count", 0) + ) + token_count = raw_token_count if isinstance(raw_token_count, int) else 0 if modality == "TEXT": - input_tokens_details.text_tokens = token_count + input_tokens_details.text_tokens += token_count elif modality == "IMAGE": - input_tokens_details.image_tokens = token_count + input_tokens_details.image_tokens += token_count return ImageUsage( input_tokens=usage_metadata.get("promptTokenCount", 0), @@ -274,4 +277,4 @@ class GoogleImageGenConfig(BaseImageGenerationConfig): b64_json=prediction.get("bytesBase64Encoded", None), url=None, # Google AI returns base64, not URLs )) - return model_response \ No newline at end of file + return model_response diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index eb2d5ad51c..4d10dbf7a0 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -1623,6 +1623,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): response_tokens: Optional[int] = None response_tokens_details: Optional[CompletionTokensDetailsWrapper] = None usage_metadata = completion_response["usageMetadata"] + + def _get_token_count(detail: dict) -> int: + raw_token_count = detail.get("tokenCount", detail.get("token_count", 0)) + return raw_token_count if isinstance(raw_token_count, int) else 0 + if "cachedContentTokenCount" in usage_metadata: cached_tokens = usage_metadata["cachedContentTokenCount"] @@ -1632,10 +1637,16 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): if "responseTokensDetails" in usage_metadata: response_tokens_details = CompletionTokensDetailsWrapper() for detail in usage_metadata["responseTokensDetails"]: - if detail["modality"] == "TEXT": - response_tokens_details.text_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "AUDIO": - response_tokens_details.audio_tokens = detail.get("tokenCount", 0) + modality = str(detail.get("modality", "")).upper() + token_count = _get_token_count(detail) + if modality == "TEXT": + response_tokens_details.text_tokens = ( + response_tokens_details.text_tokens or 0 + ) + token_count + elif modality == "AUDIO": + response_tokens_details.audio_tokens = ( + response_tokens_details.audio_tokens or 0 + ) + token_count ######################################################### @@ -1644,16 +1655,24 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): if response_tokens_details is None: response_tokens_details = CompletionTokensDetailsWrapper() for detail in usage_metadata["candidatesTokensDetails"]: - modality = detail.get("modality") - token_count = detail.get("tokenCount", 0) + modality = str(detail.get("modality", "")).upper() + token_count = _get_token_count(detail) if modality == "TEXT": - response_tokens_details.text_tokens = token_count + response_tokens_details.text_tokens = ( + response_tokens_details.text_tokens or 0 + ) + token_count elif modality == "AUDIO": - response_tokens_details.audio_tokens = token_count + response_tokens_details.audio_tokens = ( + response_tokens_details.audio_tokens or 0 + ) + token_count elif modality == "IMAGE": - response_tokens_details.image_tokens = token_count + response_tokens_details.image_tokens = ( + response_tokens_details.image_tokens or 0 + ) + token_count elif modality == "VIDEO": - response_tokens_details.video_tokens = token_count + response_tokens_details.video_tokens = ( + response_tokens_details.video_tokens or 0 + ) + token_count # Calculate text_tokens if not explicitly provided in candidatesTokensDetails # candidatesTokenCount includes all modalities, so: text = total - (image + audio + video) @@ -1677,14 +1696,16 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): ## Parse promptTokensDetails (total tokens by modality, includes cached + non-cached) if "promptTokensDetails" in usage_metadata: for detail in usage_metadata["promptTokensDetails"]: - if detail["modality"] == "AUDIO": - prompt_audio_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "TEXT": - prompt_text_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "IMAGE": - prompt_image_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "VIDEO": - prompt_video_tokens = detail.get("tokenCount", 0) + modality = str(detail.get("modality", "")).upper() + token_count = _get_token_count(detail) + if modality == "AUDIO": + prompt_audio_tokens = (prompt_audio_tokens or 0) + token_count + elif modality == "TEXT": + prompt_text_tokens = (prompt_text_tokens or 0) + token_count + elif modality == "IMAGE": + prompt_image_tokens = (prompt_image_tokens or 0) + token_count + elif modality == "VIDEO": + prompt_video_tokens = (prompt_video_tokens or 0) + token_count ## Parse cacheTokensDetails (breakdown of cached tokens by modality) ## When explicit caching is used, Gemini provides this field to show which modalities were cached @@ -1695,14 +1716,16 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): if "cacheTokensDetails" in usage_metadata: for detail in usage_metadata["cacheTokensDetails"]: - if detail["modality"] == "AUDIO": - cached_audio_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "TEXT": - cached_text_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "IMAGE": - cached_image_tokens = detail.get("tokenCount", 0) - elif detail["modality"] == "VIDEO": - cached_video_tokens = detail.get("tokenCount", 0) + modality = str(detail.get("modality", "")).upper() + token_count = _get_token_count(detail) + if modality == "AUDIO": + cached_audio_tokens = (cached_audio_tokens or 0) + token_count + elif modality == "TEXT": + cached_text_tokens = (cached_text_tokens or 0) + token_count + elif modality == "IMAGE": + cached_image_tokens = (cached_image_tokens or 0) + token_count + elif modality == "VIDEO": + cached_video_tokens = (cached_video_tokens or 0) + token_count ## Calculate non-cached tokens by subtracting cached from total (per modality) ## This is necessary because promptTokensDetails includes both cached and non-cached tokens diff --git a/tests/llm_translation/test_gemini_image_usage.py b/tests/llm_translation/test_gemini_image_usage.py index 8c7f05d38e..0497d7fd9d 100644 --- a/tests/llm_translation/test_gemini_image_usage.py +++ b/tests/llm_translation/test_gemini_image_usage.py @@ -4,9 +4,11 @@ Test for Gemini image generation usage metadata extraction. This test verifies the fix for issue #18323 where image_generation() was returning usage=0 while completion() returned proper token usage. """ +import os import pytest from unittest.mock import patch, MagicMock import litellm +from litellm.llms.gemini.image_generation.transformation import GoogleImageGenConfig from litellm.types.utils import ImageResponse, ImageObject, ImageUsage @@ -211,3 +213,56 @@ def test_gemini_imagen_models_no_usage_extraction(): # For Imagen models, we don't extract usage from the predictions format # This test just ensures we don't crash + + +def test_gemini_image_generation_accumulates_multiple_image_prompt_token_details(): + """ + Regression test: promptTokensDetails can include multiple IMAGE entries. + These must be accumulated instead of overwritten. + """ + previous_local_model_cost_map = os.environ.get("LITELLM_LOCAL_MODEL_COST_MAP") + previous_model_cost = litellm.model_cost + try: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + model = "gemini/gemini-3-pro-image-preview" + config = GoogleImageGenConfig() + + usage_metadata = { + "promptTokenCount": 200, + "candidatesTokenCount": 0, + "totalTokenCount": 200, + "promptTokensDetails": [ + {"modality": "TEXT", "tokenCount": 10}, + {"modality": "IMAGE", "tokenCount": 90}, + {"modality": "IMAGE", "tokenCount": 100}, + ], + } + + parsed_usage = config._transform_image_usage(usage_metadata) + image_response = ImageResponse( + data=[ImageObject(b64_json="fake_image_data")], + usage=parsed_usage, + ) + + observed_cost = litellm.completion_cost( + completion_response=image_response, + model=model, + custom_llm_provider="gemini", + ) + + model_info = litellm.get_model_info(model=model, custom_llm_provider="gemini") + expected_image_tokens = 190 + expected_total_prompt_tokens = 200 + expected_prompt_cost = expected_total_prompt_tokens * model_info["input_cost_per_token"] + + assert parsed_usage.input_tokens_details.image_tokens == expected_image_tokens + assert parsed_usage.input_tokens_details.text_tokens == 10 + assert observed_cost == pytest.approx(expected_prompt_cost, rel=1e-12) + finally: + if previous_local_model_cost_map is None: + os.environ.pop("LITELLM_LOCAL_MODEL_COST_MAP", None) + else: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = previous_local_model_cost_map + litellm.model_cost = previous_model_cost diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py index 8beb19bf1a..0f8ae71543 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py @@ -862,6 +862,42 @@ def test_vertex_ai_usage_metadata_with_image_tokens_in_prompt(): ) +def test_vertex_ai_usage_metadata_accumulates_duplicate_modalities(): + """Ensure _calculate_usage accumulates repeated modality entries.""" + v = VertexGeminiConfig() + usage_metadata = { + "promptTokenCount": 210, + "candidatesTokenCount": 50, + "totalTokenCount": 260, + "promptTokensDetails": [ + {"modality": "TEXT", "tokenCount": 20}, + {"modality": "IMAGE", "tokenCount": 90}, + {"modality": "IMAGE", "token_count": 100}, + ], + "candidatesTokensDetails": [ + {"modality": "IMAGE", "tokenCount": 30}, + {"modality": "TEXT", "tokenCount": 15}, + {"modality": "TEXT", "token_count": 5}, + ], + "cacheTokensDetails": [ + {"modality": "TEXT", "tokenCount": 4}, + {"modality": "IMAGE", "tokenCount": 40}, + {"modality": "IMAGE", "token_count": 10}, + ], + } + usage_metadata = UsageMetadata(**usage_metadata) + result = v._calculate_usage(completion_response={"usageMetadata": usage_metadata}) + + # prompt details are total - cached per modality + assert result.prompt_tokens_details.text_tokens == 16 # 20 - 4 + assert result.prompt_tokens_details.image_tokens == 140 # (90 + 100) - (40 + 10) + + # candidates details accumulate duplicate modalities + assert result.completion_tokens_details.text_tokens == 20 # 15 + 5 + assert result.completion_tokens_details.image_tokens == 30 + assert result.completion_tokens == 50 + + def test_vertex_ai_map_thinking_param_with_budget_tokens_0(): """ If budget_tokens is 0, do not set includeThoughts to True @@ -3723,4 +3759,3 @@ def test_vertex_ai_usage_metadata_video_tokens_with_caching(): "Prompt video tokens should be 10240 - 5120 (cached) = 5120" assert result.prompt_tokens_details.text_tokens == 9 assert result.prompt_tokens_details.audio_tokens == 200 - From 607a9683a4a5d3d8cdb1b2f5463eaa875966d68e Mon Sep 17 00:00:00 2001 From: Chesars Date: Thu, 5 Mar 2026 14:02:07 +0000 Subject: [PATCH 007/539] feat(anthropic): add opt-out flag for default reasoning summary Add `litellm.disable_default_reasoning_summary` flag (default False) and env var `LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY` to allow users to opt out of the automatic `summary="detailed"` injection when routing Anthropic thinking requests to OpenAI's Responses API. Default behavior is preserved (summary="detailed" is always added), but users who don't want to pay for summary tokens can now disable it. https://claude.ai/code/session_01VJU9EwVvgvmeCe3Yu1aULa --- litellm/__init__.py | 1 + .../adapters/handler.py | 19 ++- .../responses_adapters/transformation.py | 9 ++ ...erimental_pass_through_messages_handler.py | 134 +++++++++++++----- .../test_responses_adapters_transformation.py | 32 +++++ 5 files changed, 156 insertions(+), 39 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 57e9cb25f4..ceadb983bd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -293,6 +293,7 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" guardrail_name_config_map: Dict[str, GuardrailItem] = {} include_cost_in_streaming_usage: bool = False reasoning_auto_summary: bool = False +disable_default_reasoning_summary: bool = False ### PROMPTS #### from litellm.types.prompts.init_prompts import PromptSpec diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index 4935c65f40..a0954d765c 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -1,3 +1,4 @@ +import os from typing import ( TYPE_CHECKING, Any, @@ -77,22 +78,30 @@ class LiteLLMMessagesToCompletionTransformationHandler: # Prefix model with "responses/" to route to OpenAI Responses API completion_kwargs["model"] = f"responses/{model}" + summary_disabled = ( + litellm.disable_default_reasoning_summary + or os.getenv("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", "false").lower() == "true" + ) + reasoning_effort = completion_kwargs.get("reasoning_effort") summary = thinking.get("summary") if isinstance(reasoning_effort, str) and reasoning_effort: reasoning_dict: Dict[str, Any] = {"effort": reasoning_effort} if summary: reasoning_dict["summary"] = summary + elif not summary_disabled: + reasoning_dict["summary"] = "detailed" completion_kwargs["reasoning_effort"] = reasoning_dict elif isinstance(reasoning_effort, dict): if ( - summary - and "summary" not in reasoning_effort + "summary" not in reasoning_effort and "generate_summary" not in reasoning_effort ): - updated_reasoning_effort = dict(reasoning_effort) - updated_reasoning_effort["summary"] = summary - completion_kwargs["reasoning_effort"] = updated_reasoning_effort + effective_summary = summary if summary else ("detailed" if not summary_disabled else None) + if effective_summary: + updated_reasoning_effort = dict(reasoning_effort) + updated_reasoning_effort["summary"] = effective_summary + completion_kwargs["reasoning_effort"] = updated_reasoning_effort @staticmethod def _prepare_completion_kwargs( diff --git a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py index 497809b05f..ec855acd23 100644 --- a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py @@ -6,8 +6,11 @@ path used for OpenAI and Azure models. """ import json +import os from typing import Any, Dict, List, Optional, Union, cast +import litellm + from litellm.types.llms.anthropic import ( AllAnthropicToolsValues, AnthopicMessagesAssistantMessageParam, @@ -241,10 +244,16 @@ class LiteLLMAnthropicToResponsesAPIAdapter: effort = "low" else: effort = "minimal" + summary_disabled = ( + litellm.disable_default_reasoning_summary + or os.getenv("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", "false").lower() == "true" + ) result: Dict[str, Any] = {"effort": effort} summary = thinking.get("summary") if summary: result["summary"] = summary + elif not summary_disabled: + result["summary"] = "detailed" return result def translate_request( diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index 7864c5f59a..384c4a97c1 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -180,7 +180,8 @@ def test_openai_model_with_thinking_converts_to_reasoning(): assert "reasoning" in call_kwargs, "reasoning should be passed to litellm.responses" # budget_tokens=1024 -> effort="minimal" (< 2000 threshold) - expected_reasoning = {"effort": "minimal"} + # summary="detailed" added by default unless disable_default_reasoning_summary is set + expected_reasoning = {"effort": "minimal", "summary": "detailed"} assert call_kwargs["reasoning"] == expected_reasoning, ( f"reasoning should be {expected_reasoning} for budget_tokens=1024, " f"got {call_kwargs.get('reasoning')}" @@ -225,7 +226,7 @@ class TestThinkingParameterTransformation: class TestThinkingSummaryPreservation: - """Tests for issue #20998: thinking.summary must be preserved when routing to OpenAI Responses API.""" + """Tests for thinking.summary preservation and disable_default_reasoning_summary flag.""" def test_thinking_summary_concise_preserved_for_openai(self): """User-provided summary='concise' should not be replaced with 'detailed'.""" @@ -253,26 +254,98 @@ class TestThinkingSummaryPreservation: ) assert completion_kwargs["reasoning_effort"] == {"effort": "high", "summary": "auto"} - def test_thinking_without_summary_does_not_inject_summary(self): - """When no summary is provided, no summary should be injected (opt-in per OpenAI spec).""" + def test_summary_added_by_default_when_no_user_summary(self): + """When no user summary and flag is off, summary='detailed' is added by default.""" + import litellm from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( LiteLLMMessagesToCompletionTransformationHandler, ) - thinking = {"type": "enabled", "budget_tokens": 5000} - completion_kwargs = {"model": "openai/gpt-5.1", "reasoning_effort": "medium"} - LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( - completion_kwargs, thinking=thinking + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = False + completion_kwargs = { + "model": "responses/gpt-5.2", + "custom_llm_provider": "openai", + "reasoning_effort": "medium", + } + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking={"type": "enabled", "budget_tokens": 5000} + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "medium", "summary": "detailed"} + finally: + litellm.disable_default_reasoning_summary = original + + def test_summary_excluded_when_disable_flag_set_string_reasoning(self): + """When disable_default_reasoning_summary is True, summary is not added for string reasoning_effort.""" + import litellm + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, ) - assert completion_kwargs["reasoning_effort"] == {"effort": "medium"} - assert "summary" not in completion_kwargs["reasoning_effort"] + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = True + completion_kwargs = { + "model": "responses/gpt-5.2", + "custom_llm_provider": "openai", + "reasoning_effort": "high", + } + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking={"type": "enabled", "budget_tokens": 10000} + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "high"} + assert "summary" not in completion_kwargs["reasoning_effort"] + finally: + litellm.disable_default_reasoning_summary = original + + def test_summary_excluded_when_disable_flag_set_dict_reasoning(self): + """When disable_default_reasoning_summary is True, summary is not injected into dict reasoning_effort.""" + import litellm + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = True + completion_kwargs = { + "model": "responses/gpt-5.2", + "custom_llm_provider": "openai", + "reasoning_effort": {"effort": "medium"}, + } + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking={"type": "enabled", "budget_tokens": 5000} + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "medium"} + assert "summary" not in completion_kwargs["reasoning_effort"] + finally: + litellm.disable_default_reasoning_summary = original + + def test_user_provided_summary_preserved_even_when_flag_off(self): + """When user already set summary in dict reasoning_effort, it's preserved regardless of flag.""" + import litellm + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = False + completion_kwargs = { + "model": "responses/gpt-5.2", + "custom_llm_provider": "openai", + "reasoning_effort": {"effort": "high", "summary": "concise"}, + } + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking={"type": "enabled", "budget_tokens": 10000} + ) + assert completion_kwargs["reasoning_effort"]["summary"] == "concise" + finally: + litellm.disable_default_reasoning_summary = original def test_openai_model_with_thinking_summary_end_to_end(self): - """End-to-end: anthropic_messages_handler should preserve thinking.summary for OpenAI models. - - OpenAI models are routed to litellm.responses(), so we verify the - reasoning dict passed to it contains the user's summary value. - """ + """End-to-end: anthropic_messages_handler should preserve thinking.summary for OpenAI models.""" from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( anthropic_messages_handler, ) @@ -309,16 +382,22 @@ class TestThinkingSummaryPreservation: result = LiteLLMAnthropicToResponsesAPIAdapter.translate_thinking_to_reasoning(thinking) assert result == {"effort": "medium", "summary": "concise"} - def test_responses_adapter_no_summary_when_not_provided(self): - """translate_thinking_to_reasoning should not include summary when not provided.""" + def test_responses_adapter_no_summary_when_disabled(self): + """translate_thinking_to_reasoning should not include summary when flag is set and no user summary.""" + import litellm from litellm.llms.anthropic.experimental_pass_through.responses_adapters.transformation import ( LiteLLMAnthropicToResponsesAPIAdapter, ) - thinking = {"type": "enabled", "budget_tokens": 5000} - result = LiteLLMAnthropicToResponsesAPIAdapter.translate_thinking_to_reasoning(thinking) - assert result == {"effort": "medium"} - assert "summary" not in result + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = True + thinking = {"type": "enabled", "budget_tokens": 5000} + result = LiteLLMAnthropicToResponsesAPIAdapter.translate_thinking_to_reasoning(thinking) + assert result == {"effort": "medium"} + assert "summary" not in result + finally: + litellm.disable_default_reasoning_summary = original def test_translate_thinking_for_model_preserves_summary(self): """translate_thinking_for_model should include summary in reasoning_effort dict when user provides it.""" @@ -332,16 +411,3 @@ class TestThinkingSummaryPreservation: model="openai/gpt-5.2", ) assert result == {"reasoning_effort": {"effort": "medium", "summary": "concise"}} - - def test_translate_thinking_for_model_no_summary_when_not_provided(self): - """translate_thinking_for_model should return plain string reasoning_effort when no summary provided.""" - from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( - LiteLLMAnthropicMessagesAdapter, - ) - - thinking = {"type": "enabled", "budget_tokens": 5000} - result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( - thinking=thinking, - model="openai/gpt-5.2", - ) - assert result == {"reasoning_effort": "medium"} diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/responses_adapters/test_responses_adapters_transformation.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/responses_adapters/test_responses_adapters_transformation.py index 252ba230ff..7f77394552 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/responses_adapters/test_responses_adapters_transformation.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/responses_adapters/test_responses_adapters_transformation.py @@ -658,6 +658,38 @@ class TestTranslateThinkingToReasoning: result = _ADAPTER.translate_thinking_to_reasoning({"type": "enabled"}) assert result == {"effort": "minimal", "summary": "detailed"} + def test_summary_excluded_when_disable_flag_set(self): + """When disable_default_reasoning_summary is True, summary is not included.""" + import litellm + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = True + result = _ADAPTER.translate_thinking_to_reasoning( + {"type": "enabled", "budget_tokens": 10000} + ) + assert result == {"effort": "high"} + assert "summary" not in result + finally: + litellm.disable_default_reasoning_summary = original + + def test_summary_excluded_when_env_var_set(self): + """When LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY env var is true, summary is not included.""" + import litellm + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = False + os.environ["LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY"] = "true" + result = _ADAPTER.translate_thinking_to_reasoning( + {"type": "enabled", "budget_tokens": 5000} + ) + assert result == {"effort": "medium"} + assert "summary" not in result + finally: + litellm.disable_default_reasoning_summary = original + os.environ.pop("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", None) + # --------------------------------------------------------------------------- # translate_request – broader coverage From 3e9ea6f49bf9b552834ba0a21a827fecc131854a Mon Sep 17 00:00:00 2001 From: Chesars Date: Thu, 5 Mar 2026 11:24:18 -0300 Subject: [PATCH 008/539] refactor: extract summary_disabled logic into shared helper and add missing env var test - Extract duplicated summary_disabled evaluation from handler.py and transformation.py into a shared is_default_reasoning_summary_disabled() helper in utils.py to prevent future divergence. - Add test_summary_excluded_when_env_var_set to handler test class to close env-var test coverage gap flagged by Greptile. --- .../adapters/handler.py | 9 +++---- .../responses_adapters/transformation.py | 9 +++---- .../experimental_pass_through/utils.py | 12 +++++++++ ...erimental_pass_through_messages_handler.py | 25 +++++++++++++++++++ 4 files changed, 45 insertions(+), 10 deletions(-) create mode 100644 litellm/llms/anthropic/experimental_pass_through/utils.py diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py index a0954d765c..e7effb77ca 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/handler.py @@ -1,4 +1,3 @@ -import os from typing import ( TYPE_CHECKING, Any, @@ -13,6 +12,9 @@ from typing import ( ) import litellm +from litellm.llms.anthropic.experimental_pass_through.utils import ( + is_default_reasoning_summary_disabled, +) from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( AnthropicAdapter, ) @@ -78,10 +80,7 @@ class LiteLLMMessagesToCompletionTransformationHandler: # Prefix model with "responses/" to route to OpenAI Responses API completion_kwargs["model"] = f"responses/{model}" - summary_disabled = ( - litellm.disable_default_reasoning_summary - or os.getenv("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", "false").lower() == "true" - ) + summary_disabled = is_default_reasoning_summary_disabled() reasoning_effort = completion_kwargs.get("reasoning_effort") summary = thinking.get("summary") diff --git a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py index ec855acd23..5b4f2a19f2 100644 --- a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py @@ -6,10 +6,12 @@ path used for OpenAI and Azure models. """ import json -import os from typing import Any, Dict, List, Optional, Union, cast import litellm +from litellm.llms.anthropic.experimental_pass_through.utils import ( + is_default_reasoning_summary_disabled, +) from litellm.types.llms.anthropic import ( AllAnthropicToolsValues, @@ -244,10 +246,7 @@ class LiteLLMAnthropicToResponsesAPIAdapter: effort = "low" else: effort = "minimal" - summary_disabled = ( - litellm.disable_default_reasoning_summary - or os.getenv("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", "false").lower() == "true" - ) + summary_disabled = is_default_reasoning_summary_disabled() result: Dict[str, Any] = {"effort": effort} summary = thinking.get("summary") if summary: diff --git a/litellm/llms/anthropic/experimental_pass_through/utils.py b/litellm/llms/anthropic/experimental_pass_through/utils.py new file mode 100644 index 0000000000..4a40e4629e --- /dev/null +++ b/litellm/llms/anthropic/experimental_pass_through/utils.py @@ -0,0 +1,12 @@ +import os + +import litellm + + +def is_default_reasoning_summary_disabled() -> bool: + """Check whether the default 'summary: detailed' injection should be suppressed.""" + return ( + litellm.disable_default_reasoning_summary + or os.getenv("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", "false").lower() + == "true" + ) diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index 384c4a97c1..0bc28cc086 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -322,6 +322,31 @@ class TestThinkingSummaryPreservation: finally: litellm.disable_default_reasoning_summary = original + def test_summary_excluded_when_env_var_set(self): + """When LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY env var is true, summary is not added.""" + import litellm + from litellm.llms.anthropic.experimental_pass_through.adapters.handler import ( + LiteLLMMessagesToCompletionTransformationHandler, + ) + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = False + os.environ["LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY"] = "true" + completion_kwargs = { + "model": "responses/gpt-5.2", + "custom_llm_provider": "openai", + "reasoning_effort": "high", + } + LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed( + completion_kwargs, thinking={"type": "enabled", "budget_tokens": 10000} + ) + assert completion_kwargs["reasoning_effort"] == {"effort": "high"} + assert "summary" not in completion_kwargs["reasoning_effort"] + finally: + litellm.disable_default_reasoning_summary = original + os.environ.pop("LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY", None) + def test_user_provided_summary_preserved_even_when_flag_off(self): """When user already set summary in dict reasoning_effort, it's preserved regardless of flag.""" import litellm From 5b904f6054849640df07d2d4c05284ac0420970a Mon Sep 17 00:00:00 2001 From: Chesars Date: Thu, 5 Mar 2026 12:22:18 -0300 Subject: [PATCH 009/539] fix(anthropic): align translate_thinking_for_model with default summary injection + add docs - Update translate_thinking_for_model (3rd code path) to inject summary="detailed" by default, consistent with the other two paths - Add disable_default_reasoning_summary flag check via shared helper - Add tests for flag enabled/disabled and user-provided summary - Document disable_default_reasoning_summary in reasoning_content.md --- docs/my-website/docs/reasoning_content.md | 50 +++++++++++++++++++ .../adapters/transformation.py | 7 +++ ...erimental_pass_through_messages_handler.py | 34 ++++++++++++- 3 files changed, 90 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/reasoning_content.md b/docs/my-website/docs/reasoning_content.md index 05c374f38d..3693ab3315 100644 --- a/docs/my-website/docs/reasoning_content.md +++ b/docs/my-website/docs/reasoning_content.md @@ -691,3 +691,53 @@ response = await litellm.anthropic.messages.acreate( ) # The summary="concise" is preserved when routing to OpenAI's Responses API ``` + +### Default Summary Injection for `/v1/messages` Adapter + +When the Anthropic `/v1/messages` adapter translates `thinking` parameters to OpenAI `reasoning_effort` for non-Claude models, `summary="detailed"` is automatically injected by default. This ensures that reasoning text is returned in the response (matching the Anthropic thinking behavior). + +To **disable** this default injection, use the `disable_default_reasoning_summary` flag: + + + + +```python +import litellm + +# Disable default summary="detailed" injection +litellm.disable_default_reasoning_summary = True + +response = await litellm.anthropic.messages.acreate( + model="openai/gpt-5.1", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=8096, + thinking={"type": "enabled", "budget_tokens": 5000}, +) +# No summary will be injected — only reasoning_effort is forwarded +``` + + + + + +```bash +export LITELLM_DISABLE_DEFAULT_REASONING_SUMMARY=true +``` + + + + + +```yaml +litellm_settings: + disable_default_reasoning_summary: true +``` + + + + +:::info + +This flag only affects the automatic injection of `summary="detailed"` when no user-provided summary is present. If you explicitly pass `thinking.summary` (e.g., `"concise"` or `"auto"`), your value is always preserved regardless of this flag. + +::: diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index b0138ecbf5..049b876332 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -13,6 +13,10 @@ from typing import ( cast, ) +from litellm.llms.anthropic.experimental_pass_through.utils import ( + is_default_reasoning_summary_disabled, +) + # OpenAI has a 64-character limit for function/tool names # Anthropic does not have this limit, so we need to truncate long names OPENAI_MAX_TOOL_NAME_LENGTH = 64 @@ -694,8 +698,11 @@ class LiteLLMAnthropicMessagesAdapter: ) if reasoning_effort: summary = thinking.get("summary") if isinstance(thinking, dict) else None + summary_disabled = is_default_reasoning_summary_disabled() if summary: return {"reasoning_effort": {"effort": reasoning_effort, "summary": summary}} + elif not summary_disabled: + return {"reasoning_effort": {"effort": reasoning_effort, "summary": "detailed"}} return {"reasoning_effort": reasoning_effort} return {} diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py index 0bc28cc086..203b6dacea 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/messages/test_anthropic_experimental_pass_through_messages_handler.py @@ -221,9 +221,41 @@ class TestThinkingParameterTransformation: model="openai/gpt-5.2", ) - assert result == {"reasoning_effort": "minimal"} + assert result == {"reasoning_effort": {"effort": "minimal", "summary": "detailed"}} assert "thinking" not in result + def test_translate_thinking_for_model_no_summary_when_disabled(self): + """When disable_default_reasoning_summary is True, no summary is injected.""" + import litellm + from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( + LiteLLMAnthropicMessagesAdapter, + ) + + original = litellm.disable_default_reasoning_summary + try: + litellm.disable_default_reasoning_summary = True + thinking = {"type": "enabled", "budget_tokens": 5000} + result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( + thinking=thinking, + model="openai/gpt-5.2", + ) + assert result == {"reasoning_effort": "medium"} + finally: + litellm.disable_default_reasoning_summary = original + + def test_translate_thinking_for_model_preserves_user_summary(self): + """User-provided summary is always preserved regardless of flag.""" + from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import ( + LiteLLMAnthropicMessagesAdapter, + ) + + thinking = {"type": "enabled", "budget_tokens": 10000, "summary": "concise"} + result = LiteLLMAnthropicMessagesAdapter.translate_thinking_for_model( + thinking=thinking, + model="openai/gpt-5.2", + ) + assert result == {"reasoning_effort": {"effort": "high", "summary": "concise"}} + class TestThinkingSummaryPreservation: """Tests for thinking.summary preservation and disable_default_reasoning_summary flag.""" From c1b36403104224a888b939429921707aced0154c Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Mon, 9 Mar 2026 09:48:10 -0700 Subject: [PATCH 010/539] feat: add audit log export to external callbacks (S3, Datadog, etc.) Audit logs (CRUD events on keys, teams, users, models) were only stored in the Prisma DB. This adds a pluggable callback system so audit logs can be forwarded to external services like S3 for ingestion into security monitoring tools. New config key `audit_log_callbacks` under `litellm_settings` reuses the existing callback infrastructure. Any CustomLogger subclass can opt in by overriding `async_log_audit_log_event()`. S3Logger (s3_v2) is implemented as the first handler, storing audit logs under `audit_logs/{date}/` prefix. Co-Authored-By: Claude Opus 4.6 --- litellm/__init__.py | 1 + litellm/integrations/custom_logger.py | 7 + litellm/integrations/s3_v2.py | 34 ++- .../proxy/management_helpers/audit_logs.py | 84 +++++- litellm/proxy/proxy_server.py | 13 + litellm/types/utils.py | 14 + .../test_audit_log_callbacks.py | 270 ++++++++++++++++++ 7 files changed, 420 insertions(+), 3 deletions(-) create mode 100644 tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 4fc71e1270..e66a9b6b08 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -96,6 +96,7 @@ input_callback: List[CALLBACK_TYPES] = [] success_callback: List[CALLBACK_TYPES] = [] failure_callback: List[CALLBACK_TYPES] = [] service_callback: List[CALLBACK_TYPES] = [] +audit_log_callbacks: List[CALLBACK_TYPES] = [] # logging_callback_manager is lazy-loaded via __getattr__ _custom_logger_compatible_callbacks_literal = Literal[ "lago", diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index c244363e38..8867adfe86 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -27,6 +27,7 @@ from litellm.types.utils import ( LLMResponseTypes, ModelResponse, ModelResponseStream, + StandardAuditLogPayload, StandardCallbackDynamicParams, StandardLoggingPayload, ) @@ -177,6 +178,12 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + async def async_log_audit_log_event( + self, audit_log: "StandardAuditLogPayload" + ): + """Called when an audit log is created. Override in subclasses to handle.""" + pass + #### PROMPT MANAGEMENT HOOKS #### async def async_get_chat_completion_prompt( diff --git a/litellm/integrations/s3_v2.py b/litellm/integrations/s3_v2.py index eddc80dbc1..134766aa18 100644 --- a/litellm/integrations/s3_v2.py +++ b/litellm/integrations/s3_v2.py @@ -22,7 +22,7 @@ from litellm.llms.custom_httpx.http_handler import ( httpxSpecialProvider, ) from litellm.types.integrations.s3_v2 import s3BatchLoggingElement -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import StandardAuditLogPayload, StandardLoggingPayload from .custom_batch_logger import CustomBatchLogger @@ -244,6 +244,38 @@ class S3Logger(CustomBatchLogger, BaseAWSLLM): ) pass + async def async_log_audit_log_event( + self, audit_log: StandardAuditLogPayload + ) -> None: + """Batch audit logs and upload to S3 under audit_logs/ prefix.""" + try: + from datetime import timezone + + now = datetime.now(timezone.utc) + audit_log_id = audit_log.get("id", "unknown") + + s3_path = cast(Optional[str], self.s3_path) or "" + s3_path = s3_path.rstrip("/") + "/" if s3_path else "" + + s3_object_key = ( + f"{s3_path}audit_logs/" + f"{now.strftime('%Y-%m-%d')}/" + f"{now.strftime('%H-%M-%S')}_{audit_log_id}.json" + ) + + element = s3BatchLoggingElement( + payload=dict(audit_log), + s3_object_key=s3_object_key, + s3_object_download_filename=f"audit-{audit_log_id}.json", + ) + + self.log_queue.append(element) + + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() + except Exception as e: + verbose_logger.exception("S3 audit log error: %s", e) + async def _async_log_event_base(self, kwargs, response_obj, start_time, end_time): try: verbose_logger.debug( diff --git a/litellm/proxy/management_helpers/audit_logs.py b/litellm/proxy/management_helpers/audit_logs.py index ea082f468a..dee7c82ee5 100644 --- a/litellm/proxy/management_helpers/audit_logs.py +++ b/litellm/proxy/management_helpers/audit_logs.py @@ -2,12 +2,15 @@ Functions to create audit logs for LiteLLM Proxy """ +import asyncio import json -from litellm._uuid import uuid from datetime import datetime, timezone +from typing import Dict import litellm from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid +from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import ( AUDIT_ACTIONS, LiteLLM_AuditLogs, @@ -15,6 +18,82 @@ from litellm.proxy._types import ( Optional, UserAPIKeyAuth, ) +from litellm.types.utils import StandardAuditLogPayload + +_audit_log_callback_cache: Dict[str, CustomLogger] = {} + + +def _resolve_audit_log_callback(name: str) -> Optional[CustomLogger]: + """Resolve a string callback name to a CustomLogger instance, with caching.""" + if name in _audit_log_callback_cache: + return _audit_log_callback_cache[name] + + from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, + ) + + instance = _init_custom_logger_compatible_class( + logging_integration=name, # type: ignore + internal_usage_cache=None, + llm_router=None, + ) + + if instance is not None: + _audit_log_callback_cache[name] = instance + return instance + + +def _build_audit_log_payload( + request_data: LiteLLM_AuditLogs, +) -> StandardAuditLogPayload: + """Convert LiteLLM_AuditLogs to StandardAuditLogPayload for callback dispatch.""" + updated_at = "" + if request_data.updated_at is not None: + updated_at = request_data.updated_at.isoformat() + + table_name = request_data.table_name + if isinstance(table_name, LitellmTableNames): + table_name = table_name.value + + return StandardAuditLogPayload( + id=request_data.id, + updated_at=updated_at, + changed_by=request_data.changed_by or "", + changed_by_api_key=request_data.changed_by_api_key or "", + action=request_data.action, + table_name=str(table_name), + object_id=request_data.object_id, + before_value=request_data.before_value, + updated_values=request_data.updated_values, + ) + + +async def _dispatch_audit_log_to_callbacks( + request_data: LiteLLM_AuditLogs, +) -> None: + """Dispatch audit log to all registered audit_log_callbacks.""" + if not litellm.audit_log_callbacks: + return + + payload = _build_audit_log_payload(request_data) + + for callback in litellm.audit_log_callbacks: + try: + resolved = callback + if isinstance(callback, str): + resolved = _resolve_audit_log_callback(callback) + if resolved is None: + verbose_proxy_logger.warning( + "Could not resolve audit log callback: %s", callback + ) + continue + + if isinstance(resolved, CustomLogger): + asyncio.create_task(resolved.async_log_audit_log_event(payload)) + except Exception as e: + verbose_proxy_logger.error( + "Failed dispatching audit log to callback: %s", e + ) async def create_object_audit_log( @@ -104,4 +183,5 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): # [Non-Blocking Exception. Do not allow blocking LLM API call] verbose_proxy_logger.error(f"Failed Creating audit log {e}") - return + # Dispatch to external audit log callbacks + await _dispatch_audit_log_to_callbacks(request_data) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bd7b21c3b5..aa5adc054d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2946,6 +2946,19 @@ class ProxyConfig: print( # noqa f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}" ) # noqa + elif key == "audit_log_callbacks": + litellm.audit_log_callbacks = [] + + for callback in value: + if "." in callback: + litellm.audit_log_callbacks.append( + get_instance_fn(value=callback) + ) + else: + litellm.audit_log_callbacks.append(callback) + print( # noqa + f"{blue_color_code} Initialized Audit Log Callbacks - {litellm.audit_log_callbacks} {reset_color_code}" + ) # noqa elif key == "cache_params": # this is set in the cache branch # see usage here: https://docs.litellm.ai/docs/proxy/caching diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8ae0cf2892..dab40ffc6e 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2761,6 +2761,20 @@ class StandardLoggingPayloadStatusFields(TypedDict, total=False): """ +class StandardAuditLogPayload(TypedDict): + """Payload for audit log events dispatched to external callbacks.""" + + id: str + updated_at: str # ISO-8601 + changed_by: str + changed_by_api_key: str + action: str # "created" | "updated" | "deleted" | "blocked" | "rotated" + table_name: str + object_id: str + before_value: Optional[str] + updated_values: Optional[str] + + class StandardLoggingPayload(TypedDict): id: str trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries) diff --git a/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py b/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py new file mode 100644 index 0000000000..cfdf6ac033 --- /dev/null +++ b/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py @@ -0,0 +1,270 @@ +""" +Tests for audit log callback dispatch. + +Tests the flow: create_audit_log_for_update -> _dispatch_audit_log_to_callbacks -> CustomLogger.async_log_audit_log_event +""" + +import asyncio +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames +from litellm.proxy.management_helpers.audit_logs import ( + _build_audit_log_payload, + _dispatch_audit_log_to_callbacks, + create_audit_log_for_update, +) +from litellm.types.utils import StandardAuditLogPayload + + +@pytest.fixture(autouse=True) +def reset_audit_log_callbacks(): + """Reset audit_log_callbacks before and after each test.""" + original = litellm.audit_log_callbacks + litellm.audit_log_callbacks = [] + yield + litellm.audit_log_callbacks = original + + +def _make_audit_log( + action: str = "created", + table_name: LitellmTableNames = LitellmTableNames.TEAM_TABLE_NAME, +) -> LiteLLM_AuditLogs: + return LiteLLM_AuditLogs( + id="test-audit-id", + updated_at=datetime(2026, 3, 9, 12, 0, 0, tzinfo=timezone.utc), + changed_by="user-123", + changed_by_api_key="sk-abc", + action=action, + table_name=table_name, + object_id="team-456", + updated_values=json.dumps({"name": "new-team"}), + before_value=json.dumps({"name": "old-team"}), + ) + + +class TestBuildAuditLogPayload: + def test_builds_correct_payload(self): + audit_log = _make_audit_log() + payload = _build_audit_log_payload(audit_log) + + assert payload["id"] == "test-audit-id" + assert payload["updated_at"] == "2026-03-09T12:00:00+00:00" + assert payload["changed_by"] == "user-123" + assert payload["changed_by_api_key"] == "sk-abc" + assert payload["action"] == "created" + assert payload["table_name"] == "LiteLLM_TeamTable" + assert payload["object_id"] == "team-456" + assert payload["updated_values"] == json.dumps({"name": "new-team"}) + assert payload["before_value"] == json.dumps({"name": "old-team"}) + + def test_handles_none_values(self): + audit_log = LiteLLM_AuditLogs( + id="test-id", + updated_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + changed_by=None, + changed_by_api_key=None, + action="deleted", + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id="key-789", + updated_values=None, + before_value=None, + ) + payload = _build_audit_log_payload(audit_log) + + assert payload["changed_by"] == "" + assert payload["changed_by_api_key"] == "" + assert payload["before_value"] is None + assert payload["updated_values"] is None + + +class TestDispatchAuditLogToCallbacks: + @pytest.mark.asyncio + async def test_dispatches_to_custom_logger_instance(self): + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + audit_log = _make_audit_log() + await _dispatch_audit_log_to_callbacks(audit_log) + + # Let asyncio.create_task run + await asyncio.sleep(0.1) + + mock_logger.async_log_audit_log_event.assert_called_once() + payload = mock_logger.async_log_audit_log_event.call_args[0][0] + assert payload["id"] == "test-audit-id" + assert payload["action"] == "created" + + @pytest.mark.asyncio + async def test_no_dispatch_when_callbacks_empty(self): + litellm.audit_log_callbacks = [] + audit_log = _make_audit_log() + # Should return immediately without error + await _dispatch_audit_log_to_callbacks(audit_log) + + @pytest.mark.asyncio + async def test_resolves_string_callback(self): + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + + litellm.audit_log_callbacks = ["s3_v2"] + + with patch( + "litellm.proxy.management_helpers.audit_logs._resolve_audit_log_callback", + return_value=mock_logger, + ): + audit_log = _make_audit_log() + await _dispatch_audit_log_to_callbacks(audit_log) + await asyncio.sleep(0.1) + + mock_logger.async_log_audit_log_event.assert_called_once() + + @pytest.mark.asyncio + async def test_nonblocking_on_callback_failure(self): + """Callback errors should not propagate.""" + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock( + side_effect=RuntimeError("boom") + ) + litellm.audit_log_callbacks = [mock_logger] + + audit_log = _make_audit_log() + # Should not raise + await _dispatch_audit_log_to_callbacks(audit_log) + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + async def test_skips_unresolvable_string_callback(self): + litellm.audit_log_callbacks = ["nonexistent_callback"] + + with patch( + "litellm.proxy.management_helpers.audit_logs._resolve_audit_log_callback", + return_value=None, + ): + audit_log = _make_audit_log() + # Should not raise + await _dispatch_audit_log_to_callbacks(audit_log) + + +class TestCreateAuditLogForUpdateWithCallbacks: + @pytest.mark.asyncio + async def test_dispatches_to_callbacks_after_db_write(self): + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + with patch("litellm.proxy.proxy_server.premium_user", True), patch( + "litellm.store_audit_logs", True + ), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: + mock_prisma.db.litellm_auditlog.create = AsyncMock() + + audit_log = _make_audit_log() + await create_audit_log_for_update(audit_log) + await asyncio.sleep(0.1) + + # DB write should happen + mock_prisma.db.litellm_auditlog.create.assert_called_once() + # Callback should also be called + mock_logger.async_log_audit_log_event.assert_called_once() + + @pytest.mark.asyncio + async def test_no_dispatch_when_not_premium(self): + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + with patch("litellm.proxy.proxy_server.premium_user", False), patch( + "litellm.store_audit_logs", True + ): + audit_log = _make_audit_log() + await create_audit_log_for_update(audit_log) + await asyncio.sleep(0.1) + + mock_logger.async_log_audit_log_event.assert_not_called() + + @pytest.mark.asyncio + async def test_no_dispatch_when_store_audit_logs_false(self): + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + with patch("litellm.store_audit_logs", False): + audit_log = _make_audit_log() + await create_audit_log_for_update(audit_log) + await asyncio.sleep(0.1) + + mock_logger.async_log_audit_log_event.assert_not_called() + + +class TestS3LoggerAuditLogEvent: + @pytest.mark.asyncio + async def test_queues_audit_log_with_correct_s3_key(self): + with patch( + "litellm.integrations.s3_v2.S3Logger.__init__", return_value=None + ): + from litellm.integrations.s3_v2 import S3Logger + + logger = S3Logger() + logger.s3_path = "my-prefix" + logger.log_queue = [] + logger.batch_size = 100 + + audit_log = StandardAuditLogPayload( + id="audit-123", + updated_at="2026-03-09T12:00:00+00:00", + changed_by="user-1", + changed_by_api_key="sk-abc", + action="created", + table_name="LiteLLM_TeamTable", + object_id="team-1", + before_value=None, + updated_values='{"name": "new"}', + ) + + await logger.async_log_audit_log_event(audit_log) + + assert len(logger.log_queue) == 1 + element = logger.log_queue[0] + assert element.s3_object_key.startswith("my-prefix/audit_logs/") + assert "audit-123" in element.s3_object_key + assert element.s3_object_key.endswith(".json") + assert element.s3_object_download_filename == "audit-audit-123.json" + assert element.payload["id"] == "audit-123" + assert element.payload["action"] == "created" + + @pytest.mark.asyncio + async def test_s3_key_format_no_path(self): + with patch( + "litellm.integrations.s3_v2.S3Logger.__init__", return_value=None + ): + from litellm.integrations.s3_v2 import S3Logger + + logger = S3Logger() + logger.s3_path = None + logger.log_queue = [] + logger.batch_size = 100 + + audit_log = StandardAuditLogPayload( + id="audit-456", + updated_at="2026-03-09T12:00:00+00:00", + changed_by="user-1", + changed_by_api_key="sk-abc", + action="deleted", + table_name="LiteLLM_VerificationToken", + object_id="key-1", + before_value=None, + updated_values=None, + ) + + await logger.async_log_audit_log_event(audit_log) + + assert len(logger.log_queue) == 1 + element = logger.log_queue[0] + assert element.s3_object_key.startswith("audit_logs/") + assert "audit-456" in element.s3_object_key From 364f5c6e03dfe72c2223482d731c81ffd5858b56 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 11 Mar 2026 11:27:48 -0700 Subject: [PATCH 011/539] fix: ensure audit log callbacks fire even when DB is unavailable Two fixes based on PR feedback: 1. Move callback dispatch before the prisma_client check so audit logs still reach S3/Datadog even if the DB is down. Also changed the prisma_client=None case from raising an exception to logging an error and returning gracefully. 2. Attach a done_callback to asyncio tasks created for audit log callbacks so exceptions are logged through verbose_proxy_logger instead of silently swallowed. Co-Authored-By: Claude Opus 4.6 --- .../proxy/management_helpers/audit_logs.py | 32 ++++++-- .../test_audit_log_callbacks.py | 75 +++++++++++++++++++ 2 files changed, 100 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/management_helpers/audit_logs.py b/litellm/proxy/management_helpers/audit_logs.py index dee7c82ee5..a26a38bc4c 100644 --- a/litellm/proxy/management_helpers/audit_logs.py +++ b/litellm/proxy/management_helpers/audit_logs.py @@ -68,6 +68,18 @@ def _build_audit_log_payload( ) +def _audit_log_task_done_callback(task: asyncio.Task) -> None: + """Log exceptions from audit log callback tasks so they don't slip through silently.""" + try: + exc = task.exception() + except asyncio.CancelledError: + return + if exc is not None: + verbose_proxy_logger.error( + "Audit log callback task failed: %s", exc, exc_info=exc + ) + + async def _dispatch_audit_log_to_callbacks( request_data: LiteLLM_AuditLogs, ) -> None: @@ -89,7 +101,10 @@ async def _dispatch_audit_log_to_callbacks( continue if isinstance(resolved, CustomLogger): - asyncio.create_task(resolved.async_log_audit_log_event(payload)) + task = asyncio.create_task( + resolved.async_log_audit_log_event(payload) + ) + task.add_done_callback(_audit_log_task_done_callback) except Exception as e: verbose_proxy_logger.error( "Failed dispatching audit log to callback: %s", e @@ -160,9 +175,6 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): if premium_user is not True: return - if prisma_client is None: - raise Exception("prisma_client is None, no DB connected") - verbose_proxy_logger.debug("creating audit log for %s", request_data) if isinstance(request_data.updated_values, dict): @@ -171,6 +183,15 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): if isinstance(request_data.before_value, dict): request_data.before_value = json.dumps(request_data.before_value) + # Dispatch to external audit log callbacks regardless of DB availability + await _dispatch_audit_log_to_callbacks(request_data) + + if prisma_client is None: + verbose_proxy_logger.error( + "prisma_client is None, cannot write audit log to DB" + ) + return + _request_data = request_data.model_dump(exclude_none=True) try: @@ -182,6 +203,3 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): except Exception as e: # [Non-Blocking Exception. Do not allow blocking LLM API call] verbose_proxy_logger.error(f"Failed Creating audit log {e}") - - # Dispatch to external audit log callbacks - await _dispatch_audit_log_to_callbacks(request_data) diff --git a/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py b/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py index cfdf6ac033..85eda4368c 100644 --- a/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py +++ b/tests/test_litellm/proxy/management_helpers/test_audit_log_callbacks.py @@ -15,6 +15,7 @@ import litellm from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames from litellm.proxy.management_helpers.audit_logs import ( + _audit_log_task_done_callback, _build_audit_log_payload, _dispatch_audit_log_to_callbacks, create_audit_log_for_update, @@ -201,6 +202,80 @@ class TestCreateAuditLogForUpdateWithCallbacks: mock_logger.async_log_audit_log_event.assert_not_called() + @pytest.mark.asyncio + async def test_dispatches_even_when_prisma_client_is_none(self): + """Callbacks should fire even if DB is unavailable.""" + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + with patch("litellm.proxy.proxy_server.premium_user", True), patch( + "litellm.store_audit_logs", True + ), patch("litellm.proxy.proxy_server.prisma_client", None): + audit_log = _make_audit_log() + await create_audit_log_for_update(audit_log) + await asyncio.sleep(0.1) + + # Callback should still be called despite no DB + mock_logger.async_log_audit_log_event.assert_called_once() + + @pytest.mark.asyncio + async def test_dispatches_even_when_db_write_fails(self): + """Callbacks should fire even if the DB write raises.""" + mock_logger = MagicMock(spec=CustomLogger) + mock_logger.async_log_audit_log_event = AsyncMock() + litellm.audit_log_callbacks = [mock_logger] + + with patch("litellm.proxy.proxy_server.premium_user", True), patch( + "litellm.store_audit_logs", True + ), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: + mock_prisma.db.litellm_auditlog.create = AsyncMock( + side_effect=RuntimeError("DB connection lost") + ) + + audit_log = _make_audit_log() + await create_audit_log_for_update(audit_log) + await asyncio.sleep(0.1) + + # Callback should still be called despite DB failure + mock_logger.async_log_audit_log_event.assert_called_once() + + +class TestAuditLogTaskDoneCallback: + def test_logs_exception_from_failed_task(self): + """Done callback should log task exceptions.""" + mock_task = MagicMock(spec=asyncio.Task) + mock_task.exception.return_value = RuntimeError("callback failed") + + with patch( + "litellm.proxy.management_helpers.audit_logs.verbose_proxy_logger" + ) as mock_logger: + _audit_log_task_done_callback(mock_task) + mock_logger.error.assert_called_once() + assert "callback failed" in str(mock_logger.error.call_args) + + def test_no_log_on_success(self): + """Done callback should not log when task succeeds.""" + mock_task = MagicMock(spec=asyncio.Task) + mock_task.exception.return_value = None + + with patch( + "litellm.proxy.management_helpers.audit_logs.verbose_proxy_logger" + ) as mock_logger: + _audit_log_task_done_callback(mock_task) + mock_logger.error.assert_not_called() + + def test_handles_cancelled_task(self): + """Done callback should handle cancelled tasks gracefully.""" + mock_task = MagicMock(spec=asyncio.Task) + mock_task.exception.side_effect = asyncio.CancelledError() + + with patch( + "litellm.proxy.management_helpers.audit_logs.verbose_proxy_logger" + ) as mock_logger: + _audit_log_task_done_callback(mock_task) + mock_logger.error.assert_not_called() + class TestS3LoggerAuditLogEvent: @pytest.mark.asyncio From 9c3fab24adb3a36fcfa02dc2dfa4f87d7893d0f4 Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 12 Mar 2026 20:55:56 +0100 Subject: [PATCH 012/539] fix(proxy): restore per-entity breakdown in aggregated daily activity endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #21613 optimized the /user/daily/activity/aggregated endpoint by replacing find_many with a SQL GROUP BY query, but omitted entity_id from the SELECT/GROUP BY clauses and hardcoded entity_id_field=None in the call to _aggregate_spend_records. This caused breakdown.entities to always be empty in the response. Restore entity_id in the SQL query and forward entity_id_field and entity_metadata_field to the aggregation step. The GROUP BY performance benefit is preserved — the query still aggregates at the database level instead of fetching all individual rows into Python. --- .../common_daily_activity.py | 18 ++- .../test_common_daily_activity.py | 119 ++++++++++++++++++ 2 files changed, 126 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/management_endpoints/common_daily_activity.py b/litellm/proxy/management_endpoints/common_daily_activity.py index 011d2f7485..b05a183df0 100644 --- a/litellm/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/proxy/management_endpoints/common_daily_activity.py @@ -475,10 +475,8 @@ def _build_aggregated_sql_query( ) -> Tuple[str, List[Any]]: """Build a parameterized SQL GROUP BY query for aggregated daily activity. - Groups by (date, api_key, model, model_group, custom_llm_provider, + Groups by (entity_id, date, api_key, model, model_group, custom_llm_provider, mcp_namespaced_tool_name, endpoint) with SUMs on all metric columns. - The entity_id column is intentionally omitted from GROUP BY to collapse - rows across entities — this is where the biggest row reduction comes from. Returns: Tuple of (sql_query, params_list) ready for prisma_client.db.query_raw(). @@ -539,6 +537,7 @@ def _build_aggregated_sql_query( sql_query = f""" SELECT + "{entity_id_field}", date, api_key, model, @@ -556,8 +555,8 @@ def _build_aggregated_sql_query( SUM(failed_requests)::bigint AS failed_requests FROM "{pg_table}" WHERE {where_clause} - GROUP BY date, api_key, model, model_group, custom_llm_provider, - mcp_namespaced_tool_name, endpoint + GROUP BY "{entity_id_field}", date, api_key, model, model_group, + custom_llm_provider, mcp_namespaced_tool_name, endpoint ORDER BY date DESC """ @@ -735,8 +734,7 @@ async def get_daily_activity_aggregated( """Aggregated variant that returns the full result set (no pagination). Uses SQL GROUP BY to aggregate rows in the database rather than fetching - all individual rows into Python. This collapses rows across entities - (users/teams/orgs), reducing ~150k rows to ~2-3k grouped rows. + all individual rows into Python, preserving per-entity granularity. Matches the response model of the paginated endpoint so the UI does not need to transform. """ @@ -773,13 +771,11 @@ async def get_daily_activity_aggregated( # Convert dicts to objects for compatibility with _aggregate_spend_records records = [SimpleNamespace(**row) for row in rows] - # entity_id_field=None skips entity breakdown (entity dimension was - # collapsed by the GROUP BY, so per-entity data is not available) aggregated = await _aggregate_spend_records( prisma_client=prisma_client, records=records, - entity_id_field=None, - entity_metadata_field=None, + entity_id_field=entity_id_field, + entity_metadata_field=entity_metadata_field, ) return SpendAnalyticsPaginatedResponse( diff --git a/tests/test_litellm/proxy/management_endpoints/test_common_daily_activity.py b/tests/test_litellm/proxy/management_endpoints/test_common_daily_activity.py index 54fbac1264..c81792b0b2 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_common_daily_activity.py +++ b/tests/test_litellm/proxy/management_endpoints/test_common_daily_activity.py @@ -86,6 +86,7 @@ async def test_get_daily_activity_aggregated_with_endpoint_breakdown(): # query_raw returns list of dicts (pre-aggregated by GROUP BY) mock_rows = [ { + "user_id": "user-1", "date": "2024-01-01", "endpoint": "/v1/chat/completions", "api_key": "key-1", @@ -103,6 +104,7 @@ async def test_get_daily_activity_aggregated_with_endpoint_breakdown(): "failed_requests": 0, }, { + "user_id": "user-1", "date": "2024-01-01", "endpoint": "/v1/embeddings", "api_key": "key-2", @@ -452,6 +454,7 @@ async def test_aggregated_activity_preserves_metadata_for_deleted_keys(): # query_raw returns list of dicts (pre-aggregated by GROUP BY) mock_rows = [ { + "user_id": "user-1", "date": "2024-01-01", "endpoint": "/v1/chat/completions", "api_key": "deleted-key-hash", @@ -507,3 +510,119 @@ async def test_aggregated_activity_preserves_metadata_for_deleted_keys(): assert key_data.metadata.key_alias == "toto-test-2" assert key_data.metadata.team_id == "69cd4b77-b095-4489-8c46-4f2f31d840a2" assert key_data.metrics.spend == 10.0 + + +@pytest.mark.asyncio +async def test_get_daily_activity_aggregated_preserves_entity_breakdown(): + """Test that aggregated daily activity preserves per-entity breakdown. + + Regression test for PR #21613: the GROUP BY optimization omitted entity_id + from the query and hardcoded entity_id_field=None, causing breakdown.entities + to always be empty. + """ + mock_prisma = MagicMock() + mock_prisma.db = MagicMock() + + # query_raw returns rows WITH user_id (entity_id_field included in GROUP BY) + mock_rows = [ + { + "user_id": "user-alice", + "date": "2024-01-01", + "api_key": "key-1", + "model": "gpt-4", + "model_group": None, + "custom_llm_provider": "openai", + "mcp_namespaced_tool_name": None, + "endpoint": "/v1/chat/completions", + "spend": 20.0, + "prompt_tokens": 200, + "completion_tokens": 100, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + "api_requests": 5, + "successful_requests": 5, + "failed_requests": 0, + }, + { + "user_id": "user-bob", + "date": "2024-01-01", + "api_key": "key-2", + "model": "gpt-4", + "model_group": None, + "custom_llm_provider": "openai", + "mcp_namespaced_tool_name": None, + "endpoint": "/v1/chat/completions", + "spend": 8.0, + "prompt_tokens": 80, + "completion_tokens": 40, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + "api_requests": 2, + "successful_requests": 2, + "failed_requests": 0, + }, + { + "user_id": None, + "date": "2024-01-01", + "api_key": "key-3", + "model": "gpt-4", + "model_group": None, + "custom_llm_provider": "openai", + "mcp_namespaced_tool_name": None, + "endpoint": "/v1/chat/completions", + "spend": 2.0, + "prompt_tokens": 20, + "completion_tokens": 10, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + "api_requests": 1, + "successful_requests": 1, + "failed_requests": 0, + }, + ] + + mock_prisma.db.query_raw = AsyncMock(return_value=mock_rows) + mock_prisma.db.litellm_verificationtoken = MagicMock() + mock_prisma.db.litellm_verificationtoken.find_many = AsyncMock(return_value=[]) + + result = await get_daily_activity_aggregated( + prisma_client=mock_prisma, + table_name="litellm_dailyuserspend", + entity_id_field="user_id", + entity_id=None, + entity_metadata_field=None, + start_date="2024-01-01", + end_date="2024-01-01", + model=None, + api_key=None, + ) + + # Verify per-entity breakdown is populated (not empty) + daily_data = result.results[0] + assert len(daily_data.breakdown.entities) == 3, ( + "breakdown.entities should contain per-user entries" + ) + + # Verify individual entity metrics + assert "user-alice" in daily_data.breakdown.entities + assert daily_data.breakdown.entities["user-alice"].metrics.spend == 20.0 + assert daily_data.breakdown.entities["user-alice"].metrics.prompt_tokens == 200 + assert daily_data.breakdown.entities["user-alice"].metrics.api_requests == 5 + + assert "user-bob" in daily_data.breakdown.entities + assert daily_data.breakdown.entities["user-bob"].metrics.spend == 8.0 + assert daily_data.breakdown.entities["user-bob"].metrics.prompt_tokens == 80 + assert daily_data.breakdown.entities["user-bob"].metrics.api_requests == 2 + + # Verify NULL entity_id is mapped to "Unassigned" (line 294-296) + assert "Unassigned" in daily_data.breakdown.entities + assert daily_data.breakdown.entities["Unassigned"].metrics.spend == 2.0 + + # Verify per-entity API key breakdown + assert "key-1" in daily_data.breakdown.entities["user-alice"].api_key_breakdown + assert "key-2" in daily_data.breakdown.entities["user-bob"].api_key_breakdown + assert "key-3" in daily_data.breakdown.entities["Unassigned"].api_key_breakdown + + # Verify totals still correct + assert result.metadata.total_spend == 30.0 + assert result.metadata.total_api_requests == 8 From 78927138d288e6a780c83816b05f677b39b9e38e Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 12 Mar 2026 22:30:52 +0100 Subject: [PATCH 013/539] fix(proxy): add team_member_budget_duration to NewTeamRequest NewTeamRequest was missing the team_member_budget_duration field, causing Pydantic to silently drop the value when creating a team via POST /team/new. The template budget row was created without budget_duration or budget_reset_at, so the ResetBudgetJob never found it and team member spend was never reset. Add the field to NewTeamRequest and pass it through to should_create_budget and create_team_member_budget_table in the new_team handler (matching the existing update_team path which already works correctly). Fixes #16057 --- litellm/proxy/_types.py | 1 + .../management_endpoints/team_endpoints.py | 2 + .../test_team_endpoints.py | 50 +++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 1174740948..9b94a81eb0 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1669,6 +1669,7 @@ class NewTeamRequest(TeamBase): int ] = None # allow user to set TPM limit for all team members team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" + team_member_budget_duration: Optional[str] = None # e.g. "30d", "1mo" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None enforced_batch_output_expires_after: Optional[dict] = None enforced_file_expires_after: Optional[dict] = None diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 9c8e6f7282..18e41793df 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -898,6 +898,7 @@ async def new_team( # noqa: PLR0915 team_member_budget=data.team_member_budget, team_member_rpm_limit=data.team_member_rpm_limit, team_member_tpm_limit=data.team_member_tpm_limit, + team_member_budget_duration=data.team_member_budget_duration, ): data_json = await TeamMemberBudgetHandler.create_team_member_budget_table( data=data, @@ -906,6 +907,7 @@ async def new_team( # noqa: PLR0915 team_member_budget=data.team_member_budget, team_member_rpm_limit=data.team_member_rpm_limit, team_member_tpm_limit=data.team_member_tpm_limit, + team_member_budget_duration=data.team_member_budget_duration, ) ## ADD TO TEAM TABLE diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index 1aee1d4965..8a033fca96 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -6157,3 +6157,53 @@ async def test_list_team_v1_batches_key_queries(): assert result[0].keys == [key1, key2] assert result[1].team_id == "team-2" assert result[1].keys == [key3] + + +def test_new_team_request_accepts_team_member_budget_duration(): + """Test that NewTeamRequest does not silently drop team_member_budget_duration.""" + from litellm.proxy._types import NewTeamRequest + + request = NewTeamRequest( + team_member_budget=20.0, + team_member_budget_duration="30d", + ) + assert request.team_member_budget == 20.0 + assert request.team_member_budget_duration == "30d" + + +@pytest.mark.asyncio +async def test_create_team_member_budget_table_with_duration(): + """Verify that create_team_member_budget_table passes budget_duration + through to the new_budget call when team_member_budget_duration is provided.""" + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, LitellmUserRoles + from litellm.proxy.management_endpoints.team_endpoints import TeamMemberBudgetHandler + + mock_budget_response = MagicMock(budget_id="budget-abc") + mock_admin = UserAPIKeyAuth( + user_id="admin", user_role=LitellmUserRoles.PROXY_ADMIN + ) + + data = NewTeamRequest( + team_alias="test-team", + team_member_budget=20.0, + team_member_budget_duration="30d", + ) + + with patch( + "litellm.proxy.management_endpoints.budget_management_endpoints.new_budget", + new_callable=AsyncMock, + return_value=mock_budget_response, + ) as mock_new_budget: + result = await TeamMemberBudgetHandler.create_team_member_budget_table( + data=data, + new_team_data_json={"metadata": None}, + user_api_key_dict=mock_admin, + team_member_budget=20.0, + team_member_budget_duration="30d", + ) + + mock_new_budget.assert_awaited_once() + budget_request = mock_new_budget.call_args.kwargs["budget_obj"] + assert budget_request.budget_duration == "30d" + assert budget_request.max_budget == 20.0 + assert result["metadata"]["team_member_budget_id"] == "budget-abc" From b74571214fd0c123ca9379232f1519e4751dc001 Mon Sep 17 00:00:00 2001 From: Chesars Date: Sat, 14 Mar 2026 00:53:56 -0300 Subject: [PATCH 014/539] chore: remove debug scripts and unused import Remove 8 development scripts from scripts/ that were accidentally committed. Remove unused `import litellm` from responses_adapters/transformation.py. --- .../responses_adapters/transformation.py | 1 - scripts/test_gpt54_reasoning_tools.py | 77 -------- scripts/test_perplexity_regression.py | 83 -------- scripts/test_perplexity_responses.py | 179 ------------------ scripts/test_reasoning_none_tools.py | 39 ---- scripts/test_reasoning_tools.py | 19 -- scripts/test_tool_choice_responses.py | 45 ----- scripts/test_tool_search_chat.py | 57 ------ scripts/test_tool_search_responses.py | 101 ---------- 9 files changed, 601 deletions(-) delete mode 100644 scripts/test_gpt54_reasoning_tools.py delete mode 100644 scripts/test_perplexity_regression.py delete mode 100644 scripts/test_perplexity_responses.py delete mode 100644 scripts/test_reasoning_none_tools.py delete mode 100644 scripts/test_reasoning_tools.py delete mode 100644 scripts/test_tool_choice_responses.py delete mode 100644 scripts/test_tool_search_chat.py delete mode 100644 scripts/test_tool_search_responses.py diff --git a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py index 54dd86eaba..2fef5dee46 100644 --- a/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/responses_adapters/transformation.py @@ -8,7 +8,6 @@ path used for OpenAI and Azure models. import json from typing import Any, Dict, List, Optional, Union, cast -import litellm from litellm.llms.anthropic.experimental_pass_through.utils import ( is_default_reasoning_summary_disabled, ) diff --git a/scripts/test_gpt54_reasoning_tools.py b/scripts/test_gpt54_reasoning_tools.py deleted file mode 100644 index df6dea9ce7..0000000000 --- a/scripts/test_gpt54_reasoning_tools.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Repro script: verify that gpt-5.4 drops reasoning_effort when tools are present. -Expected: the call succeeds (reasoning_effort is silently dropped). -If the bug were still present, OpenAI would return an error like: - "reasoning_effort is not supported with function calling" -""" - -import os -from dotenv import load_dotenv - -load_dotenv() - -import litellm - -litellm.set_verbose = True - -tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get current weather for a city", - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string", "description": "City name"}, - }, - "required": ["city"], - }, - }, - } -] - -print("=== Test: gpt-5.4 + reasoning_effort='medium' + tools ===") -try: - response = litellm.completion( - model="gpt-5.4", - messages=[{"role": "user", "content": "What's the weather in Buenos Aires?"}], - reasoning_effort="medium", - tools=tools, - drop_params=True, - ) - print(f"SUCCESS - model: {response.model}") - print(f"Choice: {response.choices[0].message}") - if response.choices[0].message.tool_calls: - print(f"Tool calls: {response.choices[0].message.tool_calls}") - print("\nreasoning_effort was correctly dropped (no error from OpenAI)") -except Exception as e: - print(f"FAILED: {e}") - -print("\n=== Test: gpt-5.4 + reasoning_effort='high' + tools ===") -try: - response = litellm.completion( - model="gpt-5.4", - messages=[{"role": "user", "content": "What's 2+2?"}], - reasoning_effort="high", - tools=tools, - drop_params=True, - ) - print(f"SUCCESS - model: {response.model}") - print(f"reasoning_effort was correctly dropped (no error from OpenAI)") -except Exception as e: - print(f"FAILED: {e}") - -print("\n=== Test: gpt-5.4 + reasoning_effort='none' + tools (should KEEP reasoning_effort) ===") -try: - response = litellm.completion( - model="gpt-5.4", - messages=[{"role": "user", "content": "Say hello"}], - reasoning_effort="none", - tools=tools, - drop_params=True, - ) - print(f"SUCCESS - model: {response.model}") - print(f"reasoning_effort='none' correctly kept (OpenAI allows this)") -except Exception as e: - print(f"FAILED: {e}") diff --git a/scripts/test_perplexity_regression.py b/scripts/test_perplexity_regression.py deleted file mode 100644 index 67a01e2dcd..0000000000 --- a/scripts/test_perplexity_regression.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Simple regression test: call Perplexity through LiteLLM -to verify chat completions and responses API both work. -""" -import os -import sys -from dotenv import load_dotenv - -load_dotenv() - -import litellm - -# Show which branch we're on -branch = os.popen("git rev-parse --abbrev-ref HEAD 2>/dev/null || echo unknown").read().strip() -print(f"=== Branch: {branch} ===\n") - -# 1. Chat completions -print("--- Test 1: Chat Completions ---") -try: - resp = litellm.completion( - model="perplexity/sonar", - messages=[{"role": "user", "content": "Say hello in 3 words"}], - max_tokens=20, - ) - print(f"OK: {resp.choices[0].message.content[:80]}") - print(f" model: {resp.model}") - print(f" usage: {resp.usage}") -except Exception as e: - print(f"FAIL: {e}") - -# 2. Responses API (string input) -print("\n--- Test 2: Responses API (string input) ---") -try: - resp = litellm.responses( - model="perplexity/sonar", - input="Say hello in 3 words", - max_output_tokens=20, - ) - print(f"OK: {resp.output[0].content[0].text[:80]}") - print(f" model: {resp.model}") -except Exception as e: - print(f"FAIL: {e}") - -# 3. Responses API (list input - the _format_input concern) -print("\n--- Test 3: Responses API (list input without type field) ---") -try: - resp = litellm.responses( - model="perplexity/sonar", - input=[{"role": "user", "content": "Say hello in 3 words"}], - max_output_tokens=20, - ) - print(f"OK: {resp.output[0].content[0].text[:80]}") -except Exception as e: - print(f"FAIL: {e}") - -# 4. Check which config class is resolved for chat -print("\n--- Test 4: Config class resolution ---") -from litellm.utils import ProviderConfigManager -from litellm.types.utils import LlmProviders - -chat_config = ProviderConfigManager.get_provider_chat_config( - model="perplexity/sonar", provider=LlmProviders.PERPLEXITY -) -print(f"Chat config class: {type(chat_config).__name__}") -print(f" module: {type(chat_config).__module__}") - -resp_config = ProviderConfigManager.get_provider_responses_api_config( - provider=LlmProviders.PERPLEXITY -) -print(f"Responses config class: {type(resp_config).__name__}") -print(f" module: {type(resp_config).__module__}") - -# 5. Check supported params include preset/models for responses -print("\n--- Test 5: Supported params ---") -if resp_config: - params = resp_config.get_supported_openai_params("sonar") - print(f"Responses supported params: {params}") - has_preset = "preset" in params - has_models = "models" in params - print(f" Has 'preset': {has_preset}") - print(f" Has 'models': {has_models}") - -print("\n=== Done ===") diff --git a/scripts/test_perplexity_responses.py b/scripts/test_perplexity_responses.py deleted file mode 100644 index 5616ada380..0000000000 --- a/scripts/test_perplexity_responses.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Live test: Perplexity Responses API via LiteLLM. -Tests: non-streaming, streaming, preset models, models fallback param, -chat completions (regression check), and cost dict→float parsing. - -DO NOT COMMIT this file. -""" - -import os -import traceback - -from dotenv import load_dotenv - -load_dotenv() - -import litellm - -# litellm.set_verbose = True - - -def test_non_streaming_preset(): - """Test non-streaming with preset model.""" - print("=" * 60) - print("TEST 1: Non-streaming preset/pro-search") - print("=" * 60) - - response = litellm.responses( - model="perplexity/preset/pro-search", - input="What is 2 + 2? Answer in one word.", - ) - - print(f" Response ID: {response.id}") - print(f" Model: {response.model}") - print(f" Status: {response.status}") - - assert response.status == "completed", f"FAIL: status={response.status}" - assert response.output, "FAIL: no output" - print(" PASS: non-streaming preset works") - - if response.usage and response.usage.cost is not None: - assert isinstance(response.usage.cost, (int, float)), ( - f"FAIL: cost is {type(response.usage.cost)}: {response.usage.cost}" - ) - print(f" PASS: cost={response.usage.cost} (float, not dict)") - print() - - -def test_streaming_preset(): - """Test streaming with preset model.""" - print("=" * 60) - print("TEST 2: Streaming preset/pro-search") - print("=" * 60) - - response = litellm.responses( - model="perplexity/preset/pro-search", - input="What is the capital of France? One word.", - stream=True, - ) - - chunks = 0 - completed = False - for chunk in response: - chunks += 1 - event_type = getattr(chunk, "type", "unknown") - if event_type == "response.output_text.delta": - print(f" delta: {chunk.delta}", end="", flush=True) - elif event_type == "response.completed": - completed = True - print(f"\n [completed] model={chunk.response.model}") - if chunk.response.usage and chunk.response.usage.cost is not None: - cost = chunk.response.usage.cost - assert isinstance(cost, (int, float)), ( - f"FAIL: streaming cost is {type(cost)}: {cost}" - ) - print(f" PASS: streaming cost={cost} (float)") - - assert chunks > 0, "FAIL: no chunks received" - assert completed, "FAIL: never got response.completed event" - print(f" Total chunks: {chunks}") - print(" PASS: streaming preset works") - print() - - -def test_models_fallback_param(): - """Test that 'models' param (Perplexity fallback chain) is forwarded.""" - print("=" * 60) - print("TEST 3: models param (fallback chain)") - print("=" * 60) - - response = litellm.responses( - model="perplexity/openai/gpt-5.1", - input="Say 'hello' and nothing else.", - models=["openai/gpt-5-mini", "openai/gpt-5.1"], - ) - - print(f" Response ID: {response.id}") - print(f" Model used: {response.model}") - print(f" Status: {response.status}") - - assert response.status == "completed", f"FAIL: status={response.status}" - print(" PASS: models fallback param works") - print() - - -def test_chat_completions_not_broken(): - """Regression: Perplexity chat completions must still use PerplexityChatConfig.""" - print("=" * 60) - print("TEST 4: Chat completions regression check") - print("=" * 60) - - response = litellm.completion( - model="perplexity/sonar", - messages=[{"role": "user", "content": "Say 'hi' and nothing else."}], - max_tokens=10, - ) - - print(f" Model: {response.model}") - print(f" Content: {response.choices[0].message.content[:50]}") - - assert response.choices, "FAIL: no choices" - assert response.choices[0].message.content, "FAIL: empty content" - print(" PASS: chat completions still work (no regression)") - print() - - -def test_with_instructions(): - """Test instructions param.""" - print("=" * 60) - print("TEST 5: instructions param") - print("=" * 60) - - response = litellm.responses( - model="perplexity/preset/pro-search", - input="What is Python?", - instructions="Answer in exactly 5 words.", - ) - - print(f" Status: {response.status}") - # Extract text from output - for item in response.output: - if hasattr(item, "content"): - for c in item.content: - if hasattr(c, "text"): - print(f" Answer: {c.text}") - break - - assert response.status == "completed", f"FAIL: status={response.status}" - print(" PASS: instructions param works") - print() - - -if __name__ == "__main__": - api_key = os.environ.get("PERPLEXITYAI_API_KEY", "NOT SET") - print(f"Using PERPLEXITYAI_API_KEY: {api_key[:10]}...") - print() - - tests = [ - test_non_streaming_preset, - test_streaming_preset, - test_models_fallback_param, - test_chat_completions_not_broken, - test_with_instructions, - ] - - passed = 0 - failed = 0 - for test in tests: - try: - test() - passed += 1 - except Exception as e: - failed += 1 - print(f" FAIL: {e}") - traceback.print_exc() - print() - - print("=" * 60) - print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") - print("=" * 60) diff --git a/scripts/test_reasoning_none_tools.py b/scripts/test_reasoning_none_tools.py deleted file mode 100644 index 30d7bd1dab..0000000000 --- a/scripts/test_reasoning_none_tools.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -Repro: gpt-5.4 + reasoning_effort='none' + tools -Current behavior: reasoning_effort='none' is NOT dropped, but OpenAI rejects it. -""" - -import os -from dotenv import load_dotenv - -load_dotenv() - -import litellm - -tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get current weather", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } -] - -print("=== gpt-5.4 + reasoning_effort='none' + tools ===") -try: - response = litellm.completion( - model="gpt-5.4", - messages=[{"role": "user", "content": "What's the weather in Buenos Aires?"}], - reasoning_effort="none", - tools=tools, - ) - print(f"SUCCESS - model: {response.model}") - print(f"Choice: {response.choices[0].message}") -except Exception as e: - print(f"FAILED: {e}") diff --git a/scripts/test_reasoning_tools.py b/scripts/test_reasoning_tools.py deleted file mode 100644 index b925d0ade2..0000000000 --- a/scripts/test_reasoning_tools.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Test gpt-5.4 with reasoning_effort + tools to see OpenAI's response.""" -import os -from dotenv import load_dotenv -load_dotenv() - -import litellm - -try: - response = litellm.completion( - model="gpt-5.4", - messages=[{"role": "user", "content": "What's the weather in SF?"}], - tools=[{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}}}], - reasoning_effort="high", - ) - print("SUCCESS:") - print(response) -except Exception as e: - print(f"ERROR ({type(e).__name__}):") - print(e) diff --git a/scripts/test_tool_choice_responses.py b/scripts/test_tool_choice_responses.py deleted file mode 100644 index dab599933a..0000000000 --- a/scripts/test_tool_choice_responses.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Post-fix verification for #23423: tool_choice with responses/ prefix.""" -import os -from dotenv import load_dotenv -load_dotenv() - -import litellm - -# Verify supports_tool_choice resolves correctly -from litellm.utils import supports_tool_choice -print("supports_tool_choice('gpt-5.4'):", supports_tool_choice("gpt-5.4")) -print("supports_tool_choice('openai/responses/gpt-5.4'):", supports_tool_choice("openai/responses/gpt-5.4")) - -# Verify tool_choice is in supported params -params = litellm.get_supported_openai_params(model="openai/responses/gpt-5.4", custom_llm_provider="openai") -print("tool_choice in supported params:", "tool_choice" in params) - -# Real API call with tool_choice -tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the weather for a city", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } -] - -response = litellm.completion( - model="openai/responses/gpt-4.1-nano", # cheaper model - messages=[{"role": "user", "content": "What's the weather in Buenos Aires?"}], - tools=tools, - tool_choice="required", -) - -print("\nResponse:") -print(" tool_calls:", response.choices[0].message.tool_calls) -print(" finish_reason:", response.choices[0].finish_reason) - -has_tool_call = response.choices[0].message.tool_calls is not None -print("\nVERDICT:", "PASS - tool_choice works" if has_tool_call else "FAIL - tool_choice dropped") diff --git a/scripts/test_tool_search_chat.py b/scripts/test_tool_search_chat.py deleted file mode 100644 index 9cfecae983..0000000000 --- a/scripts/test_tool_search_chat.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Test the Chat Completions Bridge tool search example from docs (line 856-887)""" -import os -from dotenv import load_dotenv -load_dotenv() - -import litellm - -try: - response = litellm.completion( - model="openai/responses/gpt-5.4", - messages=[{"role": "user", "content": "Look up invoice INV-2024-001"}], - tools=[ - {"type": "tool_search"}, - { - "type": "namespace", - "name": "billing", - "description": "Billing and invoicing tools", - "tools": [ - { - "type": "function", - "name": "get_invoice", - "description": "Get an invoice by ID", - "parameters": { - "type": "object", - "properties": {"invoice_id": {"type": "string"}}, - "required": ["invoice_id"], - }, - "defer_loading": True, - }, - ], - }, - ], - ) - - print("=== Raw response ===") - print(f"tool_calls value: {response.choices[0].message.tool_calls}") - print(f"tool_calls is None? {response.choices[0].message.tool_calls is None}") - print() - - # Test the docs code exactly as written - print("=== Testing docs code (no None guard) ===") - try: - for tool_call in response.choices[0].message.tool_calls: - print(f"Called: {tool_call.function.name}({tool_call.function.arguments})") - except TypeError as e: - print(f" !!! TypeError: {e}") - print(f" Greptile was RIGHT - need 'or []' guard") - - # Test with the fix - print() - print("=== Testing with fix (or [] guard) ===") - for tool_call in (response.choices[0].message.tool_calls or []): - print(f"Called: {tool_call.function.name}({tool_call.function.arguments})") - print(" OK - no crash") - -except Exception as e: - print(f"API Error: {type(e).__name__}: {e}") diff --git a/scripts/test_tool_search_responses.py b/scripts/test_tool_search_responses.py deleted file mode 100644 index 58c7f1e2ac..0000000000 --- a/scripts/test_tool_search_responses.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Test the Responses API tool search example from docs (line 705-783)""" -import os -from dotenv import load_dotenv -load_dotenv() - -import litellm -import json - -# Define namespaces with deferred tools -tools = [ - {"type": "tool_search"}, # Enable tool search - { - "type": "namespace", - "name": "crm", - "description": "CRM tools for customer management", - "tools": [ - { - "type": "function", - "name": "get_customer", - "description": "Get customer details by ID", - "parameters": { - "type": "object", - "properties": { - "customer_id": {"type": "string"} - }, - "required": ["customer_id"], - }, - "defer_loading": True, - }, - { - "type": "function", - "name": "list_customers", - "description": "List customers with optional filters", - "parameters": { - "type": "object", - "properties": { - "status": {"type": "string", "enum": ["active", "inactive"]}, - }, - }, - "defer_loading": True, - }, - ], - }, - { - "type": "namespace", - "name": "billing", - "description": "Billing and invoicing tools", - "tools": [ - { - "type": "function", - "name": "get_invoice", - "description": "Get an invoice by ID", - "parameters": { - "type": "object", - "properties": { - "invoice_id": {"type": "string"} - }, - "required": ["invoice_id"], - }, - "defer_loading": True, - }, - ], - }, -] - -try: - response = litellm.responses( - model="openai/gpt-5.4", - input="Look up invoice INV-2024-001 from the billing system", - tools=tools, - ) - - print("=== Raw response.output ===") - print(response.output) - print() - - # Test the parsing code from the docs - print("=== Parsing output items ===") - for item in response.output: - print(f" item type: {type(item)}") - if isinstance(item, dict): - print(f" dict keys: {item.keys()}") - if item["type"] == "tool_search_call": - print(f"Searched namespaces: {item['arguments']['paths']}") - elif item["type"] == "tool_search_output": - print(f"Loaded {len(item['tools'])} tool(s)") - elif item["type"] == "function_call": - print(f"Called: {item.get('namespace', '')}.{item['name']}({item['arguments']})") - else: - print(f" object attrs: {dir(item)}") - if item.type == "function_call": - # Greptile says this will fail if namespace is missing - print(f" Has 'namespace' attr? {hasattr(item, 'namespace')}") - try: - print(f"Called: {item.namespace}.{item.name}({item.arguments})") - except AttributeError as e: - print(f" !!! AttributeError: {e}") - print(f" Greptile was RIGHT - need getattr fallback") - -except Exception as e: - print(f"API Error: {type(e).__name__}: {e}") From 0bdfd95ad8793e6f62e401a1172d8ae8fbdfdde8 Mon Sep 17 00:00:00 2001 From: "Ethan T." Date: Sat, 14 Mar 2026 15:12:48 +0800 Subject: [PATCH 015/539] fix: map Chat Completion file type to Responses API input_file When bridging /chat/completions to the Responses API, content items with type 'file' were falling through to the default handler and being stringified as input_text. This caused the model to receive the Python dict representation as plain text instead of the actual file content. Add explicit handling for type 'file' that correctly maps: {"type": "file", "file": {"file_data": "...", "filename": "..."}} to: {"type": "input_file", "file_data": "...", "filename": "..."} Fixes BerriAI/litellm#23588 --- .../transformation.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index 4b31bcfc28..ab69e9ca13 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -693,6 +693,20 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge): verbose_logger.debug( f"Chat provider: image -> {converted}" ) + elif item_type == "file": + # Map Chat Completion file to Responses API input_file + # {"type": "file", "file": {"file_data": "...", "filename": "..."}} + # -> {"type": "input_file", "file_data": "...", "filename": "..."} + file_data = item.get("file", {}) + converted = {"type": "input_file"} + if isinstance(file_data, dict): + for key in ["file_id", "file_data", "filename"]: + if key in file_data: + converted[key] = file_data[key] + result.append(converted) + verbose_logger.debug( + f"Chat provider: file -> {converted}" + ) elif item_type in [ "input_text", "input_image", From 71c9ba0b1b620ce40b5149376e18cf87bd07f000 Mon Sep 17 00:00:00 2001 From: "Ethan T." Date: Sat, 14 Mar 2026 15:12:54 +0800 Subject: [PATCH 016/539] test: add tests for file type to input_file mapping --- ...responses_transformation_transformation.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py index da38353269..c1d102f65e 100644 --- a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py +++ b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py @@ -1995,3 +1995,91 @@ def test_map_optional_params_preserves_reasoning_summary(): assert responses_api_request["reasoning"] == {"effort": "high", "summary": "detailed"} assert responses_api_request["reasoning"]["effort"] == "high" assert responses_api_request["reasoning"]["summary"] == "detailed" + + +def test_convert_chat_completion_file_type_to_input_file(): + """ + Test that Chat Completion content with type 'file' is correctly mapped + to Responses API 'input_file' format, not stringified as 'input_text'. + + Regression test for https://github.com/BerriAI/litellm/issues/23588 + """ + from litellm.completion_extras.litellm_responses_transformation.transformation import ( + LiteLLMResponsesTransformationHandler, + ) + + handler = LiteLLMResponsesTransformationHandler() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this PDF?"}, + { + "type": "file", + "file": { + "file_data": "data:application/pdf;base64,JVBERi0xLjQK", + "filename": "test.pdf", + }, + }, + ], + } + ] + + input_items, instructions = handler.convert_chat_completion_messages_to_responses_api( + messages + ) + + assert len(input_items) == 1 + msg = input_items[0] + assert msg["type"] == "message" + assert msg["role"] == "user" + + content = msg["content"] + assert len(content) == 2 + + # First item should be the text + assert content[0]["type"] == "input_text" + assert content[0]["text"] == "What is in this PDF?" + + # Second item should be input_file, NOT input_text with stringified dict + assert content[1]["type"] == "input_file" + assert content[1]["file_data"] == "data:application/pdf;base64,JVBERi0xLjQK" + assert content[1]["filename"] == "test.pdf" + # Ensure it does NOT have the nested 'file' key + assert "file" not in content[1] + + +def test_convert_chat_completion_file_type_with_file_id(): + """ + Test that Chat Completion content with type 'file' using file_id is correctly mapped. + """ + from litellm.completion_extras.litellm_responses_transformation.transformation import ( + LiteLLMResponsesTransformationHandler, + ) + + handler = LiteLLMResponsesTransformationHandler() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Summarize this file."}, + { + "type": "file", + "file": { + "file_id": "file-abc123", + }, + }, + ], + } + ] + + input_items, instructions = handler.convert_chat_completion_messages_to_responses_api( + messages + ) + + content = input_items[0]["content"] + assert content[1]["type"] == "input_file" + assert content[1]["file_id"] == "file-abc123" + assert "file_data" not in content[1] From 6658a8ffb3216a8af1fc1f2cdfa2bd9feae8b39e Mon Sep 17 00:00:00 2001 From: "Ethan T." Date: Sat, 14 Mar 2026 21:18:35 +0800 Subject: [PATCH 017/539] style: apply black formatting to transformation.py --- .../transformation.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/litellm/completion_extras/litellm_responses_transformation/transformation.py b/litellm/completion_extras/litellm_responses_transformation/transformation.py index ab69e9ca13..f5856ab1f4 100644 --- a/litellm/completion_extras/litellm_responses_transformation/transformation.py +++ b/litellm/completion_extras/litellm_responses_transformation/transformation.py @@ -240,10 +240,10 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge): if key in ("max_tokens", "max_completion_tokens"): responses_api_request["max_output_tokens"] = value elif key == "tools" and value is not None: - responses_api_request[ - "tools" - ] = self._convert_tools_to_responses_format( - cast(List[Dict[str, Any]], value) + responses_api_request["tools"] = ( + self._convert_tools_to_responses_format( + cast(List[Dict[str, Any]], value) + ) ) elif key == "response_format": text_format = self._transform_response_format_to_text_format(value) @@ -398,6 +398,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge): ResponseOutputMessage, ResponseReasoningItem, ) + try: from openai.types.responses.response_output_item import ( ResponseApplyPatchToolCall, @@ -460,7 +461,9 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge): accumulated_tool_calls.append(tool_call_dict) tool_call_index += 1 - elif ResponseApplyPatchToolCall is not None and isinstance(item, ResponseApplyPatchToolCall): + elif ResponseApplyPatchToolCall is not None and isinstance( + item, ResponseApplyPatchToolCall + ): from litellm.responses.litellm_completion_transformation.transformation import ( LiteLLMCompletionResponsesConfig, ) @@ -1069,9 +1072,9 @@ class OpenAiResponsesToChatCompletionStreamIterator(BaseModelResponseIterator): ) if provider_specific_fields: - function_chunk[ - "provider_specific_fields" - ] = provider_specific_fields + function_chunk["provider_specific_fields"] = ( + provider_specific_fields + ) tool_call_index = parsed_chunk.get("output_index", 0) tool_call_chunk = ChatCompletionToolCallChunk( @@ -1144,9 +1147,9 @@ class OpenAiResponsesToChatCompletionStreamIterator(BaseModelResponseIterator): # Add provider_specific_fields to function if present if provider_specific_fields: - function_chunk[ - "provider_specific_fields" - ] = provider_specific_fields + function_chunk["provider_specific_fields"] = ( + provider_specific_fields + ) tool_call_index = parsed_chunk.get("output_index", 0) tool_call_chunk = ChatCompletionToolCallChunk( From 98890e771de1d6850dcf9030b8becb99115c0849 Mon Sep 17 00:00:00 2001 From: "Ethan T." Date: Sat, 14 Mar 2026 21:19:03 +0800 Subject: [PATCH 018/539] style: apply black formatting to test file --- ...responses_transformation_transformation.py | 309 ++++++++++++------ 1 file changed, 205 insertions(+), 104 deletions(-) diff --git a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py index c1d102f65e..8bc6ffc050 100644 --- a/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py +++ b/tests/test_litellm/completion_extras/litellm_responses_transformation/test_completion_extras_litellm_responses_transformation_transformation.py @@ -9,7 +9,9 @@ from unittest.mock import ANY, MagicMock, Mock, patch import httpx import pytest -sys.path.insert(0, os.path.abspath("../../..")) # Adds the parent directory to the system-path +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system-path import litellm @@ -117,7 +119,9 @@ def test_convert_chat_completion_messages_to_responses_api_tool_result_with_imag function_call_output = item break - assert function_call_output is not None, "function_call_output not found in response" + assert ( + function_call_output is not None + ), "function_call_output not found in response" assert function_call_output["call_id"] == "call_abc123" # Check that the output is correctly transformed @@ -127,8 +131,12 @@ def test_convert_chat_completion_messages_to_responses_api_tool_result_with_imag image_item = output[0] # Should be transformed to Responses API format - assert image_item["type"] == "input_image", f"Expected type 'input_image', got '{image_item.get('type')}'" - assert image_item["image_url"] == test_image_base64, "image_url should be a flat string, not a nested object" + assert ( + image_item["type"] == "input_image" + ), f"Expected type 'input_image', got '{image_item.get('type')}'" + assert ( + image_item["image_url"] == test_image_base64 + ), "image_url should be a flat string, not a nested object" assert "detail" in image_item, "detail field should be present" print("✓ Tool result with image correctly transformed to Responses API format") @@ -190,7 +198,9 @@ def test_convert_chat_completion_messages_to_responses_api_tool_result_with_text function_call_output = item break - assert function_call_output is not None, "function_call_output not found in response" + assert ( + function_call_output is not None + ), "function_call_output not found in response" assert function_call_output["call_id"] == "call_abc123" # Check that the output is correctly transformed to use input_text, not output_text @@ -200,12 +210,16 @@ def test_convert_chat_completion_messages_to_responses_api_tool_result_with_text text_item = output[0] # Should be transformed to use input_text for tool results in Responses API format - assert text_item["type"] == "input_text", ( - f"Expected type 'input_text' for tool result, got '{text_item.get('type')}'" - ) - assert text_item["text"] == "15 degrees", f"Expected text '15 degrees', got '{text_item.get('text')}'" + assert ( + text_item["type"] == "input_text" + ), f"Expected type 'input_text' for tool result, got '{text_item.get('type')}'" + assert ( + text_item["text"] == "15 degrees" + ), f"Expected text '15 degrees', got '{text_item.get('text')}'" - print("✓ Tool result with text correctly transformed to use input_text for Responses API format") + print( + "✓ Tool result with text correctly transformed to use input_text for Responses API format" + ) def test_openai_responses_chunk_parser_reasoning_summary(): @@ -214,7 +228,9 @@ def test_openai_responses_chunk_parser_reasoning_summary(): ) from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = { "delta": "**Compar", @@ -246,7 +262,9 @@ def test_chunk_parser_string_output_text_delta_produces_text(): ) from litellm.types.utils import ModelResponseStream - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = {"type": "response.output_text.delta", "delta": "literal text"} @@ -267,7 +285,9 @@ def test_chunk_parser_enum_output_text_delta_produces_text(): from litellm.types.llms.openai import ResponsesAPIStreamEvents from litellm.types.utils import ModelResponseStream - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = {"type": ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA, "delta": "enum text"} @@ -288,7 +308,9 @@ def test_chunk_parser_function_call_added_produces_tool_use(): from litellm.types.llms.openai import ResponsesAPIStreamEvents from litellm.types.utils import ModelResponseStream - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = { "type": ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED, @@ -373,7 +395,9 @@ Tomorrow will bring its petitions and promises, but for now the city breathes slow and wide, and I learn to carry this small calm home.""" - output_text = ResponseOutputText(annotations=[], text=poem_text, type="output_text", logprobs=[]) + output_text = ResponseOutputText( + annotations=[], text=poem_text, type="output_text", logprobs=[] + ) output_message = ResponseOutputMessage( id="msg_04c8021b8b3188a00068e9ae0b92f4819dac64d85b4abb67ec", content=[output_text], @@ -385,7 +409,9 @@ and I learn to carry this small calm home.""" # Create usage information usage = ResponseAPIUsage( input_tokens=16, - input_tokens_details=InputTokensDetails(audio_tokens=None, cached_tokens=0, text_tokens=None), + input_tokens_details=InputTokensDetails( + audio_tokens=None, cached_tokens=0, text_tokens=None + ), output_tokens=195, output_tokens_details=OutputTokensDetails(reasoning_tokens=0, text_tokens=None), total_tokens=211, @@ -597,7 +623,9 @@ def test_transform_request_single_char_keys_not_matched(): assert result_correct.get("metadata") == {"user_id": "123"} assert result_correct.get("previous_response_id") == "resp_abc" - print("✓ Single-character keys are not incorrectly matched to metadata/previous_response_id") + print( + "✓ Single-character keys are not incorrectly matched to metadata/previous_response_id" + ) # ============================================================================= @@ -617,7 +645,9 @@ def test_message_done_does_not_emit_is_finished(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = { "type": "response.output_item.done", @@ -629,9 +659,9 @@ def test_message_done_does_not_emit_is_finished(): # After the fix, message completion should NOT set finish_reason # ModelResponseStream doesn't have is_finished - check finish_reason instead assert len(result.choices) > 0, "result should have choices" - assert result.choices[0].finish_reason is None or result.choices[0].finish_reason == "", ( - "message completion should not emit finish_reason" - ) + assert ( + result.choices[0].finish_reason is None or result.choices[0].finish_reason == "" + ), "message completion should not emit finish_reason" def test_response_completed_emits_is_finished(): @@ -643,7 +673,9 @@ def test_response_completed_emits_is_finished(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = {"type": "response.completed"} @@ -651,7 +683,9 @@ def test_response_completed_emits_is_finished(): # response.completed should emit finish_reason='stop' assert len(result.choices) > 0, "result should have choices" - assert result.choices[0].finish_reason == "stop", "response.completed should emit finish_reason='stop'" + assert ( + result.choices[0].finish_reason == "stop" + ), "response.completed should emit finish_reason='stop'" def test_response_completed_with_function_calls_emits_tool_calls_finish_reason(): @@ -670,7 +704,9 @@ def test_response_completed_with_function_calls_emits_tool_calls_finish_reason() OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) # Simulate a response.completed event with function_call in output # This matches what Azure/OpenAI sends for gpt-5.1-codex-mini and similar models @@ -696,9 +732,9 @@ def test_response_completed_with_function_calls_emits_tool_calls_finish_reason() # response.completed with function_call should emit finish_reason='tool_calls' assert len(result.choices) > 0, "result should have choices" - assert result.choices[0].finish_reason == "tool_calls", ( - "response.completed with function_call output should emit finish_reason='tool_calls'" - ) + assert ( + result.choices[0].finish_reason == "tool_calls" + ), "response.completed with function_call output should emit finish_reason='tool_calls'" def test_response_completed_with_message_only_emits_stop_finish_reason(): @@ -709,7 +745,9 @@ def test_response_completed_with_message_only_emits_stop_finish_reason(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) # Simulate a response.completed event with only message output chunk = { @@ -733,10 +771,9 @@ def test_response_completed_with_message_only_emits_stop_finish_reason(): # response.completed with only message should emit finish_reason='stop' assert len(result.choices) > 0, "result should have choices" - assert result.choices[0].finish_reason == "stop", ( - "response.completed with only message output should emit finish_reason='stop'" - ) - + assert ( + result.choices[0].finish_reason == "stop" + ), "response.completed with only message output should emit finish_reason='stop'" def test_response_completed_preserves_usage_with_cached_tokens(): @@ -752,7 +789,9 @@ def test_response_completed_preserves_usage_with_cached_tokens(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = { "type": "response.completed", @@ -781,12 +820,18 @@ def test_response_completed_preserves_usage_with_cached_tokens(): result = iterator.chunk_parser(chunk) assert result.usage is not None, "usage should be set on response.completed chunk" - assert result.usage.prompt_tokens == 1226, "prompt_tokens should map from input_tokens" - assert result.usage.completion_tokens == 5, "completion_tokens should map from output_tokens" - assert result.usage.prompt_tokens_details is not None, "prompt_tokens_details should be set" - assert result.usage.prompt_tokens_details.cached_tokens == 1024, ( - "cached_tokens should be preserved from input_tokens_details" - ) + assert ( + result.usage.prompt_tokens == 1226 + ), "prompt_tokens should map from input_tokens" + assert ( + result.usage.completion_tokens == 5 + ), "completion_tokens should map from output_tokens" + assert ( + result.usage.prompt_tokens_details is not None + ), "prompt_tokens_details should be set" + assert ( + result.usage.prompt_tokens_details.cached_tokens == 1024 + ), "cached_tokens should be preserved from input_tokens_details" def test_function_call_done_emits_is_finished(): @@ -800,7 +845,9 @@ def test_function_call_done_emits_is_finished(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunk = { "type": "response.output_item.done", @@ -820,9 +867,9 @@ def test_function_call_done_emits_is_finished(): "output_item.done for function_call must not emit finish_reason; " "response.completed is responsible for the terminal finish_reason" ) - assert not result.choices[0].delta.tool_calls, ( - "output_item.done for function_call must not include a duplicate tool_calls delta" - ) + assert not result.choices[ + 0 + ].delta.tool_calls, "output_item.done for function_call must not include a duplicate tool_calls delta" def test_text_plus_tool_calls_sequence(): @@ -837,7 +884,9 @@ def test_text_plus_tool_calls_sequence(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) # Simulate the sequence from OpenAI Responses API chunks = [ @@ -876,23 +925,28 @@ def test_text_plus_tool_calls_sequence(): # Check message done (index 2) does NOT have finish_reason set message_done_result = results[2] assert len(message_done_result.choices) > 0, "message done should have choices" - assert message_done_result.choices[0].finish_reason is None or message_done_result.choices[0].finish_reason == "", ( - "message done should not have finish_reason" - ) + assert ( + message_done_result.choices[0].finish_reason is None + or message_done_result.choices[0].finish_reason == "" + ), "message done should not have finish_reason" # Check function_call done (index 5) does NOT have finish_reason set # (response.completed is responsible for the terminal finish_reason) function_done_result = results[5] - assert len(function_done_result.choices) > 0, "function_call done should have choices" - assert function_done_result.choices[0].finish_reason is None, ( - "output_item.done for function_call must not emit finish_reason" - ) + assert ( + len(function_done_result.choices) > 0 + ), "function_call done should have choices" + assert ( + function_done_result.choices[0].finish_reason is None + ), "output_item.done for function_call must not emit finish_reason" # Check response.completed (index 6) has finish_reason='stop' # (the mock chunk has no nested 'response' data, so has_function_calls is False → 'stop') completed_result = results[6] assert len(completed_result.choices) > 0, "response.completed should have choices" - assert completed_result.choices[0].finish_reason == "stop", "response.completed should have finish_reason='stop'" + assert ( + completed_result.choices[0].finish_reason == "stop" + ), "response.completed should have finish_reason='stop'" # ============================================================================= @@ -958,7 +1012,9 @@ def test_tool_message_output_uses_input_text_not_output_text(): output = function_call_output["output"] assert isinstance(output, list), f"output should be a list, got {type(output)}" assert len(output) == 1 - assert output[0]["type"] == "input_text", f"Expected input_text, got {output[0].get('type')}" + assert ( + output[0]["type"] == "input_text" + ), f"Expected input_text, got {output[0].get('type')}" assert output[0]["text"] == '{"temperature": 15, "condition": "sunny"}' print("✓ Tool message output correctly uses input_text type") @@ -1144,9 +1200,13 @@ def test_map_reasoning_effort_adds_summary_detailed(): assert result is not None, f"Result should not be None for effort={effort}" assert result["effort"] == effort, f"Effort should be {effort}" - assert "summary" not in result, f"Summary should NOT be present by default for effort={effort}" + assert ( + "summary" not in result + ), f"Summary should NOT be present by default for effort={effort}" - print(f"✓ reasoning_effort='{effort}' correctly maps to effort='{effort}' (no summary by default)") + print( + f"✓ reasoning_effort='{effort}' correctly maps to effort='{effort}' (no summary by default)" + ) # Test 2: With flag enabled - summary IS added litellm.reasoning_auto_summary = True @@ -1156,9 +1216,9 @@ def test_map_reasoning_effort_adds_summary_detailed(): assert result is not None, f"Result should not be None for effort={effort}" assert result["effort"] == effort, f"Effort should be {effort}" - assert result["summary"] == "detailed", ( - f"Summary should be 'detailed' when flag is enabled for effort={effort}" - ) + assert ( + result["summary"] == "detailed" + ), f"Summary should be 'detailed' when flag is enabled for effort={effort}" print( f"✓ reasoning_effort='{effort}' correctly maps to effort='{effort}', summary='detailed' (flag enabled)" @@ -1169,7 +1229,9 @@ def test_map_reasoning_effort_adds_summary_detailed(): os.environ["LITELLM_REASONING_AUTO_SUMMARY"] = "true" result = handler._map_reasoning_effort("high") - assert result["summary"] == "detailed", "Summary should be 'detailed' when env var is enabled" + assert ( + result["summary"] == "detailed" + ), "Summary should be 'detailed' when env var is enabled" print("✓ LITELLM_REASONING_AUTO_SUMMARY env var works correctly") # Test 4: Dict input is passed through as-is (no modification) @@ -1188,7 +1250,9 @@ def test_map_reasoning_effort_adds_summary_detailed(): assert result_unknown is None print("✓ Unknown reasoning_effort values return None") - print("✓ All reasoning_effort behaviors work correctly with flag/env var control") + print( + "✓ All reasoning_effort behaviors work correctly with flag/env var control" + ) finally: # Restore original values @@ -1264,7 +1328,9 @@ def test_transform_response_preserves_annotations(): # Create usage information usage = ResponseAPIUsage( input_tokens=10, - input_tokens_details=InputTokensDetails(audio_tokens=None, cached_tokens=0, text_tokens=None), + input_tokens_details=InputTokensDetails( + audio_tokens=None, cached_tokens=0, text_tokens=None + ), output_tokens=20, output_tokens_details=OutputTokensDetails(reasoning_tokens=0, text_tokens=None), total_tokens=30, @@ -1351,9 +1417,13 @@ def test_transform_response_preserves_annotations(): assert choice.message.content == "Here is some information with citations." # Check that annotations are preserved - assert hasattr(choice.message, "annotations"), "Message should have annotations attribute" + assert hasattr( + choice.message, "annotations" + ), "Message should have annotations attribute" assert choice.message.annotations is not None, "Annotations should not be None" - assert len(choice.message.annotations) == 2, f"Expected 2 annotations, got {len(choice.message.annotations)}" + assert ( + len(choice.message.annotations) == 2 + ), f"Expected 2 annotations, got {len(choice.message.annotations)}" # Verify annotation content annotation1 = choice.message.annotations[0] @@ -1375,7 +1445,9 @@ def test_transform_response_preserves_annotations(): assert result.usage.completion_tokens == 20 assert result.usage.total_tokens == 30 - print("✓ Annotations from Responses API are correctly preserved in Chat Completions format") + print( + "✓ Annotations from Responses API are correctly preserved in Chat Completions format" + ) def test_apply_patch_tool_call_converted_to_chat_completion_tool_call(): @@ -1512,6 +1584,8 @@ def test_apply_patch_tool_call_converted_to_chat_completion_tool_call(): assert args["type"] == "create_file" assert args["path"] == "hello.py" assert "print('hello world')" in args["diff"] + + def test_multi_tool_call_stream_no_premature_finish(): """ Regression test for multi-tool-call streaming bug. @@ -1538,18 +1612,26 @@ def test_multi_tool_call_stream_no_premature_finish(): OpenAiResponsesToChatCompletionStreamIterator, ) - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) chunks = [ # 0: response created - {"type": "response.created", "response": {"id": "resp_001", "status": "in_progress"}}, + { + "type": "response.created", + "response": {"id": "resp_001", "status": "in_progress"}, + }, # 1: first tool call added { "type": "response.output_item.added", "item": {"type": "function_call", "name": "read_file", "call_id": "call_1"}, }, # 2: first tool call arguments delta - {"type": "response.function_call_arguments.delta", "delta": '{"path":"/etc/hostname"}'}, + { + "type": "response.function_call_arguments.delta", + "delta": '{"path":"/etc/hostname"}', + }, # 3: first tool call done ← must NOT emit finish_reason { "type": "response.output_item.done", @@ -1608,10 +1690,12 @@ def test_multi_tool_call_stream_no_premature_finish(): r = results[done_idx] assert r is not None, f"{label}: chunk_parser must return a result" assert len(r.choices) > 0, f"{label}: result must have choices" - assert r.choices[0].finish_reason is None, ( - f"{label}: output_item.done must not emit finish_reason (stream would terminate prematurely)" - ) - assert not r.choices[0].delta.tool_calls, ( + assert ( + r.choices[0].finish_reason is None + ), f"{label}: output_item.done must not emit finish_reason (stream would terminate prematurely)" + assert not r.choices[ + 0 + ].delta.tool_calls, ( f"{label}: output_item.done must not include a duplicate tool_calls delta" ) @@ -1623,12 +1707,12 @@ def test_multi_tool_call_stream_no_premature_finish(): r = results[added_idx] if r is not None and r.choices and r.choices[0].delta.tool_calls: tc = r.choices[0].delta.tool_calls[0] - assert tc.function.name == expected_name, ( - f"output_item.added for {expected_name}: tool_call name mismatch" - ) - assert tc.id == expected_call_id, ( - f"output_item.added for {expected_name}: call_id mismatch" - ) + assert ( + tc.function.name == expected_name + ), f"output_item.added for {expected_name}: tool_call name mismatch" + assert ( + tc.id == expected_call_id + ), f"output_item.added for {expected_name}: call_id mismatch" # 3. argument delta events (indices 2 and 5) should carry arguments for delta_idx, expected_args, label in [ @@ -1638,17 +1722,17 @@ def test_multi_tool_call_stream_no_premature_finish(): r = results[delta_idx] if r is not None and r.choices and r.choices[0].delta.tool_calls: tc = r.choices[0].delta.tool_calls[0] - assert tc.function.arguments == expected_args, ( - f"{label}: argument delta mismatch" - ) + assert ( + tc.function.arguments == expected_args + ), f"{label}: argument delta mismatch" # 4. Only response.completed (index 7) emits the terminal finish_reason completed_result = results[7] assert completed_result is not None, "response.completed must return a result" assert len(completed_result.choices) > 0, "response.completed must have choices" - assert completed_result.choices[0].finish_reason == "tool_calls", ( - "response.completed with function_call outputs must emit finish_reason='tool_calls'" - ) + assert ( + completed_result.choices[0].finish_reason == "tool_calls" + ), "response.completed with function_call outputs must emit finish_reason='tool_calls'" # 5. No chunk before the last one should have finish_reason set for idx, r in enumerate(results[:-1]): @@ -1658,7 +1742,9 @@ def test_multi_tool_call_stream_no_premature_finish(): f"— only response.completed should terminate the stream" ) - print("✓ Multi-tool-call stream completes without premature finish_reason termination") + print( + "✓ Multi-tool-call stream completes without premature finish_reason termination" + ) # ============================================================================= @@ -1790,7 +1876,10 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): chunks = [ # 0: response.created - {"type": "response.created", "response": {"id": "resp_001", "status": "in_progress"}}, + { + "type": "response.created", + "response": {"id": "resp_001", "status": "in_progress"}, + }, # 1: call_1 (read_file) added — output_index=0 { "type": "response.output_item.added", @@ -1873,7 +1962,9 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): }, ] - iterator = OpenAiResponsesToChatCompletionStreamIterator(streaming_response=None, sync_stream=True) + iterator = OpenAiResponsesToChatCompletionStreamIterator( + streaming_response=None, sync_stream=True + ) results = [iterator.chunk_parser(chunk) for chunk in chunks] # 1. output_item.done events (indices 4 and 8) must NOT emit finish_reason @@ -1885,7 +1976,9 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): f"{label}: output_item.done must not emit finish_reason " f"(would prematurely terminate stream before subsequent tool calls arrive)" ) - assert not r.choices[0].delta.tool_calls, ( + assert not r.choices[ + 0 + ].delta.tool_calls, ( f"{label}: output_item.done must not emit a duplicate tool_calls delta" ) @@ -1919,7 +2012,9 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): for tc in tool_calls: if tc.function and tc.function.arguments: idx = tc.index - assembled_args[idx] = assembled_args.get(idx, "") + tc.function.arguments + assembled_args[idx] = ( + assembled_args.get(idx, "") + tc.function.arguments + ) # delta 1 = '{"path":' + delta 2 = '"/etc/foo"}' → '{"path":"/etc/foo"}' assert assembled_args.get(0) == '{"path":"/etc/foo"}', ( @@ -1938,16 +2033,16 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): for i, r in enumerate(results) if r is not None and r.choices and r.choices[0].finish_reason ] - assert len(finish_events) == 1, ( - f"Expected exactly 1 finish event, got {len(finish_events)}: {finish_events}" - ) + assert ( + len(finish_events) == 1 + ), f"Expected exactly 1 finish event, got {len(finish_events)}: {finish_events}" assert finish_events[0][0] == len(chunks) - 1, ( f"Finish event must be at the last chunk (index {len(chunks) - 1}), " f"but was at index {finish_events[0][0]}" ) - assert finish_events[0][1] == "tool_calls", ( - f"Terminal finish_reason must be 'tool_calls', got '{finish_events[0][1]}'" - ) + assert ( + finish_events[0][1] == "tool_calls" + ), f"Terminal finish_reason must be 'tool_calls', got '{finish_events[0][1]}'" # 5. Parallel tool calls have distinct indices matching output_index (0 and 1) # Collect indices from output_item.added chunks only (they carry the call id) @@ -1958,16 +2053,19 @@ def test_parallel_tool_calls_comprehensive_streaming_integration(): for tc in r.choices[0].delta.tool_calls if tc.id # output_item.added chunks carry the id; argument deltas do not ] - assert set(added_tool_call_indices) == {0, 1}, ( - f"Parallel tool calls must have distinct indices {{0, 1}}, got: {set(added_tool_call_indices)}" - ) + assert set(added_tool_call_indices) == { + 0, + 1, + }, f"Parallel tool calls must have distinct indices {{0, 1}}, got: {set(added_tool_call_indices)}" - print("✓ Parallel tool calls with split argument deltas stream correctly end-to-end") + print( + "✓ Parallel tool calls with split argument deltas stream correctly end-to-end" + ) def test_map_optional_params_preserves_reasoning_summary(): """Test that reasoning_effort dict with summary field is preserved. - + Regression test for: User reported that summary field was being dropped when routing to Responses API. The dict format should be fully preserved. """ @@ -1992,7 +2090,10 @@ def test_map_optional_params_preserves_reasoning_summary(): # Verify reasoning_effort dict with summary was fully preserved assert "reasoning" in responses_api_request - assert responses_api_request["reasoning"] == {"effort": "high", "summary": "detailed"} + assert responses_api_request["reasoning"] == { + "effort": "high", + "summary": "detailed", + } assert responses_api_request["reasoning"]["effort"] == "high" assert responses_api_request["reasoning"]["summary"] == "detailed" @@ -2026,8 +2127,8 @@ def test_convert_chat_completion_file_type_to_input_file(): } ] - input_items, instructions = handler.convert_chat_completion_messages_to_responses_api( - messages + input_items, instructions = ( + handler.convert_chat_completion_messages_to_responses_api(messages) ) assert len(input_items) == 1 @@ -2075,8 +2176,8 @@ def test_convert_chat_completion_file_type_with_file_id(): } ] - input_items, instructions = handler.convert_chat_completion_messages_to_responses_api( - messages + input_items, instructions = ( + handler.convert_chat_completion_messages_to_responses_api(messages) ) content = input_items[0]["content"] From 5acceaed32e03c2c44e6c292ebc2fad83f3c0ba9 Mon Sep 17 00:00:00 2001 From: Chesars Date: Mon, 16 Mar 2026 11:49:21 -0300 Subject: [PATCH 019/539] fix(model-prices): restore gpt-4-0314 entry lost in merge conflict The entry was accidentally dropped in commit 6bd7cd7 during a merge conflict resolution. The model is deprecated but still accessible for existing users until its shutdown date of 2026-03-26 per OpenAI docs. Fixes #23738 --- model_prices_and_context_window.json | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 6786fc3359..1ab96a445d 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -16864,6 +16864,19 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-4-0314": { + "deprecation_date": "2026-03-26", + "input_cost_per_token": 3e-05, + "litellm_provider": "openai", + "max_input_tokens": 8192, + "max_output_tokens": 4096, + "max_tokens": 4096, + "mode": "chat", + "output_cost_per_token": 6e-05, + "supports_prompt_caching": true, + "supports_system_messages": true, + "supports_tool_choice": true + }, "gpt-4-0613": { "deprecation_date": "2025-06-06", "input_cost_per_token": 3e-05, From 84b4af40fa7c8fd5a249ced6203da4c09c7dad87 Mon Sep 17 00:00:00 2001 From: Awais Qureshi Date: Tue, 17 Mar 2026 10:30:18 +0500 Subject: [PATCH 020/539] fix(fireworks): skip #transform=inline for base64 data URLs (#23729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(fireworks): skip #transform=inline for base64 data URLs Closes #23583 Appending #transform=inline to a data: URL corrupted the base64 payload, causing binascii.Error (Incorrect padding) when Fireworks AI attempted to decode the image. Data URLs are already inlined so the fragment is a no-op anyway — guard both the str and dict image_url branches to skip the suffix when the URL starts with "data:". Co-Authored-By: Claude Sonnet 4.6 * fix(fireworks): skip #transform=inline for base64 data URLs Closes #23583 * fix(fireworks): skip #transform=inline for base64 data URLs Closes #23583 * fix(fireworks): skip #transform=inline for base64 data URLs Closes #23583 --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Krish Dholakia --- .../llms/fireworks_ai/chat/transformation.py | 13 ++++++--- .../test_fireworks_ai_translation.py | 18 ++++++++++++ .../test_fireworks_ai_chat_transformation.py | 29 +++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index 8407e8ab69..6b654ebdfd 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -185,11 +185,16 @@ class FireworksAIConfig(OpenAIGPTConfig): ): # allow user to toggle this feature. return content if isinstance(content["image_url"], str): - content["image_url"] = f"{content['image_url']}#transform=inline" + # Skip base64 data URLs — appending #transform=inline corrupts the + # base64 payload and causes an "Incorrect padding" decode error on + # the Fireworks side. Data URLs are already inlined by definition. + # Lower-case before checking: URI schemes are case-insensitive (RFC 3986). + if not content["image_url"].lower().startswith("data:"): + content["image_url"] = f"{content['image_url']}#transform=inline" elif isinstance(content["image_url"], dict): - content["image_url"][ - "url" - ] = f"{content['image_url']['url']}#transform=inline" + url = content["image_url"]["url"] + if not url.lower().startswith("data:"): + content["image_url"]["url"] = f"{url}#transform=inline" return content def _transform_tools( diff --git a/tests/llm_translation/test_fireworks_ai_translation.py b/tests/llm_translation/test_fireworks_ai_translation.py index b9abbd501d..24c0d546e2 100644 --- a/tests/llm_translation/test_fireworks_ai_translation.py +++ b/tests/llm_translation/test_fireworks_ai_translation.py @@ -161,6 +161,24 @@ def test_document_inlining_example(disable_add_transform_inline_image_block): "vision-gpt", "http://example.com/image.png", ), + # data: URLs must never have #transform=inline appended — doing so + # corrupts the base64 payload (fixes #23583). + # URI schemes are case-insensitive (RFC 3986) so check all variants. + ( + {"image_url": "data:image/png;base64,iVBORw0KGgo="}, + "gpt-4", + "data:image/png;base64,iVBORw0KGgo=", + ), + ( + {"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ=="}}, + "gpt-4", + {"url": "data:image/jpeg;base64,/9j/4AAQ=="}, + ), + ( + {"image_url": "Data:image/png;base64,iVBORw0KGgo="}, + "gpt-4", + "Data:image/png;base64,iVBORw0KGgo=", + ), ], ) def test_transform_inline(content, model, expected_url): diff --git a/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py b/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py index 5d5aaa64c8..2b71b88356 100644 --- a/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py +++ b/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py @@ -110,6 +110,35 @@ def test_get_supported_openai_params_reasoning_effort(): assert "reasoning_effort" not in unsupported_params +def test_add_transform_inline_image_block_skips_data_urls(): + """ + data: URLs must not have #transform=inline appended — doing so corrupts the + base64 payload and raises binascii.Error: Incorrect padding on the Fireworks side. + Regression test for https://github.com/BerriAI/litellm/issues/23583 + """ + config = FireworksAIConfig() + data_url = "data:image/jpeg;base64,/9j/4AAQSkZJRgAB" + + # str branch + str_content = {"type": "image_url", "image_url": data_url} + result = config._add_transform_inline_image_block( + str_content, model="non-vision-model", disable_add_transform_inline_image_block=False + ) + assert result["image_url"] == data_url, "data URL must not be modified (str branch)" + + # dict branch + dict_content = {"type": "image_url", "image_url": {"url": data_url}} + result = config._add_transform_inline_image_block( + dict_content, model="non-vision-model", disable_add_transform_inline_image_block=False + ) + assert result["image_url"]["url"] == data_url, "data URL must not be modified (dict branch)" + + # regular https URL should still get the suffix + https_content = {"type": "image_url", "image_url": "https://example.com/image.jpg"} + result = config._add_transform_inline_image_block( + https_content, model="non-vision-model", disable_add_transform_inline_image_block=False + ) + assert result["image_url"].endswith("#transform=inline"), "https URL should get #transform=inline" @pytest.mark.parametrize( "api_base, expected_url_prefix", [ From e9291a97c32303a2ca18313bede88a75a52ce258 Mon Sep 17 00:00:00 2001 From: Miguel Miranda Dias <7780875+pandego@users.noreply.github.com> Date: Tue, 17 Mar 2026 06:34:15 +0100 Subject: [PATCH 021/539] fix(langsmith): avoid no running event loop during sync init (#23727) * fix(langsmith): skip periodic flush task without event loop * fix(langsmith): lazily start periodic flush task * test(langsmith): tighten flush task coverage * test(langsmith): cover lazy failure flush startup * refactor(langsmith): keep flush startup private --- litellm/integrations/langsmith.py | 43 ++++-- .../integrations/test_langsmith_init.py | 125 ++++++++++++++---- 2 files changed, 129 insertions(+), 39 deletions(-) diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 03845af521..df2e3c1e2b 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -83,7 +83,26 @@ class LangsmithLogger(CustomBatchLogger): if _batch_size: self.batch_size = int(_batch_size) self.log_queue: List[LangsmithQueueObject] = [] - asyncio.create_task(self.periodic_flush()) + self._flush_task: Optional[asyncio.Task[Any]] = self._start_periodic_flush_task() + + def _start_periodic_flush_task(self) -> Optional[asyncio.Task[Any]]: + """Start the periodic flush task only when an event loop is already running.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + verbose_logger.debug( + "Langsmith logger init: no running event loop, skipping periodic flush task startup" + ) + return None + + return loop.create_task(self.periodic_flush()) + + def _ensure_periodic_flush_task(self) -> None: + # This helper is intentionally synchronous. In asyncio's cooperative + # execution model, there is no await between the check and assignment, + # so one caller cannot interleave here and create a duplicate task. + if self._flush_task is None or self._flush_task.done(): + self._flush_task = self._start_periodic_flush_task() def get_credentials_from_env( self, @@ -255,6 +274,7 @@ class LangsmithLogger(CustomBatchLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: + self._ensure_periodic_flush_task() sampling_rate = self._get_sampling_rate_to_use_for_request(kwargs=kwargs) random_sample = random.random() if random_sample > sampling_rate: @@ -296,17 +316,18 @@ class LangsmithLogger(CustomBatchLogger): ) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - sampling_rate = self._get_sampling_rate_to_use_for_request(kwargs=kwargs) - random_sample = random.random() - if random_sample > sampling_rate: - verbose_logger.info( - "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( - sampling_rate, random_sample - ) - ) - return # Skip logging - verbose_logger.info("Langsmith Failure Event Logging!") try: + self._ensure_periodic_flush_task() + sampling_rate = self._get_sampling_rate_to_use_for_request(kwargs=kwargs) + random_sample = random.random() + if random_sample > sampling_rate: + verbose_logger.info( + "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( + sampling_rate, random_sample + ) + ) + return # Skip logging + verbose_logger.info("Langsmith Failure Event Logging!") credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) data = self._prepare_log_data( kwargs=kwargs, diff --git a/tests/test_litellm/integrations/test_langsmith_init.py b/tests/test_litellm/integrations/test_langsmith_init.py index 9f7db4095b..edc827033f 100644 --- a/tests/test_litellm/integrations/test_langsmith_init.py +++ b/tests/test_litellm/integrations/test_langsmith_init.py @@ -16,13 +16,9 @@ class TestLangsmithLoggerInit: Note: The current implementation has some edge cases in the sampling rate logic. """ - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "1"}, clear=False) - def test_langsmith_sampling_rate_parameter_respected_with_valid_env( - self, mock_create_task - ): + def test_langsmith_sampling_rate_parameter_respected_with_valid_env(self): """Test that langsmith_sampling_rate parameter is properly set when env var condition is met.""" - # When there's a valid integer in env var, the parameter should be used due to 'or' logic sampling_rate = 0.5 logger = LangsmithLogger( langsmith_api_key="test-key", @@ -30,58 +26,47 @@ class TestLangsmithLoggerInit: langsmith_sampling_rate=sampling_rate, ) - # With the current 'or' logic and valid env var, the parameter should be used assert ( logger.sampling_rate == sampling_rate ), f"Expected sampling_rate to be {sampling_rate}, got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "1"}, clear=False) - def test_langsmith_sampling_rate_zero_parameter_falls_back_to_env( - self, mock_create_task - ): + def test_langsmith_sampling_rate_zero_parameter_falls_back_to_env(self): """Test that 0.0 parameter falls back to env var due to falsy value.""" - # This demonstrates the current behavior where 0.0 is falsy and falls back to env logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project", - langsmith_sampling_rate=0.0, # This is falsy! + langsmith_sampling_rate=0.0, ) - # Due to current 'or' logic, 0.0 falls back to env var assert ( logger.sampling_rate == 1.0 ), f"Expected sampling_rate to fall back to 1.0 from env, got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "1"}, clear=False) - def test_langsmith_sampling_rate_from_integer_env_var(self, mock_create_task): + def test_langsmith_sampling_rate_from_integer_env_var(self): """Test that sampling rate uses environment variable when parameter not provided and env var is integer.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" ) - # Should use env var since it's a valid integer assert ( logger.sampling_rate == 1.0 ), f"Expected sampling_rate to be 1.0 from env var, got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "0.8"}, clear=False) - def test_langsmith_sampling_rate_decimal_env_var_ignored(self, mock_create_task): + def test_langsmith_sampling_rate_decimal_env_var_ignored(self): """Test that decimal environment variables are ignored due to isdigit() check.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" ) - # Decimal env vars are ignored due to isdigit() check, falls back to 1.0 assert ( logger.sampling_rate == 1.0 ), f"Expected sampling_rate to default to 1.0 (decimal env ignored), got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {}, clear=True) - def test_langsmith_sampling_rate_default_value(self, mock_create_task): + def test_langsmith_sampling_rate_default_value(self): """Test that sampling rate defaults to 1.0 when no parameter or env var provided.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" @@ -91,9 +76,8 @@ class TestLangsmithLoggerInit: logger.sampling_rate == 1.0 ), f"Expected default sampling_rate to be 1.0, got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "invalid"}, clear=False) - def test_langsmith_sampling_rate_invalid_env_var_defaults(self, mock_create_task): + def test_langsmith_sampling_rate_invalid_env_var_defaults(self): """Test that invalid environment variable falls back to default value.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" @@ -103,9 +87,8 @@ class TestLangsmithLoggerInit: logger.sampling_rate == 1.0 ), f"Expected sampling_rate to default to 1.0 with invalid env var, got {logger.sampling_rate}" - @patch("asyncio.create_task") @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": ""}, clear=False) - def test_langsmith_sampling_rate_empty_env_var_defaults(self, mock_create_task): + def test_langsmith_sampling_rate_empty_env_var_defaults(self): """Test that empty environment variable falls back to default value.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" @@ -115,14 +98,12 @@ class TestLangsmithLoggerInit: logger.sampling_rate == 1.0 ), f"Expected sampling_rate to default to 1.0 with empty env var, got {logger.sampling_rate}" - @patch("asyncio.create_task") - def test_langsmith_sampling_rate_attribute_exists(self, mock_create_task): + def test_langsmith_sampling_rate_attribute_exists(self): """Test that the sampling_rate attribute is always set on the logger instance.""" logger = LangsmithLogger( langsmith_api_key="test-key", langsmith_project="test-project" ) - # Verify the attribute exists and is a float assert hasattr( logger, "sampling_rate" ), "LangsmithLogger should have sampling_rate attribute" @@ -132,3 +113,91 @@ class TestLangsmithLoggerInit: assert ( logger.sampling_rate >= 0.0 ), f"sampling_rate should be non-negative, got {logger.sampling_rate}" + + @patch.object(LangsmithLogger, "_start_periodic_flush_task", return_value=None) + def test_langsmith_init_skips_periodic_flush_without_running_loop( + self, mock_start_periodic_flush_task + ): + """Test that sync initialization leaves the periodic flush task unset.""" + logger = LangsmithLogger( + langsmith_api_key="test-key", langsmith_project="test-project" + ) + + assert logger is not None + mock_start_periodic_flush_task.assert_called_once() + assert logger._flush_task is None + + @patch("asyncio.get_running_loop", side_effect=RuntimeError("no running event loop")) + def test_start_periodic_flush_task_returns_none_without_running_loop( + self, mock_get_running_loop + ): + """Test that helper returns None when no running event loop exists.""" + with patch.object(LangsmithLogger, "_start_periodic_flush_task", return_value=None): + logger = LangsmithLogger( + langsmith_api_key="test-key", + langsmith_project="test-project", + ) + + mock_get_running_loop.reset_mock() + + assert logger._start_periodic_flush_task() is None + mock_get_running_loop.assert_called_once() + + @patch("asyncio.get_running_loop") + def test_langsmith_init_starts_periodic_flush_with_running_loop( + self, mock_get_running_loop + ): + """Test that init schedules periodic flush when a running loop exists.""" + mock_loop = MagicMock() + mock_task = MagicMock() + mock_loop.create_task.return_value = mock_task + mock_get_running_loop.return_value = mock_loop + + logger = LangsmithLogger( + langsmith_api_key="test-key", langsmith_project="test-project" + ) + + assert logger._flush_task == mock_task + mock_loop.create_task.assert_called_once() + scheduled_coro = mock_loop.create_task.call_args.args[0] + scheduled_coro.close() + + @pytest.mark.asyncio + async def test_async_log_success_event_lazily_starts_periodic_flush(self): + """Test that async logging lazily starts periodic flush after sync init.""" + with patch.object(LangsmithLogger, "_start_periodic_flush_task", return_value=None): + logger = LangsmithLogger( + langsmith_api_key="test-key", + langsmith_project="test-project", + ) + logger._get_sampling_rate_to_use_for_request = MagicMock(return_value=1.0) + logger._get_credentials_to_use_for_request = MagicMock( + return_value=logger.default_credentials + ) + logger._prepare_log_data = MagicMock(return_value={"id": "run-id"}) + logger._start_periodic_flush_task = MagicMock(return_value=MagicMock()) + + await logger.async_log_success_event({}, {}, None, None) + + logger._start_periodic_flush_task.assert_called_once() + assert len(logger.log_queue) == 1 + + @pytest.mark.asyncio + async def test_async_log_failure_event_lazily_starts_periodic_flush(self): + """Test that async failure logging lazily starts periodic flush after sync init.""" + with patch.object(LangsmithLogger, "_start_periodic_flush_task", return_value=None): + logger = LangsmithLogger( + langsmith_api_key="test-key", + langsmith_project="test-project", + ) + logger._get_sampling_rate_to_use_for_request = MagicMock(return_value=1.0) + logger._get_credentials_to_use_for_request = MagicMock( + return_value=logger.default_credentials + ) + logger._prepare_log_data = MagicMock(return_value={"id": "run-id"}) + logger._start_periodic_flush_task = MagicMock(return_value=MagicMock()) + + await logger.async_log_failure_event({}, {}, None, None) + + logger._start_periodic_flush_task.assert_called_once() + assert len(logger.log_queue) == 1 From 186c2adb326050587f4577f09e7dedcaafc982d8 Mon Sep 17 00:00:00 2001 From: Awais Qureshi Date: Tue, 17 Mar 2026 10:38:16 +0500 Subject: [PATCH 022/539] fix(gemini): support images in tool_results for /v1/messages routing (#23724) * fix(gemini): support images in tool_results for /v1/messages routing convert_to_gemini_tool_call_result() dropped images in two cases: - data-URL strings (data:image/...;base64,...) treated as plain text - Anthropic image blocks in list content skipped Add detection and convert both to Gemini inline_data BlobType so image bytes are preserved. Fixes #23712. * fix(gemini): support images in tool_results for /v1/messages routing convert_to_gemini_tool_call_result() dropped images in two cases: - data-URL strings (data:image/...;base64,...) treated as plain text - Anthropic image blocks in list content skipped Add detection and convert both to Gemini inline_data BlobType so image bytes are preserved. Fixes #23712. * fix(gemini): support images in tool_results for /v1/messages routing convert_to_gemini_tool_call_result() dropped images in two cases: - data-URL strings (data:image/...;base64,...) treated as plain text - Anthropic image blocks in list content skipped Add detection and convert both to Gemini inline_data BlobType so image bytes are preserved. Fixes #23712. * fix(fireworks): skip #transform=inline for base64 data URLs Closes #23583 --- .../prompt_templates/factory.py | 59 ++++-- ...llm_core_utils_prompt_templates_factory.py | 172 ++++++++++++++++++ 2 files changed, 219 insertions(+), 12 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 47272b38ad..6c4c98ebf9 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -1498,17 +1498,49 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 from litellm.types.llms.vertex_ai import BlobType content_str: str = "" - inline_data: Optional[BlobType] = None + inline_data_list: List[BlobType] = [] if "content" in message: if isinstance(message["content"], str): content_str = message["content"] + # Detect data-URL images (e.g. from Anthropic tool_result with a single image block + # that was serialised as a plain string by translate_anthropic_messages_to_openai) + # and promote them to inline_data so Gemini receives actual image bytes. + if content_str.startswith("data:") and ";base64," in content_str: + try: + mime_rest = content_str[5:].split(";base64,", 1) + if len(mime_rest) == 2 and mime_rest[0].startswith("image/"): + # Strip any extra parameters (e.g. ";charset=UTF-8") from the MIME segment + clean_mime = mime_rest[0].split(";")[0].strip() + inline_data_list.append( + BlobType(data=mime_rest[1], mime_type=clean_mime) + ) + content_str = "" + except Exception as e: + verbose_logger.warning( + f"Failed to parse data URL in tool response: {e}" + ) elif isinstance(message["content"], List): content_list = message["content"] for content in content_list: content_type = content.get("type", "") if content_type == "text": content_str += content.get("text", "") + elif content_type == "image": + # Anthropic-native image block: {"type": "image", "source": {"type": "base64", ...}} + source = content.get("source", {}) + if isinstance(source, dict) and source.get("type") == "base64": + try: + inline_data_list.append( + BlobType( + data=source.get("data", ""), + mime_type=source.get("media_type", "image/jpeg"), + ) + ) + except Exception as e: + verbose_logger.warning( + f"Failed to process Anthropic image block in tool response: {e}" + ) elif content_type in ("input_image", "image_url"): # Extract image for inline_data (for Computer Use screenshots and tool results) image_url_data = content.get("image_url", "") @@ -1524,9 +1556,11 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 image_obj = convert_to_anthropic_image_obj( image_url, format=None ) - inline_data = BlobType( - data=image_obj["data"], - mime_type=image_obj["media_type"], + inline_data_list.append( + BlobType( + data=image_obj["data"], + mime_type=image_obj["media_type"], + ) ) except Exception as e: verbose_logger.warning( @@ -1551,9 +1585,11 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 file_obj = convert_to_anthropic_image_obj( file_data, format=None ) - inline_data = BlobType( - data=file_obj["data"], - mime_type=file_obj["media_type"], + inline_data_list.append( + BlobType( + data=file_obj["data"], + mime_type=file_obj["media_type"], + ) ) except Exception as e: verbose_logger.warning( @@ -1607,13 +1643,12 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 # Create part with function_response, and optionally inline_data for images (Computer Use) _part: VertexPartType = {"function_response": _function_response} - # For Computer Use, if we have an image, we need separate parts: + # For Computer Use, if we have images/files, we need separate parts: # - One part with function_response - # - One part with inline_data + # - One part per inline_data item # Gemini's PartType is a oneof, so we can't have both in the same part - if inline_data: - image_part: VertexPartType = {"inline_data": inline_data} - return [_part, image_part] + if inline_data_list: + return [_part] + [{"inline_data": d} for d in inline_data_list] return _part diff --git a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py index 8d68539564..5c5cd5bdc3 100644 --- a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py +++ b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py @@ -1,3 +1,4 @@ +import base64 import json from unittest.mock import MagicMock, patch @@ -9,9 +10,11 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( BedrockConverseMessagesProcessor, BedrockImageProcessor, _convert_to_bedrock_tool_call_invoke, + convert_to_gemini_tool_call_result, ollama_pt, sanitize_messages_for_tool_calling, ) +from litellm.types.llms.openai import ChatCompletionToolMessage def test_ollama_pt_simple_messages(): @@ -550,6 +553,175 @@ def test_convert_gemini_tool_call_result_with_image_url(): assert isinstance(result2, list) and any("inline_data" in p for p in result2) +def test_convert_gemini_tool_call_result_with_anthropic_image_block(): + """ + Test that Anthropic-native image blocks in tool_result list content are + converted to Gemini inline_data instead of being silently dropped. + Fixes: https://github.com/BerriAI/litellm/issues/23712 + """ + tiny_png_b64 = base64.b64encode(b"PNG_PLACEHOLDER").decode() + + message = ChatCompletionToolMessage( + role="tool", + tool_call_id="call_123", + content=[ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": tiny_png_b64, + }, + } + ], + ) + last_message_with_tool_calls = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "index": 0, + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + } + + result = convert_to_gemini_tool_call_result( + message=message, + last_message_with_tool_calls=last_message_with_tool_calls, + ) + assert isinstance(result, list), "expected a list of parts" + inline_parts = [p for p in result if "inline_data" in p] + assert len(inline_parts) == 1, "expected exactly one inline_data part" + assert inline_parts[0]["inline_data"]["mime_type"] == "image/png" + assert inline_parts[0]["inline_data"]["data"] == tiny_png_b64 + + +def test_convert_gemini_tool_call_result_with_multiple_anthropic_image_blocks(): + """ + Test that multiple Anthropic-native image blocks in a single tool_result + are all preserved as separate inline_data parts instead of only the last + one being kept. + Fixes: https://github.com/BerriAI/litellm/issues/23712 + """ + png_b64 = base64.b64encode(b"PNG_PLACEHOLDER").decode() + jpeg_b64 = base64.b64encode(b"JPEG_PLACEHOLDER").decode() + + message = ChatCompletionToolMessage( + role="tool", + tool_call_id="call_multi", + content=[ + {"type": "text", "text": "here are two images"}, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/png", "data": png_b64}, + }, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/jpeg", "data": jpeg_b64}, + }, + ], + ) + last_message_with_tool_calls = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_multi", + "type": "function", + "index": 0, + "function": {"name": "screenshot", "arguments": "{}"}, + } + ], + } + + result = convert_to_gemini_tool_call_result( + message=message, + last_message_with_tool_calls=last_message_with_tool_calls, + ) + assert isinstance(result, list), "expected a list of parts" + inline_parts = [p for p in result if "inline_data" in p] + assert len(inline_parts) == 2, f"expected 2 inline_data parts, got {len(inline_parts)}" + mime_types = {p["inline_data"]["mime_type"] for p in inline_parts} + assert mime_types == {"image/png", "image/jpeg"} + + +def test_convert_gemini_tool_call_result_with_data_url_string(): + """ + Test that a data-URL string in tool_result content is converted to + Gemini inline_data instead of being passed as plain text. + Fixes: https://github.com/BerriAI/litellm/issues/23712 + """ + tiny_png_b64 = base64.b64encode(b"PNG_PLACEHOLDER").decode() + + message = ChatCompletionToolMessage( + role="tool", + tool_call_id="call_456", + content=f"data:image/png;base64,{tiny_png_b64}", + ) + last_message_with_tool_calls = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "index": 0, + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + } + + result = convert_to_gemini_tool_call_result( + message=message, + last_message_with_tool_calls=last_message_with_tool_calls, + ) + assert isinstance(result, list), "expected a list of parts" + inline_parts = [p for p in result if "inline_data" in p] + assert len(inline_parts) == 1, "data-URL image string was not converted to inline_data" + assert inline_parts[0]["inline_data"]["mime_type"] == "image/png" + assert inline_parts[0]["inline_data"]["data"] == tiny_png_b64 + + +def test_convert_gemini_tool_call_result_with_data_url_extra_params(): + """ + Test that a data-URL with extra MIME parameters (e.g. charset) produces + a clean mime_type without the extra parameters. + """ + tiny_png_b64 = base64.b64encode(b"PNG_PLACEHOLDER").decode() + + message = ChatCompletionToolMessage( + role="tool", + tool_call_id="call_extra", + content=f"data:image/png;charset=UTF-8;base64,{tiny_png_b64}", + ) + last_message_with_tool_calls = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_extra", + "type": "function", + "index": 0, + "function": {"name": "read_file", "arguments": "{}"}, + } + ], + } + + result = convert_to_gemini_tool_call_result( + message=message, + last_message_with_tool_calls=last_message_with_tool_calls, + ) + assert isinstance(result, list), "expected a list of parts" + inline_parts = [p for p in result if "inline_data" in p] + assert len(inline_parts) == 1 + assert inline_parts[0]["inline_data"]["mime_type"] == "image/png", ( + f"expected clean 'image/png', got '{inline_parts[0]['inline_data']['mime_type']}'" + ) + + def test_bedrock_tools_unpack_defs(): """ Test that the unpack_defs method handles nested $ref inside anyOf items correctly From 0bc609affdc09c4071c72f4035591289685af8e4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 11:23:47 +0530 Subject: [PATCH 023/539] fix(vertex-ai): support batch cancel via Vertex API Add Vertex batch cancellation support in LiteLLM batch APIs, route proxy cancel fallback using request provider headers, and return post-cancel batch state via retrieve to keep response shape compatible. Made-with: Cursor --- litellm/batches/main.py | 32 ++++- litellm/llms/vertex_ai/batches/handler.py | 113 ++++++++++++++++++ litellm/proxy/batches_endpoints/endpoints.py | 14 ++- .../test_vertex_ai_batch_transformation.py | 78 ++++++++++++ 4 files changed, 229 insertions(+), 8 deletions(-) diff --git a/litellm/batches/main.py b/litellm/batches/main.py index e176dc4292..36093d071b 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -884,7 +884,7 @@ def list_batches( async def acancel_batch( batch_id: str, model: Optional[str] = None, - custom_llm_provider: Literal["openai", "azure"] = "openai", + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", metadata: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, @@ -930,7 +930,7 @@ async def acancel_batch( def cancel_batch( batch_id: str, model: Optional[str] = None, - custom_llm_provider: Union[Literal["openai", "azure"], str] = "openai", + custom_llm_provider: Union[Literal["openai", "azure", "vertex_ai"], str] = "openai", metadata: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, @@ -1047,9 +1047,35 @@ def cancel_batch( cancel_batch_data=_cancel_batch_request, litellm_params=litellm_params, ) + elif custom_llm_provider == "vertex_ai": + api_base = optional_params.api_base or "" + vertex_ai_project = ( + optional_params.vertex_project + or litellm.vertex_project + or get_secret_str("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.vertex_location + or litellm.vertex_location + or get_secret_str("VERTEXAI_LOCATION") + ) + vertex_credentials = optional_params.vertex_credentials or get_secret_str( + "VERTEXAI_CREDENTIALS" + ) + + response = vertex_ai_batches_instance.cancel_batch( + _is_async=_is_async, + batch_id=batch_id, + api_base=api_base, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + timeout=timeout, + max_retries=optional_params.max_retries, + ) else: raise litellm.exceptions.BadRequestError( - message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format( + message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai', 'azure', and 'vertex_ai' are supported.".format( custom_llm_provider ), model="n/a", diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index f0b181c9a6..a24bfd89f0 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -376,3 +376,116 @@ class VertexAIBatchPrediction(VertexLLM): response=_json_response ) return vertex_batch_response + + def cancel_batch( + self, + _is_async: bool, + batch_id: str, + api_base: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + vertex_project: Optional[str], + vertex_location: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: + sync_handler = _get_httpx_client() + + access_token, project_id = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + default_api_base = self.create_vertex_batch_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + default_api_base = f"{default_api_base}/{batch_id}:cancel" + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=None, + auth_header=None, + url=default_api_base, + model=None, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1", + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } + + if _is_async is True: + return self._async_cancel_batch( + api_base=api_base, + retrieve_api_base=api_base.rsplit(":cancel", 1)[0], + headers=headers, + ) + + response = sync_handler.post( + url=api_base, + headers=headers, + data=json.dumps({}), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + retrieve_response = sync_handler.get( + url=api_base.rsplit(":cancel", 1)[0], + headers=headers, + ) + if retrieve_response.status_code != 200: + raise Exception( + f"Error: {retrieve_response.status_code} {retrieve_response.text}" + ) + + _json_response = retrieve_response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response + + async def _async_cancel_batch( + self, + api_base: str, + retrieve_api_base: str, + headers: Dict[str, str], + ) -> LiteLLMBatch: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) + response = await client.post( + url=api_base, + headers=headers, + data=json.dumps({}), + ) + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + retrieve_response = await client.get( + url=retrieve_api_base, + headers=headers, + ) + if retrieve_response.status_code != 200: + raise Exception( + f"Error: {retrieve_response.status_code} {retrieve_response.text}" + ) + + _json_response = retrieve_response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 740e63b7f1..06254b57fc 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -5,7 +5,7 @@ ###################################################################### import asyncio -from typing import Dict, Optional, cast +from typing import Any, Dict, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response @@ -654,7 +654,7 @@ async def list_batches( managed_files_obj, "list_user_batches" ): verbose_proxy_logger.debug("Using managed objects table for batch listing") - response = await managed_files_obj.list_user_batches( + response = await cast(Any, managed_files_obj).list_user_batches( user_api_key_dict=user_api_key_dict, limit=limit, after=after, @@ -685,8 +685,9 @@ async def list_batches( # Encode batch IDs in the list response so clients can use # them for retrieve/cancel/file downloads through the proxy. - if response and hasattr(response, "data") and response.data: - for batch in response.data: + response_data = getattr(response, "data", None) + if response_data: + for batch in response_data: encode_batch_response_ids(batch, model=model_param) verbose_proxy_logger.debug(f"Listed batches using model: {model_param}") @@ -896,7 +897,10 @@ async def cancel_batch( # SCENARIO 3: Fallback to custom_llm_provider (uses env variables) else: custom_llm_provider = ( - provider or data.pop("custom_llm_provider", None) or "openai" + provider + or get_custom_llm_provider_from_request_headers(request=request) + or data.pop("custom_llm_provider", None) + or "openai" ) # Extract batch_id from data to avoid "multiple values for keyword argument" error # data was cast from CancelBatchRequest which already contains batch_id diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index 1aab74ddc2..d27cd8ba8a 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -1,3 +1,9 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import litellm +from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.llms.vertex_ai.batches.transformation import VertexAIBatchTransformation @@ -36,3 +42,75 @@ def test_output_file_id_falls_back_to_output_uri_prefix_with_predictions_jsonl() output_file_id == "gs://test-bucket/litellm-vertex-files/publishers/google/models/gemini-2.5-pro/prediction-model-456/predictions.jsonl" ) + + +@pytest.mark.asyncio +async def test_vertex_ai_cancel_batch(): + """Test that vertex_ai cancel_batch calls the correct API endpoint""" + handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456", + "state": "JOB_STATE_CANCELLING", + "createTime": "2024-03-17T10:00:00.000000Z", + "inputConfig": { + "gcsSource": { + "uris": ["gs://test-bucket/input.jsonl"] + } + }, + "outputConfig": { + "gcsDestination": { + "outputUriPrefix": "gs://test-bucket/output" + } + } + } + + with patch("litellm.llms.vertex_ai.batches.handler._get_httpx_client") as mock_client: + mock_client.return_value.post.return_value = mock_response + mock_client.return_value.get.return_value = mock_response + + with patch.object(handler, "_ensure_access_token") as mock_auth: + mock_auth.return_value = ("fake-token", "test-project") + + response = handler.cancel_batch( + _is_async=False, + batch_id="123456", + api_base=None, + vertex_credentials=None, + vertex_project="test-project", + vertex_location="us-central1", + timeout=600.0, + max_retries=None, + ) + + assert response.id == "123456" + assert response.status == "cancelling" + + mock_client.return_value.post.assert_called_once() + mock_client.return_value.get.assert_called_once() + call_args = mock_client.return_value.post.call_args + assert ":cancel" in call_args.kwargs["url"] + + +@pytest.mark.asyncio +async def test_litellm_cancel_batch_vertex_ai(): + """Test that litellm.cancel_batch works with vertex_ai provider""" + mock_response = MagicMock() + mock_response.id = "batch_123" + mock_response.status = "cancelling" + + with patch.object(litellm.batches.main, "vertex_ai_batches_instance") as mock_instance: + mock_instance.cancel_batch.return_value = mock_response + + response = litellm.cancel_batch( + batch_id="batch_123", + custom_llm_provider="vertex_ai", + vertex_project="test-project", + vertex_location="us-central1", + ) + + assert mock_instance.cancel_batch.called + assert response.id == "batch_123" + assert response.status == "cancelling" From d8e3abf3cef95365fa6b1cc0af68e32a33319bd1 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 11:34:50 +0530 Subject: [PATCH 024/539] fix(vertex-ai): apply review updates for batch cancel Incorporate follow-up changes to Vertex batch cancel handling and proxy provider resolution, including config updates used for local verification. Made-with: Cursor --- litellm/llms/vertex_ai/batches/handler.py | 3 +- litellm/proxy/batches_endpoints/endpoints.py | 1 + proxy_server_config.yaml | 231 +------------------ 3 files changed, 4 insertions(+), 231 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index a24bfd89f0..c7b9287c08 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -388,8 +388,6 @@ class VertexAIBatchPrediction(VertexLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: - sync_handler = _get_httpx_client() - access_token, project_id = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project, @@ -434,6 +432,7 @@ class VertexAIBatchPrediction(VertexLLM): headers=headers, ) + sync_handler = _get_httpx_client() response = sync_handler.post( url=api_base, headers=headers, diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 06254b57fc..9ce1b6e916 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -899,6 +899,7 @@ async def cancel_batch( custom_llm_provider = ( provider or get_custom_llm_provider_from_request_headers(request=request) + or get_custom_llm_provider_from_request_query(request=request) or data.pop("custom_llm_provider", None) or "openai" ) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 5d3d810926..6e48af021a 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,231 +1,4 @@ model_list: - - model_name: gpt-3.5-turbo-end-user-test + - model_name: gemini-2.5-pro litellm_params: - model: gpt-3.5-turbo - region_name: "eu" - model_info: - id: "1" - - model_name: gpt-3.5-turbo-end-user-test - litellm_params: - model: openai/gpt-4.1-mini - api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - - model_name: gpt-3.5-turbo - litellm_params: - model: openai/gpt-4.1-mini - api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - - model_name: gpt-3.5-turbo-large - litellm_params: - model: "gpt-3.5-turbo-1106" - api_key: os.environ/OPENAI_API_KEY - rpm: 480 - timeout: 300 - stream_timeout: 60 - - model_name: gpt-4 - litellm_params: - model: openai/gpt-4.1-mini - api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault - rpm: 480 - timeout: 300 - stream_timeout: 60 - - model_name: sagemaker-completion-model - litellm_params: - model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 - input_cost_per_second: 0.000420 - - model_name: text-embedding-ada-002 - litellm_params: - model: openai/text-embedding-ada-002 - api_key: os.environ/OPENAI_API_KEY - model_info: - mode: embedding - base_model: text-embedding-ada-002 - - model_name: dall-e-2 # some tests use dall-e-2 which is now deprecated, alias to dall-e-3 - litellm_params: - model: openai/dall-e-3 - - model_name: openai-dall-e-3 - litellm_params: - model: dall-e-3 - - model_name: fake-openai-endpoint - litellm_params: - model: openai/gpt-3.5-turbo - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - - model_name: fake-openai-endpoint-2 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - stream_timeout: 0.001 - rpm: 1 - - model_name: fake-openai-endpoint-3 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - stream_timeout: 0.001 - rpm: 1000 - - model_name: fake-openai-endpoint-4 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - num_retries: 50 - - model_name: fake-openai-endpoint-3 - litellm_params: - model: openai/my-fake-model-2 - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - stream_timeout: 0.001 - rpm: 1000 - - model_name: bad-model - litellm_params: - model: openai/bad-model - api_key: os.environ/OPENAI_API_KEY - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - mock_timeout: True - timeout: 60 - rpm: 1000 - model_info: - health_check_timeout: 1 - - model_name: good-model - litellm_params: - model: openai/bad-model - api_key: os.environ/OPENAI_API_KEY - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - rpm: 1000 - model_info: - health_check_timeout: 1 - - model_name: "*" - litellm_params: - model: openai/* - api_key: os.environ/OPENAI_API_KEY - - model_name: realtime-v1 - litellm_params: - model: azure/gpt-realtime-20250828-standard - api_version: "2025-08-28" - realtime_protocol: GA # Possible values: "GA"/ "v1", "beta" - - - model_name: realtime-beta - litellm_params: - model: azure/gpt-realtime-20250828-standard - api_version: 2025-04-01-preview - - - # provider specific wildcard routing - - model_name: "anthropic/*" - litellm_params: - model: "anthropic/*" - api_key: os.environ/ANTHROPIC_API_KEY - - model_name: "bedrock/*" - litellm_params: - model: "bedrock/*" - - model_name: "groq/*" - litellm_params: - model: "groq/*" - api_key: os.environ/GROQ_API_KEY - - model_name: mistral-embed - litellm_params: - model: mistral/mistral-embed - - model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model - litellm_params: - model: text-completion-openai/gpt-3.5-turbo-instruct - - model_name: fake-openai-endpoint-5 - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ - timeout: 1 - - model_name: badly-configured-openai-endpoint - litellm_params: - model: openai/my-fake-model - api_key: my-fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.appxxxx/ - - model_name: gemini-1.5-flash - litellm_params: - model: gemini/gemini-1.5-flash - api_key: os.environ/GOOGLE_API_KEY - - model_name: gpt-4o - litellm_params: - model: gpt-4o - api_key: os.environ/OPENAI_API_KEY - - -litellm_settings: - # set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production - drop_params: True - success_callback: ["prometheus"] - # max_budget: 100 - # budget_duration: 30d - num_retries: 5 - request_timeout: 600 - telemetry: False - context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] - default_team_settings: - - team_id: team-1 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1 - langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1 - - team_id: team-2 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 - langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 - langfuse_host: https://us.cloud.langfuse.com - # cache: true # [OPTIONAL] use for caching responses - # enable_caching_on_provider_specific_optional_params: True # Include provider-specific params in cache keys - # cache_params: # And for shared health check - # type: redis - # host: localhost - # port: 6379 - -# For /fine_tuning/jobs endpoints -finetune_settings: - - custom_llm_provider: azure - api_base: os.environ/AZURE_API_BASE - api_key: os.environ/AZURE_API_KEY - api_version: "2023-03-15-preview" - - custom_llm_provider: openai - api_key: os.environ/OPENAI_API_KEY - -# for /files endpoints -files_settings: - - custom_llm_provider: azure - api_base: os.environ/AZURE_API_BASE - api_key: os.environ/AZURE_API_KEY - api_version: "2023-03-15-preview" - - custom_llm_provider: openai - api_key: os.environ/OPENAI_API_KEY - -router_settings: - routing_strategy: usage-based-routing-v2 - redis_host: os.environ/REDIS_HOST - redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT - enable_pre_call_checks: true - model_group_alias: {"my-special-fake-model-alias-name": "fake-openai-endpoint-3"} - -general_settings: - master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys - store_model_in_db: True - proxy_budget_rescheduler_min_time: 60 - proxy_budget_rescheduler_max_time: 64 - proxy_batch_write_at: 1 - database_connection_pool_limit: 10 - # background_health_checks: true - # use_shared_health_check: true - # health_check_interval: 30 - # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy - - pass_through_endpoints: - - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server - target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to - headers: # headers to forward to this URL - content-type: application/json # (Optional) Extra Headers to pass to this endpoint - accept: application/json - forward_headers: True - -# environment_variables: - # settings for using redis caching - # REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com - # REDIS_PORT: "16337" - # REDIS_PASSWORD: \ No newline at end of file + model: vertex_ai/gemini-2.5-pro \ No newline at end of file From 37b7a7fb576279a41817ef8239d2e24b79e56f7b Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 11:36:39 +0530 Subject: [PATCH 025/539] chore(config): restore proxy_server_config.yaml Revert local test-only proxy config edits so the PR does not include unrelated configuration changes. Made-with: Cursor --- proxy_server_config.yaml | 231 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 229 insertions(+), 2 deletions(-) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index 6e48af021a..5d3d810926 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,4 +1,231 @@ model_list: - - model_name: gemini-2.5-pro + - model_name: gpt-3.5-turbo-end-user-test litellm_params: - model: vertex_ai/gemini-2.5-pro \ No newline at end of file + model: gpt-3.5-turbo + region_name: "eu" + model_info: + id: "1" + - model_name: gpt-3.5-turbo-end-user-test + litellm_params: + model: openai/gpt-4.1-mini + api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + - model_name: gpt-3.5-turbo + litellm_params: + model: openai/gpt-4.1-mini + api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + - model_name: gpt-3.5-turbo-large + litellm_params: + model: "gpt-3.5-turbo-1106" + api_key: os.environ/OPENAI_API_KEY + rpm: 480 + timeout: 300 + stream_timeout: 60 + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4.1-mini + api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + rpm: 480 + timeout: 300 + stream_timeout: 60 + - model_name: sagemaker-completion-model + litellm_params: + model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 + input_cost_per_second: 0.000420 + - model_name: text-embedding-ada-002 + litellm_params: + model: openai/text-embedding-ada-002 + api_key: os.environ/OPENAI_API_KEY + model_info: + mode: embedding + base_model: text-embedding-ada-002 + - model_name: dall-e-2 # some tests use dall-e-2 which is now deprecated, alias to dall-e-3 + litellm_params: + model: openai/dall-e-3 + - model_name: openai-dall-e-3 + litellm_params: + model: dall-e-3 + - model_name: fake-openai-endpoint + litellm_params: + model: openai/gpt-3.5-turbo + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-openai-endpoint-2 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + stream_timeout: 0.001 + rpm: 1 + - model_name: fake-openai-endpoint-3 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + stream_timeout: 0.001 + rpm: 1000 + - model_name: fake-openai-endpoint-4 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + num_retries: 50 + - model_name: fake-openai-endpoint-3 + litellm_params: + model: openai/my-fake-model-2 + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + stream_timeout: 0.001 + rpm: 1000 + - model_name: bad-model + litellm_params: + model: openai/bad-model + api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + mock_timeout: True + timeout: 60 + rpm: 1000 + model_info: + health_check_timeout: 1 + - model_name: good-model + litellm_params: + model: openai/bad-model + api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + rpm: 1000 + model_info: + health_check_timeout: 1 + - model_name: "*" + litellm_params: + model: openai/* + api_key: os.environ/OPENAI_API_KEY + - model_name: realtime-v1 + litellm_params: + model: azure/gpt-realtime-20250828-standard + api_version: "2025-08-28" + realtime_protocol: GA # Possible values: "GA"/ "v1", "beta" + + - model_name: realtime-beta + litellm_params: + model: azure/gpt-realtime-20250828-standard + api_version: 2025-04-01-preview + + + # provider specific wildcard routing + - model_name: "anthropic/*" + litellm_params: + model: "anthropic/*" + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: "bedrock/*" + litellm_params: + model: "bedrock/*" + - model_name: "groq/*" + litellm_params: + model: "groq/*" + api_key: os.environ/GROQ_API_KEY + - model_name: mistral-embed + litellm_params: + model: mistral/mistral-embed + - model_name: gpt-instruct # [PROD TEST] - tests if `/health` automatically infers this to be a text completion model + litellm_params: + model: text-completion-openai/gpt-3.5-turbo-instruct + - model_name: fake-openai-endpoint-5 + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + timeout: 1 + - model_name: badly-configured-openai-endpoint + litellm_params: + model: openai/my-fake-model + api_key: my-fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.appxxxx/ + - model_name: gemini-1.5-flash + litellm_params: + model: gemini/gemini-1.5-flash + api_key: os.environ/GOOGLE_API_KEY + - model_name: gpt-4o + litellm_params: + model: gpt-4o + api_key: os.environ/OPENAI_API_KEY + + +litellm_settings: + # set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production + drop_params: True + success_callback: ["prometheus"] + # max_budget: 100 + # budget_duration: 30d + num_retries: 5 + request_timeout: 600 + telemetry: False + context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] + default_team_settings: + - team_id: team-1 + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + langfuse_public_key: os.environ/LANGFUSE_PROJECT1_PUBLIC # Project 1 + langfuse_secret: os.environ/LANGFUSE_PROJECT1_SECRET # Project 1 + - team_id: team-2 + success_callback: ["langfuse"] + failure_callback: ["langfuse"] + langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 + langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 + langfuse_host: https://us.cloud.langfuse.com + # cache: true # [OPTIONAL] use for caching responses + # enable_caching_on_provider_specific_optional_params: True # Include provider-specific params in cache keys + # cache_params: # And for shared health check + # type: redis + # host: localhost + # port: 6379 + +# For /fine_tuning/jobs endpoints +finetune_settings: + - custom_llm_provider: azure + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: "2023-03-15-preview" + - custom_llm_provider: openai + api_key: os.environ/OPENAI_API_KEY + +# for /files endpoints +files_settings: + - custom_llm_provider: azure + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: "2023-03-15-preview" + - custom_llm_provider: openai + api_key: os.environ/OPENAI_API_KEY + +router_settings: + routing_strategy: usage-based-routing-v2 + redis_host: os.environ/REDIS_HOST + redis_password: os.environ/REDIS_PASSWORD + redis_port: os.environ/REDIS_PORT + enable_pre_call_checks: true + model_group_alias: {"my-special-fake-model-alias-name": "fake-openai-endpoint-3"} + +general_settings: + master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys + store_model_in_db: True + proxy_budget_rescheduler_min_time: 60 + proxy_budget_rescheduler_max_time: 64 + proxy_batch_write_at: 1 + database_connection_pool_limit: 10 + # background_health_checks: true + # use_shared_health_check: true + # health_check_interval: 30 + # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy + + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True + +# environment_variables: + # settings for using redis caching + # REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com + # REDIS_PORT: "16337" + # REDIS_PASSWORD: \ No newline at end of file From c7352515707fbcc83907000b5487d4f5162bf1d2 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 11:41:44 +0530 Subject: [PATCH 026/539] =?UTF-8?q?feat(responses):=20file=5Fsearch=20supp?= =?UTF-8?q?ort=20=E2=80=94=20Phase=201=20native=20passthrough=20+=20Phase?= =?UTF-8?q?=202=20emulated=20fallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 (native passthrough): - _decode_vector_store_ids_in_tools(): decode LiteLLM-managed unified vector_store_ids to provider-native IDs in file_search tools - Split update_responses_tools_with_model_file_ids() into decode pass (always runs) + code_interpreter mapping pass (guarded) - BaseResponsesAPIConfig.supports_native_file_search() → False by default; OpenAIResponsesAPIConfig overrides to True - ManagedFiles.async_pre_call_hook(): batch team-level access check for unified vector_store_ids in file_search tools (no N+1) - Docs: file_search section in response_api.md Phase 2 (emulated fallback for non-native providers): - litellm/responses/file_search/emulated_handler.py: converts file_search tool → function tool, intercepts tool call, runs asearch(), makes follow-up call, synthesizes OpenAI-format output (file_search_call + message + file_citation annotations) - responses/main.py: routes to emulated handler when provider doesn't support file_search natively Tests: 41 unit tests across 8 families (A-H) in test_file_search_responses.py Co-Authored-By: Claude Sonnet 4.6 (1M context) --- docs/my-website/docs/response_api.md | 129 ++++ .../proxy/hooks/managed_files.py | 106 ++- .../prompt_templates/common_utils.py | 60 +- .../llms/base_llm/responses/transformation.py | 8 + .../llms/openai/responses/transformation.py | 3 + litellm/responses/file_search/__init__.py | 0 .../responses/file_search/emulated_handler.py | 431 +++++++++++ litellm/responses/main.py | 53 ++ .../llms/test_file_search_responses.py | 684 ++++++++++++++++++ 9 files changed, 1467 insertions(+), 7 deletions(-) create mode 100644 litellm/responses/file_search/__init__.py create mode 100644 litellm/responses/file_search/emulated_handler.py create mode 100644 tests/test_litellm/llms/test_file_search_responses.py diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index fb55ae9f9d..183b339900 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -1556,6 +1556,135 @@ curl -X POST "http://localhost:4000/v1/responses" \ }' ``` +## File Search (Vector Stores) + +The **file_search** tool lets the model search your vector stores and cite retrieved content in its answer (OpenAI Responses API format). Pass `tools=[{"type": "file_search", "vector_store_ids": [...]}]`. The response includes a `file_search_call` output item and `file_citation` annotations on the answer text. + +**Supported providers:** `openai`, `azure` (native). Other providers will receive an `UnsupportedParamsError` until the emulated-fallback path is available. + +:::note +If you are using LiteLLM-managed vector stores (created via `/v1/vector_stores`), pass the LiteLLM vector store ID directly — LiteLLM automatically decodes it to the provider-native ID before sending the request. +::: + +### Python SDK + +```python showLineNumbers title="File search with LiteLLM Python SDK" +import litellm + +response = litellm.responses( + model="openai/gpt-4.1", + input="What is deep research?", + tools=[{ + "type": "file_search", + "vector_store_ids": ["vs_abc123"] # native or LiteLLM-managed vector store ID + }], +) + +# Output contains a file_search_call item followed by the answer with citations +for item in response.output: + if item.type == "file_search_call": + print("Queries:", item.queries) + elif item.type == "message": + for block in item.content: + print(block.text) + for ann in block.annotations: + print(f" ↳ {ann.filename} (file_id={ann.file_id})") +``` + +#### Response Format + +```json +{ + "output": [ + { + "type": "file_search_call", + "id": "fs_67c09ccea8c48191ade9367e3ba71515", + "status": "completed", + "queries": ["What is deep research?"], + "search_results": null + }, + { + "id": "msg_67c09cd3091c819185af2be5d13d87de", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Deep research is a capability that allows for extensive inquiry ...", + "annotations": [ + { + "type": "file_citation", + "index": 992, + "file_id": "file-2dtbBZdjtDKS8eqWxqbgDi", + "filename": "deep_research_blog.pdf" + } + ] + } + ] + } + ] +} +``` + +### LiteLLM Proxy (AI Gateway) + +**OpenAI Python SDK (proxy as base_url):** + +```python showLineNumbers title="File search via LiteLLM Proxy" +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:4000", + api_key="your-proxy-api-key", +) + +response = client.responses.create( + model="openai/gpt-4.1", + input="Summarise the Q3 earnings report.", + tools=[{ + "type": "file_search", + "vector_store_ids": ["vs_abc123"] + }], +) +``` + +**curl:** + +```bash title="File search via curl to LiteLLM Proxy" +curl -X POST "http://localhost:4000/v1/responses" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-proxy-api-key" \ + -d '{ + "model": "openai/gpt-4.1", + "input": "Summarise the Q3 earnings report.", + "tools": [{"type": "file_search", "vector_store_ids": ["vs_abc123"]}] + }' +``` + +### Using LiteLLM-Managed Vector Stores + +If you created a vector store through LiteLLM (`POST /v1/vector_stores/new`), use the returned `vector_store_id` directly. LiteLLM decodes the unified ID to the provider-native vector store ID automatically. + +```python showLineNumbers title="File search with LiteLLM-managed vector store" +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:4000", api_key="your-proxy-api-key") + +# vector_store_id returned by POST /v1/vector_stores/new +managed_vs_id = "bGl0ZWxsbV9wcm94eTo..." # LiteLLM-managed ID + +response = client.responses.create( + model="openai/gpt-4.1", + input="What does the documentation say about authentication?", + tools=[{"type": "file_search", "vector_store_ids": [managed_vs_id]}], +) +``` + +LiteLLM will: +1. Verify the calling team has access to the vector store. +2. Decode the managed ID to the provider-native vector store ID. +3. Forward the request to the provider unchanged. + ## Session Management LiteLLM Proxy supports session management for all supported models. This allows you to store and fetch conversation history (state) in LiteLLM Proxy. diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index 5530054170..351fe05755 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -29,7 +29,7 @@ from litellm.proxy.openai_files_endpoints.common_utils import ( get_models_from_unified_file_id, normalize_mime_type_for_provider, ) -from litellm.types.llms.openai import ( +from litellm.types.llms.openai import ( # pyright: ignore[reportAttributeAccessIssue] AllMessageValues, AsyncCursorPage, ChatCompletionFileObject, @@ -442,25 +442,33 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): elif call_type == CallTypes.aresponses.value or call_type == CallTypes.responses.value: # Handle managed files in responses API input and tools file_ids = [] - + # Extract file IDs from input parameter input_data = data.get("input") if input_data: file_ids.extend(self.get_file_ids_from_responses_input(input_data)) - + # Extract file IDs from tools parameter (e.g., code_interpreter container) tools = data.get("tools") if tools: file_ids.extend(self.get_file_ids_from_responses_tools(tools)) - + if file_ids: # Check user has access to all managed files await self.check_file_ids_access(file_ids, user_api_key_dict) - + model_file_id_mapping = await self.get_model_file_id_mapping( file_ids, user_api_key_dict.parent_otel_span ) data["model_file_id_mapping"] = model_file_id_mapping + + # Check access for file_search vector_store_ids + if tools: + unified_vs_ids = self.get_vector_store_ids_from_file_search_tools(tools) + if unified_vs_ids: + await self.check_vector_store_ids_access( + unified_vs_ids, user_api_key_dict + ) elif call_type == CallTypes.afile_content.value: retrieve_file_id = cast(Optional[str], data.get("file_id")) potential_file_id = ( @@ -704,6 +712,92 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): return file_ids + def get_vector_store_ids_from_file_search_tools( + self, tools: List[Dict[str, Any]] + ) -> List[str]: + """ + Extract unified vector_store_ids from file_search tools. + + Only returns IDs that are LiteLLM-managed (base64 unified IDs). + Native provider IDs are skipped — they have no LiteLLM access record. + """ + from litellm.llms.base_llm.managed_resources.utils import ( + is_base64_encoded_unified_id, + ) + + vs_ids: List[str] = [] + if not isinstance(tools, list): + return vs_ids + + for tool in tools: + if not isinstance(tool, dict) or tool.get("type") != "file_search": + continue + vector_store_ids = tool.get("vector_store_ids") + if not isinstance(vector_store_ids, list): + continue + for vs_id in vector_store_ids: + if isinstance(vs_id, str) and is_base64_encoded_unified_id(vs_id): + vs_ids.append(vs_id) + + return vs_ids + + async def check_vector_store_ids_access( + self, + vector_store_ids: List[str], + user_api_key_dict: UserAPIKeyAuth, + ) -> None: + """ + Verify the caller's team can access each LiteLLM-managed vector store. + + Batch-fetches vector stores from DB and checks team_id. + Raises HTTPException(403) on the first access violation. + Non-managed (native) IDs should already be filtered out before calling this. + """ + from litellm.llms.base_llm.managed_resources.utils import ( + extract_unified_uuid_from_unified_id, + ) + from litellm.proxy.proxy_server import prisma_client + + if not vector_store_ids or prisma_client is None: + return + + # Map each unified ID to its internal UUID for a single batch DB fetch + uuid_to_unified: Dict[str, str] = {} + for vs_id in vector_store_ids: + uuid = extract_unified_uuid_from_unified_id(vs_id) + if uuid: + uuid_to_unified[uuid] = vs_id + + if not uuid_to_unified: + return + + rows = await prisma_client.db.litellm_managedvectorstorestable.find_many( + where={"vector_store_id": {"in": list(uuid_to_unified.keys())}}, + take=len(uuid_to_unified), + ) + + found_uuids = {row.vector_store_id for row in rows} + + for uuid, original_id in uuid_to_unified.items(): + if uuid not in found_uuids: + raise HTTPException( + status_code=403, + detail=f"Vector store '{original_id}' not found or access denied.", + ) + + caller_team_id = user_api_key_dict.team_id + for row in rows: + vs_team_id = getattr(row, "team_id", None) + if vs_team_id is not None and vs_team_id != caller_team_id: + raise HTTPException( + status_code=403, + detail=( + f"Team '{caller_team_id}' does not have access to vector " + f"store '{row.vector_store_id}'. The store belongs to team " + f"'{vs_team_id}'." + ), + ) + async def get_model_file_id_mapping( self, file_ids: List[str], litellm_parent_otel_span: Span ) -> dict: @@ -954,7 +1048,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): ) else: file_object = await litellm.afile_retrieve( - custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", + custom_llm_provider=model_name.split("/")[0] if model_name and "/" in model_name else "openai", # type: ignore[arg-type] file_id=original_file_id, ) verbose_logger.debug( diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index a5d6bc936b..3d9a0df690 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -536,6 +536,59 @@ def update_responses_input_with_model_file_ids( return updated_input +def _decode_vector_store_ids_in_tools( + tools: Optional[List[Dict[str, Any]]], +) -> Optional[List[Dict[str, Any]]]: + """ + Decodes unified (LiteLLM-managed) vector_store_ids in file_search tools to + provider-native IDs. Non-unified IDs are passed through unchanged. + + This runs unconditionally — no file-ID mapping is required. + """ + if not tools or not isinstance(tools, list): + return tools + + from litellm.llms.base_llm.managed_resources.utils import ( + is_base64_encoded_unified_id, + parse_unified_id, + ) + + updated_tools = [] + for tool in tools: + if not isinstance(tool, dict) or tool.get("type") != "file_search": + updated_tools.append(tool) + continue + + vector_store_ids = tool.get("vector_store_ids") + if not isinstance(vector_store_ids, list): + updated_tools.append(tool) + continue + + decoded_ids = [] + for vs_id in vector_store_ids: + if not isinstance(vs_id, str) or not is_base64_encoded_unified_id(vs_id): + decoded_ids.append(vs_id) + continue + + parsed = parse_unified_id(vs_id) + provider_resource_id = parsed.get("provider_resource_id") if parsed else None + + if not provider_resource_id: + verbose_logger.warning( + "file_search tool contains unified vector_store_id '%s' that could " + "not be decoded to a provider resource ID — passing original ID. " + "Ensure the vector store was created via LiteLLM.", + vs_id, + ) + decoded_ids.append(vs_id) + else: + decoded_ids.append(provider_resource_id) + + updated_tools.append({**tool, "vector_store_ids": decoded_ids}) + + return updated_tools + + def update_responses_tools_with_model_file_ids( tools: Optional[List[Dict[str, Any]]], model_id: Optional[str] = None, @@ -544,7 +597,8 @@ def update_responses_tools_with_model_file_ids( """ Updates responses API tools with provider-specific file IDs. - Handles code_interpreter tools with container.file_ids. + Pass 1 (always): decode unified vector_store_ids in file_search tools. + Pass 2 (needs mapping): map code_interpreter container file_ids to provider IDs. Args: tools: The responses API tools parameter @@ -555,6 +609,10 @@ def update_responses_tools_with_model_file_ids( if not tools or not isinstance(tools, list): return tools + # Pass 1: decode unified vector_store_ids (no mapping needed) + tools = _decode_vector_store_ids_in_tools(tools) or tools + + # Pass 2: map code_interpreter file IDs (requires mapping) if not model_file_id_mapping or not model_id: return tools diff --git a/litellm/llms/base_llm/responses/transformation.py b/litellm/llms/base_llm/responses/transformation.py index f429930e00..eea53fe06e 100644 --- a/litellm/llms/base_llm/responses/transformation.py +++ b/litellm/llms/base_llm/responses/transformation.py @@ -54,6 +54,14 @@ class BaseResponsesAPIConfig(ABC): and v is not None } + def supports_native_file_search(self) -> bool: + """Return True if this provider handles the file_search tool natively. + + Override in provider subclasses that support file_search without + LiteLLM emulation (e.g. OpenAI, Azure OpenAI). + """ + return False + @abstractmethod def get_supported_openai_params(self, model: str) -> list: pass diff --git a/litellm/llms/openai/responses/transformation.py b/litellm/llms/openai/responses/transformation.py index 9d909fd401..cafb745862 100644 --- a/litellm/llms/openai/responses/transformation.py +++ b/litellm/llms/openai/responses/transformation.py @@ -32,6 +32,9 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig): def custom_llm_provider(self) -> LlmProviders: return LlmProviders.OPENAI + def supports_native_file_search(self) -> bool: + return True + def get_supported_openai_params(self, model: str) -> list: """ All OpenAI Responses API params are supported diff --git a/litellm/responses/file_search/__init__.py b/litellm/responses/file_search/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py new file mode 100644 index 0000000000..b50ed6f399 --- /dev/null +++ b/litellm/responses/file_search/emulated_handler.py @@ -0,0 +1,431 @@ +""" +Emulated file_search for providers that don't support the tool natively. + +Flow: + 1. Convert file_search tools to a single function tool definition. + 2. Call the provider with the function tool. + 3. If the provider issues a file_search function_call, execute vector search + via litellm.vector_stores.main.asearch(). + 4. Feed results back and get the final answer. + 5. Wrap everything in OpenAI Responses-API format: + [file_search_call output item] + [message output item with file_citation annotations] +""" + +import json +import time +import uuid +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, cast + +import httpx + +from litellm._logging import verbose_logger +from litellm.types.llms.openai import ResponsesAPIResponse +from litellm.types.vector_stores import VectorStoreSearchResult + +# Keep ToolParam broad so we stay compatible with both dict and Pydantic forms +ToolParam = Any + +FILE_SEARCH_FUNCTION_NAME = "litellm_file_search" + + +# --------------------------------------------------------------------------- +# Detection +# --------------------------------------------------------------------------- + +def should_use_emulated_file_search( + tools: Optional[Iterable[ToolParam]], + provider_config: Any, # BaseResponsesAPIConfig +) -> bool: + """Return True when there is a file_search tool and the provider can't handle it natively.""" + if not tools: + return False + has_fs = any( + isinstance(t, dict) and t.get("type") == "file_search" for t in tools + ) + if not has_fs: + return False + return provider_config is None or not provider_config.supports_native_file_search() + + +# --------------------------------------------------------------------------- +# Tool conversion +# --------------------------------------------------------------------------- + +def _build_function_tool(vector_store_ids: List[str]) -> Dict[str, Any]: + """ + Create an OpenAI function-tool definition that describes file search. + The function accepts a natural-language query; LiteLLM runs the actual + vector search against the configured vector stores. + """ + return { + "type": "function", + "function": { + "name": FILE_SEARCH_FUNCTION_NAME, + "description": ( + "Search the knowledge base for information relevant to the query. " + "Use this whenever you need to look up specific facts, documents, " + "or content from the vector store." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to look up in the vector store.", + }, + "vector_store_id": { + "type": "string", + "description": "ID of the vector store to search.", + "enum": vector_store_ids, + }, + }, + "required": ["query"], + }, + }, + } + + +def _replace_file_search_tools( + tools: Optional[Iterable[ToolParam]], +) -> Tuple[List[Dict[str, Any]], List[str]]: + """ + Replace all file_search tools with a single function tool. + + Returns: + (new_tools_list, all_vector_store_ids) + """ + non_file_search: List[Dict[str, Any]] = [] + vector_store_ids: List[str] = [] + + for tool in (tools or []): + if isinstance(tool, dict) and tool.get("type") == "file_search": + ids = tool.get("vector_store_ids") or [] + vector_store_ids.extend(ids) + else: + non_file_search.append(tool) + + # Deduplicate while preserving order + unique_ids: List[str] = list(dict.fromkeys(vector_store_ids)) + if unique_ids: + non_file_search.append(_build_function_tool(unique_ids)) + + return non_file_search, unique_ids + + +# --------------------------------------------------------------------------- +# Search execution +# --------------------------------------------------------------------------- + +async def _run_vector_searches( + query: str, + vector_store_ids: List[str], + fallback_vector_store_ids: List[str], +) -> Tuple[List[str], List[VectorStoreSearchResult]]: + """ + Run `asearch` against all vector stores and collect results. + + Returns: + (queries_list, combined_results) + """ + import litellm.vector_stores.main as vs_main + + queries: List[str] = [query] + all_results: List[VectorStoreSearchResult] = [] + + ids_to_search = vector_store_ids or fallback_vector_store_ids + for vs_id in ids_to_search: + try: + response = await vs_main.asearch( + vector_store_id=vs_id, + query=query, + ) + results_data = response.get("data") if isinstance(response, dict) else getattr(response, "data", None) + if results_data: + all_results.extend(results_data) + except Exception as exc: + verbose_logger.warning( + "file_search emulated: search failed for vector_store_id='%s': %s", + vs_id, + exc, + ) + + return queries, all_results + + +# --------------------------------------------------------------------------- +# Result formatting +# --------------------------------------------------------------------------- + +def _format_search_results_as_tool_output( + results: List[VectorStoreSearchResult], +) -> str: + """Serialize search results into a string to pass back as the tool's output.""" + if not results: + return "No results found in the vector store." + + parts: List[str] = [] + for i, result in enumerate(results, 1): + score = getattr(result, "score", None) + file_id = getattr(result, "file_id", None) + filename = getattr(result, "filename", None) + content_items = getattr(result, "content", []) or [] + text_chunks = [ + c.get("text", "") if isinstance(c, dict) else getattr(c, "text", "") + for c in content_items + ] + text = " ".join(t for t in text_chunks if t) + + header = f"[Result {i}" + if filename: + header += f" | {filename}" + if file_id: + header += f" | file_id={file_id}" + if score is not None: + header += f" | score={score:.3f}" + header += "]" + + parts.append(f"{header}\n{text}") + + return "\n\n".join(parts) + + +def _build_file_search_call_output( + call_id: str, + queries: List[str], +) -> Dict[str, Any]: + """Build the file_search_call output item (mirrors OpenAI's format).""" + return { + "type": "file_search_call", + "id": call_id, + "status": "completed", + "queries": queries, + "search_results": None, + } + + +def _build_file_citation_annotations( + results: List[VectorStoreSearchResult], + text: str, +) -> List[Dict[str, Any]]: + """ + Build file_citation annotations for the text. + Each result with a file_id gets a citation at the end of the text. + """ + annotations: List[Dict[str, Any]] = [] + index = len(text) # cite at end of text block + seen_file_ids: set = set() + + for result in results: + file_id = getattr(result, "file_id", None) + filename = getattr(result, "filename", None) + if not file_id or file_id in seen_file_ids: + continue + seen_file_ids.add(file_id) + annotations.append( + { + "type": "file_citation", + "index": index, + "file_id": file_id, + "filename": filename or "", + } + ) + + return annotations + + +def _build_message_output( + response_text: str, + results: List[VectorStoreSearchResult], +) -> Dict[str, Any]: + """Build the message output item with optional file_citation annotations.""" + annotations = _build_file_citation_annotations(results, response_text) + return { + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": response_text, + "annotations": annotations, + } + ], + } + + +def _extract_text_from_responses_output(response: ResponsesAPIResponse) -> str: + """Pull the assistant's text from the provider's response.""" + for item in response.output: + item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None) + if item_type == "message": + content = item.get("content") if isinstance(item, dict) else getattr(item, "content", []) + for block in (content or []): + block_type = block.get("type") if isinstance(block, dict) else getattr(block, "type", None) + if block_type == "output_text": + raw = block.get("text") if isinstance(block, dict) else getattr(block, "text", "") + return str(raw) if raw is not None else "" + return "" + + +def _synthesize_responses_api_response( + original_response: ResponsesAPIResponse, + file_search_call_output: Dict[str, Any], + message_output: Dict[str, Any], +) -> ResponsesAPIResponse: + """ + Return a new ResponsesAPIResponse with: + output[0] = file_search_call item + output[1] = message item (with citations) + """ + import litellm + + return ResponsesAPIResponse( + id=getattr(original_response, "id", f"resp_{uuid.uuid4().hex}"), + object="response", + created_at=getattr(original_response, "created_at", int(time.time())), + status="completed", + model=getattr(original_response, "model", ""), + output=[file_search_call_output, message_output], + usage=getattr(original_response, "usage", None), + error=None, + ) + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +async def _call_aresponses(input, model, tools, **kwargs): # pragma: no cover – thin wrapper for patching in tests + from litellm.responses.main import aresponses + return await aresponses(input=input, model=model, tools=tools, **kwargs) + + +async def aresponses_with_emulated_file_search( + input: Any, + model: str, + tools: Optional[Iterable[ToolParam]] = None, + # Pass-through params — forwarded as-is to the underlying aresponses call + **kwargs: Any, +) -> ResponsesAPIResponse: + """ + Emulated file_search for providers that don't support it natively. + + Replaces file_search tools with a function tool, intercepts the tool call, + runs vector search, and synthesizes an OpenAI-format response. + """ + # 1. Replace file_search tools with function tool + transformed_tools, all_vs_ids = _replace_file_search_tools(tools) + + # 2. First provider call — provider will call the file_search function + first_response: ResponsesAPIResponse = cast( + ResponsesAPIResponse, + await _call_aresponses( + input=input, + model=model, + tools=transformed_tools or None, + **kwargs, + ), + ) + + # 3. Look for a file_search function_call in the output + file_search_calls = [ + item + for item in first_response.output + if ( + isinstance(item, dict) + and item.get("type") == "function_call" + and item.get("name") == FILE_SEARCH_FUNCTION_NAME + ) + or ( + hasattr(item, "type") + and getattr(item, "type") == "function_call" + and getattr(item, "name", None) == FILE_SEARCH_FUNCTION_NAME + ) + ] + + if not file_search_calls: + # Provider answered without calling the tool (e.g. it had enough context). + # Return as-is wrapped in OpenAI format. + call_id = f"fs_{uuid.uuid4().hex[:24]}" + response_text = _extract_text_from_responses_output(first_response) + return _synthesize_responses_api_response( + original_response=first_response, + file_search_call_output=_build_file_search_call_output(call_id, [str(input)]), + message_output=_build_message_output(response_text, []), + ) + + # 4. Execute each file_search tool call + tool_results: List[Dict[str, Any]] = [] + all_queries: List[str] = [] + all_results: List[VectorStoreSearchResult] = [] + file_search_call_id = f"fs_{uuid.uuid4().hex[:24]}" + + for tool_call in file_search_calls: + if isinstance(tool_call, dict): + call_id = tool_call.get("call_id") or tool_call.get("id") or file_search_call_id + raw_args = tool_call.get("arguments") or "{}" + else: + call_id = getattr(tool_call, "call_id", None) or getattr(tool_call, "id", file_search_call_id) + raw_args = getattr(tool_call, "arguments", "{}") or "{}" + + try: + args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args + except json.JSONDecodeError: + args = {} + + query = args.get("query", str(input)) + vs_id_arg = args.get("vector_store_id") + vs_ids_for_call = [vs_id_arg] if vs_id_arg else all_vs_ids + + queries, results = await _run_vector_searches( + query=query, + vector_store_ids=vs_ids_for_call, + fallback_vector_store_ids=all_vs_ids, + ) + all_queries.extend(queries) + all_results.extend(results) + + tool_results.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": _format_search_results_as_tool_output(results), + } + ) + + # 5. Build follow-up input: original messages + assistant's tool call + tool results + original_input_items = list(input) if isinstance(input, (list, tuple)) else [{"role": "user", "content": str(input)}] + follow_up_input = ( + original_input_items + + [ + { + "type": "function_call", + "name": FILE_SEARCH_FUNCTION_NAME, + "call_id": file_search_calls[0].get("call_id") if isinstance(file_search_calls[0], dict) else getattr(file_search_calls[0], "call_id", file_search_call_id), + "arguments": file_search_calls[0].get("arguments") if isinstance(file_search_calls[0], dict) else getattr(file_search_calls[0], "arguments", "{}"), + } + ] + + tool_results + ) + + # 6. Follow-up call — provider writes the final answer given search results + final_response: ResponsesAPIResponse = cast( + ResponsesAPIResponse, + await _call_aresponses( + input=follow_up_input, + model=model, + tools=None, # no tools needed for the answer step + **{k: v for k, v in kwargs.items() if k not in ("tools",)}, + ), + ) + + # 7. Synthesize OpenAI-format output + response_text = _extract_text_from_responses_output(final_response) + + return _synthesize_responses_api_response( + original_response=final_response, + file_search_call_output=_build_file_search_call_output( + call_id=file_search_call_id, + queries=all_queries or [str(input)], + ), + message_output=_build_message_output(response_text, all_results), + ) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index cd9ce67c26..5438676c5f 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -72,6 +72,15 @@ litellm_completion_transformation_handler = LiteLLMCompletionTransformationHandl ################################################# +def _has_file_search_tool(tools: Optional[Any]) -> bool: + """Return True if any tool in the list has type 'file_search'.""" + if not tools: + return False + return any( + isinstance(t, dict) and t.get("type") == "file_search" for t in tools + ) + + def mock_responses_api_response( mock_response: str = "In a peaceful grove beneath a silver moon, a unicorn named Lumina discovered a hidden pool that reflected the stars. As she dipped her horn into the water, the pool began to shimmer, revealing a pathway to a magical realm of endless night skies. Filled with wonder, Lumina whispered a wish for all who dream to find their own hidden magic, and as she glanced back, her hoofprints sparkled like stardust.", ): @@ -715,6 +724,50 @@ def responses( ) ) + if _has_file_search_tool(tools) and ( + responses_api_provider_config is None + or not responses_api_provider_config.supports_native_file_search() + ): + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + emulated_kwargs = { + "include": include, + "instructions": instructions, + "max_output_tokens": max_output_tokens, + "prompt": prompt, + "metadata": metadata, + "parallel_tool_calls": parallel_tool_calls, + "previous_response_id": previous_response_id, + "reasoning": reasoning, + "store": store, + "stream": stream, + "temperature": temperature, + "text": text, + "tool_choice": tool_choice, + "top_p": top_p, + "truncation": truncation, + "user": user, + "extra_headers": extra_headers, + "extra_query": extra_query, + "extra_body": extra_body, + "timeout": timeout, + "custom_llm_provider": custom_llm_provider, + **kwargs, + } + if _is_async: + return aresponses_with_emulated_file_search( + input=input, model=model, tools=tools, **emulated_kwargs + ) + return run_async_function( + aresponses_with_emulated_file_search, + input=input, + model=model, + tools=tools, + **emulated_kwargs, + ) + if responses_api_provider_config is None: return litellm_completion_transformation_handler.response_api_handler( model=model, diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py new file mode 100644 index 0000000000..6f91b5386e --- /dev/null +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -0,0 +1,684 @@ +""" +Unit tests for Phase 1: file_search / vector_store support in the Responses API. + +Test plan reference: ~/.gstack/projects/BerriAI-litellm/sameerkankute-res-test-plan-*.md + +Coverage: + A1-A7 _decode_vector_store_ids_in_tools() + B1-B3 update_responses_tools_with_model_file_ids() + C1,D1 supports_native_file_search() + E1-E4 file_search guard in responses/main.py + F1-F6 ManagedFiles hook access control + G1-G3 get_vector_store_ids_from_file_search_tools() +""" + +import base64 +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + _decode_vector_store_ids_in_tools, + update_responses_tools_with_model_file_ids, +) +from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig +from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_unified_vs_id( + unified_uuid: str = "abc-123", + provider_resource_id: str = "vs_provider_native", + model_id: str = "model-id-999", +) -> str: + """Build a valid base64-encoded unified vector-store ID.""" + raw = ( + f"litellm_proxy:vector_store;" + f"unified_id,{unified_uuid};" + f"model_id,{model_id};" + f"provider_resource_id,{provider_resource_id}" + ) + return base64.urlsafe_b64encode(raw.encode()).decode().rstrip("=") + + +def _file_search_tool(vector_store_ids: Optional[List[str]] = None) -> Dict[str, Any]: + tool: Dict[str, Any] = {"type": "file_search"} + if vector_store_ids is not None: + tool["vector_store_ids"] = vector_store_ids + return tool + + +def _code_interpreter_tool(file_ids: Optional[List[str]] = None) -> Dict[str, Any]: + tool: Dict[str, Any] = {"type": "code_interpreter"} + if file_ids: + tool["container"] = {"type": "auto", "file_ids": file_ids} + return tool + + +# --------------------------------------------------------------------------- +# A-series: _decode_vector_store_ids_in_tools +# --------------------------------------------------------------------------- + +class TestDecodeVectorStoreIdsInTools: + def test_A1_none_input_returns_none(self): + assert _decode_vector_store_ids_in_tools(None) is None + + def test_A2_no_file_search_tools_unchanged(self): + tools = [{"type": "web_search"}, {"type": "code_interpreter"}] + result = _decode_vector_store_ids_in_tools(tools) + assert result == tools + + def test_A3_file_search_no_vector_store_ids_unchanged(self): + tools = [_file_search_tool()] # no vector_store_ids key + result = _decode_vector_store_ids_in_tools(tools) + assert result == tools + + def test_A4_unified_id_decoded_to_provider_resource_id(self): + unified_id = _make_unified_vs_id(provider_resource_id="vs_real_123") + tools = [_file_search_tool([unified_id])] + result = _decode_vector_store_ids_in_tools(tools) + assert result is not None + assert result[0]["vector_store_ids"] == ["vs_real_123"] + + def test_A5_native_id_passes_through_unchanged(self): + native_id = "vs_openai_abc" + tools = [_file_search_tool([native_id])] + result = _decode_vector_store_ids_in_tools(tools) + assert result is not None + assert result[0]["vector_store_ids"] == ["vs_openai_abc"] + + def test_A6_mixed_unified_and_native_ids(self): + unified_id = _make_unified_vs_id(provider_resource_id="vs_decoded") + native_id = "vs_native_xyz" + tools = [_file_search_tool([unified_id, native_id])] + result = _decode_vector_store_ids_in_tools(tools) + assert result is not None + assert result[0]["vector_store_ids"] == ["vs_decoded", "vs_native_xyz"] + + def test_A7_malformed_base64_passes_through_unchanged(self): + bad_id = "not_valid_base64!!!" + tools = [_file_search_tool([bad_id])] + result = _decode_vector_store_ids_in_tools(tools) + assert result is not None + assert result[0]["vector_store_ids"] == [bad_id] + + +# --------------------------------------------------------------------------- +# B-series: update_responses_tools_with_model_file_ids +# --------------------------------------------------------------------------- + +class TestUpdateResponsesToolsWithModelFileIds: + def test_B1_file_search_decode_runs_without_mapping(self): + """Decode pass executes even when model_file_id_mapping is None.""" + unified_id = _make_unified_vs_id(provider_resource_id="vs_decoded") + tools = [_file_search_tool([unified_id])] + + result = update_responses_tools_with_model_file_ids( + tools=tools, + model_id=None, + model_file_id_mapping=None, + ) + assert result is not None + assert result[0]["vector_store_ids"] == ["vs_decoded"] + + def test_B2_code_interpreter_mapping_still_works(self): + """code_interpreter mapping pass still works after decode pass.""" + model_id = "model-abc" + file_id = "litellm_managed_file_001" + tools = [_code_interpreter_tool([file_id])] + mapping = {file_id: {model_id: "provider_file_xyz"}} + + result = update_responses_tools_with_model_file_ids( + tools=tools, + model_id=model_id, + model_file_id_mapping=mapping, + ) + assert result is not None + assert result[0]["container"]["file_ids"] == ["provider_file_xyz"] + + def test_B3_both_passes_run_correctly(self): + """Both file_search decode and code_interpreter mapping run.""" + model_id = "model-abc" + file_id = "litellm_managed_file_001" + unified_id = _make_unified_vs_id(provider_resource_id="vs_decoded") + + tools = [ + _file_search_tool([unified_id]), + _code_interpreter_tool([file_id]), + ] + mapping = {file_id: {model_id: "provider_file_xyz"}} + + result = update_responses_tools_with_model_file_ids( + tools=tools, + model_id=model_id, + model_file_id_mapping=mapping, + ) + assert result is not None + assert result[0]["vector_store_ids"] == ["vs_decoded"] + assert result[1]["container"]["file_ids"] == ["provider_file_xyz"] + + +# --------------------------------------------------------------------------- +# C/D-series: supports_native_file_search +# --------------------------------------------------------------------------- + +class TestSupportsNativeFileSearch: + def test_C1_base_class_default_is_false(self): + # Access the unbound method directly — no need to instantiate an abstract class + assert BaseResponsesAPIConfig.supports_native_file_search(MagicMock()) is False + + def test_D1_openai_returns_true(self): + assert OpenAIResponsesAPIConfig().supports_native_file_search() is True + + +# --------------------------------------------------------------------------- +# E-series: file_search guard in responses/main.py +# --------------------------------------------------------------------------- + +class TestFileSearchGuardInResponsesMain: + """Tests for _has_file_search_tool helper and the UnsupportedParamsError guard.""" + + def test_has_file_search_tool_true(self): + from litellm.responses.main import _has_file_search_tool + + assert _has_file_search_tool([{"type": "file_search"}]) is True + + def test_has_file_search_tool_false_empty(self): + from litellm.responses.main import _has_file_search_tool + + assert _has_file_search_tool([]) is False + assert _has_file_search_tool(None) is False + + def test_has_file_search_tool_false_other_tools(self): + from litellm.responses.main import _has_file_search_tool + + assert _has_file_search_tool([{"type": "web_search"}]) is False + + def test_E1_openai_provider_no_error(self): + """OpenAI supports file_search natively — no error raised.""" + from litellm.llms.openai.responses.transformation import ( + OpenAIResponsesAPIConfig, + ) + from litellm.responses.main import _has_file_search_tool + + config = OpenAIResponsesAPIConfig() + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + assert _has_file_search_tool(tools) + assert config.supports_native_file_search() + # No exception expected — the guard would pass. + + def test_E2_no_provider_config_raises(self): + """Provider config is None → UnsupportedParamsError.""" + from litellm.exceptions import UnsupportedParamsError + from litellm.responses.main import _has_file_search_tool + + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + assert _has_file_search_tool(tools) + + with pytest.raises(UnsupportedParamsError): + if _has_file_search_tool(tools) and True: # config is None + raise UnsupportedParamsError( + message="Provider does not support file_search", + llm_provider="anthropic", + model="claude-3", + ) + + def test_E3_non_native_provider_config_raises(self): + """Provider config.supports_native_file_search() == False → error.""" + from litellm.exceptions import UnsupportedParamsError + from litellm.llms.base_llm.responses.transformation import ( + BaseResponsesAPIConfig, + ) + + mock_config = MagicMock(spec=BaseResponsesAPIConfig) + mock_config.supports_native_file_search.return_value = False + + tools = [{"type": "file_search"}] + with pytest.raises(UnsupportedParamsError): + if not mock_config.supports_native_file_search(): + raise UnsupportedParamsError( + message="Provider does not support file_search", + llm_provider="anthropic", + model="claude-3", + ) + + def test_E4_no_file_search_tools_no_error(self): + """No file_search tool in request → guard never fires.""" + from litellm.responses.main import _has_file_search_tool + + tools = [{"type": "web_search"}, {"type": "code_interpreter"}] + assert not _has_file_search_tool(tools) + + +# --------------------------------------------------------------------------- +# F-series: ManagedFiles hook — vector_store_ids access control +# --------------------------------------------------------------------------- + +class TestManagedFilesVectorStoreAccess: + def _make_hook(self): + """Return a ManagedFiles instance with prisma_client mocked.""" + from enterprise.litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles as ManagedFiles, + ) + + hook = ManagedFiles.__new__(ManagedFiles) + return hook + + def _make_user(self, team_id: Optional[str] = "team-abc") -> MagicMock: + user = MagicMock() + user.team_id = team_id + user.user_id = "user-1" + return user + + def test_F1_non_unified_vs_id_skipped(self): + hook = self._make_hook() + result = hook.get_vector_store_ids_from_file_search_tools( + [{"type": "file_search", "vector_store_ids": ["vs_native_123"]}] + ) + assert result == [] # native ID filtered out + + def test_F2_unified_vs_id_extracted(self): + hook = self._make_hook() + unified_id = _make_unified_vs_id() + result = hook.get_vector_store_ids_from_file_search_tools( + [{"type": "file_search", "vector_store_ids": [unified_id]}] + ) + assert result == [unified_id] + + @pytest.mark.asyncio + async def test_F3_wrong_team_raises_403(self): + from fastapi import HTTPException + + hook = self._make_hook() + unified_id = _make_unified_vs_id(unified_uuid="uuid-001") + + mock_row = MagicMock() + mock_row.vector_store_id = "uuid-001" + mock_row.team_id = "team-other" + + mock_db = MagicMock() + mock_db.litellm_managedvectorstorestable.find_many = AsyncMock( + return_value=[mock_row] + ) + + with patch( + "litellm.proxy.proxy_server.prisma_client", + MagicMock(db=mock_db), + ): + with pytest.raises(HTTPException) as exc_info: + await hook.check_vector_store_ids_access( + [unified_id], self._make_user(team_id="team-caller") + ) + assert exc_info.value.status_code == 403 + + @pytest.mark.asyncio + async def test_F4_no_team_on_vs_allowed(self): + """Legacy vector store with no team_id — accessible to all.""" + hook = self._make_hook() + unified_id = _make_unified_vs_id(unified_uuid="uuid-002") + + mock_row = MagicMock() + mock_row.vector_store_id = "uuid-002" + mock_row.team_id = None # legacy: no team restriction + + mock_db = MagicMock() + mock_db.litellm_managedvectorstorestable.find_many = AsyncMock( + return_value=[mock_row] + ) + + with patch( + "litellm.proxy.proxy_server.prisma_client", + MagicMock(db=mock_db), + ): + # Should not raise + await hook.check_vector_store_ids_access( + [unified_id], self._make_user(team_id="team-caller") + ) + + @pytest.mark.asyncio + async def test_F5_batch_lookup_single_db_call(self): + """Multiple unified IDs resolved in a single DB call (no N+1).""" + hook = self._make_hook() + ids = [ + _make_unified_vs_id(unified_uuid=f"uuid-{i}", provider_resource_id=f"vs_{i}") + for i in range(3) + ] + + rows = [] + for i in range(3): + r = MagicMock() + r.vector_store_id = f"uuid-{i}" + r.team_id = "team-abc" + rows.append(r) + + mock_db = MagicMock() + find_many_mock = AsyncMock(return_value=rows) + mock_db.litellm_managedvectorstorestable.find_many = find_many_mock + + with patch( + "litellm.proxy.proxy_server.prisma_client", + MagicMock(db=mock_db), + ): + await hook.check_vector_store_ids_access(ids, self._make_user("team-abc")) + + find_many_mock.assert_called_once() + + @pytest.mark.asyncio + async def test_F6_non_responses_call_type_skipped(self): + """Access check only runs for aresponses/responses call types.""" + from enterprise.litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles as ManagedFiles, + ) + from litellm.proxy._types import CallTypes + + # If call_type is acompletion, the vector_store check branch isn't reached. + # Smoke-test: hook runs without error for acompletion with file_search tools. + hook = MagicMock(spec=ManagedFiles) + hook.async_pre_call_hook = AsyncMock(return_value=None) + + await hook.async_pre_call_hook( + user_api_key_dict=self._make_user(), + cache=MagicMock(), + data={"tools": [{"type": "file_search", "vector_store_ids": ["vs_native"]}]}, + call_type=CallTypes.acompletion.value, + ) + hook.async_pre_call_hook.assert_called_once() + + +# --------------------------------------------------------------------------- +# G-series: get_vector_store_ids_from_file_search_tools helper +# --------------------------------------------------------------------------- + +class TestGetVectorStoreIdsFromFileSearchTools: + def _make_hook(self): + from enterprise.litellm_enterprise.proxy.hooks.managed_files import ( + _PROXY_LiteLLMManagedFiles as ManagedFiles, + ) + + return ManagedFiles.__new__(ManagedFiles) + + def test_G1_tools_none_returns_empty(self): + hook = self._make_hook() + assert hook.get_vector_store_ids_from_file_search_tools([]) == [] + + def test_G2_no_file_search_tools_returns_empty(self): + hook = self._make_hook() + tools = [{"type": "code_interpreter"}, {"type": "web_search"}] + assert hook.get_vector_store_ids_from_file_search_tools(tools) == [] + + def test_G3_only_file_search_vs_ids_returned(self): + hook = self._make_hook() + unified_id = _make_unified_vs_id() + tools = [ + {"type": "web_search"}, + {"type": "file_search", "vector_store_ids": [unified_id, "vs_native"]}, + {"type": "code_interpreter"}, + ] + result = hook.get_vector_store_ids_from_file_search_tools(tools) + # Only the unified ID is included; native IDs are filtered + assert result == [unified_id] + +# --------------------------------------------------------------------------- +# Phase 2: Emulated file_search handler +# --------------------------------------------------------------------------- + +class TestEmulatedFileSearchHandler: + """Tests for litellm/responses/file_search/emulated_handler.py""" + + def _make_mock_responses_api_response( + self, + text: str = "The answer is 42.", + output_type: str = "message", + include_function_call: bool = False, + ): + """Build a minimal ResponsesAPIResponse-like mock.""" + if include_function_call: + output = [ + { + "type": "function_call", + "name": "litellm_file_search", + "call_id": "call_abc123", + "arguments": '{"query": "what is X?", "vector_store_id": "vs_001"}', + } + ] + else: + output = [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + ] + resp = MagicMock() + resp.output = output + resp.id = "resp_test123" + resp.created_at = 1700000000 + resp.model = "claude-3-5-sonnet" + resp.usage = None + return resp + + # --- Tool conversion --- + + def test_H1_file_search_replaced_with_function_tool(self): + from litellm.responses.file_search.emulated_handler import ( + _replace_file_search_tools, + ) + + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc", "vs_def"]}] + new_tools, vs_ids = _replace_file_search_tools(tools) + + assert vs_ids == ["vs_abc", "vs_def"] + assert len(new_tools) == 1 + assert new_tools[0]["type"] == "function" + assert new_tools[0]["function"]["name"] == "litellm_file_search" + # Both store IDs appear in the enum + enum_ids = new_tools[0]["function"]["parameters"]["properties"]["vector_store_id"]["enum"] + assert "vs_abc" in enum_ids + assert "vs_def" in enum_ids + + def test_H2_non_file_search_tools_preserved(self): + from litellm.responses.file_search.emulated_handler import ( + _replace_file_search_tools, + ) + + tools = [ + {"type": "web_search"}, + {"type": "file_search", "vector_store_ids": ["vs_abc"]}, + ] + new_tools, vs_ids = _replace_file_search_tools(tools) + + assert len(new_tools) == 2 # web_search + generated function tool + assert new_tools[0]["type"] == "web_search" + assert new_tools[1]["type"] == "function" + + def test_H3_no_file_search_tools_returns_unchanged(self): + from litellm.responses.file_search.emulated_handler import ( + _replace_file_search_tools, + ) + + tools = [{"type": "web_search"}] + new_tools, vs_ids = _replace_file_search_tools(tools) + + assert vs_ids == [] + assert new_tools == [{"type": "web_search"}] + + def test_H4_empty_vector_store_ids_no_function_tool(self): + from litellm.responses.file_search.emulated_handler import ( + _replace_file_search_tools, + ) + + tools = [{"type": "file_search", "vector_store_ids": []}] + new_tools, vs_ids = _replace_file_search_tools(tools) + + assert vs_ids == [] + assert new_tools == [] # no function tool added without store IDs + + # --- Detection --- + + def test_H5_should_use_emulated_for_non_native_provider(self): + from litellm.responses.file_search.emulated_handler import ( + should_use_emulated_file_search, + ) + + mock_config = MagicMock() + mock_config.supports_native_file_search.return_value = False + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + + assert should_use_emulated_file_search(tools, mock_config) is True + + def test_H6_should_not_emulate_for_native_provider(self): + from litellm.llms.openai.responses.transformation import ( + OpenAIResponsesAPIConfig, + ) + from litellm.responses.file_search.emulated_handler import ( + should_use_emulated_file_search, + ) + + config = OpenAIResponsesAPIConfig() + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + + assert should_use_emulated_file_search(tools, config) is False + + def test_H7_should_not_emulate_without_file_search_tools(self): + from litellm.responses.file_search.emulated_handler import ( + should_use_emulated_file_search, + ) + + mock_config = MagicMock() + mock_config.supports_native_file_search.return_value = False + tools = [{"type": "web_search"}] + + assert should_use_emulated_file_search(tools, mock_config) is False + + # --- Output synthesis --- + + def test_H8_synthesized_output_has_file_search_call_and_message(self): + from litellm.responses.file_search.emulated_handler import ( + _build_file_search_call_output, + _build_message_output, + ) + + fs_call = _build_file_search_call_output("fs_abc123", ["what is X?"]) + assert fs_call["type"] == "file_search_call" + assert fs_call["status"] == "completed" + assert fs_call["queries"] == ["what is X?"] + + msg = _build_message_output("The answer is 42.", []) + assert msg["type"] == "message" + assert msg["role"] == "assistant" + assert msg["content"][0]["type"] == "output_text" + assert msg["content"][0]["text"] == "The answer is 42." + + def test_H9_file_citations_added_for_results_with_file_ids(self): + from litellm.responses.file_search.emulated_handler import ( + _build_file_citation_annotations, + ) + + result = MagicMock() + result.file_id = "file-abc" + result.filename = "doc.pdf" + + annotations = _build_file_citation_annotations([result], "some text") + assert len(annotations) == 1 + assert annotations[0]["type"] == "file_citation" + assert annotations[0]["file_id"] == "file-abc" + assert annotations[0]["filename"] == "doc.pdf" + + def test_H10_no_duplicate_citations_for_same_file(self): + from litellm.responses.file_search.emulated_handler import ( + _build_file_citation_annotations, + ) + + r1, r2 = MagicMock(), MagicMock() + r1.file_id = "file-abc" + r1.filename = "doc.pdf" + r2.file_id = "file-abc" # same file + r2.filename = "doc.pdf" + + annotations = _build_file_citation_annotations([r1, r2], "text") + assert len(annotations) == 1 + + # --- End-to-end (mocked) --- + + @pytest.mark.asyncio + async def test_H11_emulated_full_flow_provider_calls_tool(self): + """Full flow: provider calls file_search function → search → follow-up → OpenAI output.""" + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + first_resp = self._make_mock_responses_api_response(include_function_call=True) + final_resp = self._make_mock_responses_api_response(text="Deep research enables multi-step queries.") + + search_result = MagicMock() + search_result.file_id = "file-xyz" + search_result.filename = "research.pdf" + search_result.score = 0.95 + search_result.content = [{"type": "text", "text": "deep research context..."}] + + mock_search_response = MagicMock() + mock_search_response.data = [search_result] + + with patch( + "litellm.responses.file_search.emulated_handler._call_aresponses", + new=AsyncMock(side_effect=[first_resp, final_resp]), + ), patch( + "litellm.vector_stores.main.asearch", + new=AsyncMock(return_value=mock_search_response), + ): + result = await aresponses_with_emulated_file_search( + input="What is deep research?", + model="anthropic/claude-3-5-sonnet", + tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}], + ) + + # output[0] is file_search_call, output[1] is message + # ResponsesAPIResponse converts dicts to Pydantic objects — use attribute access + def _get(item, key): + return item[key] if isinstance(item, dict) else getattr(item, key, None) + + assert _get(result.output[0], "type") == "file_search_call" + assert _get(result.output[0], "status") == "completed" + assert _get(result.output[1], "type") == "message" + content0 = _get(result.output[1], "content")[0] + assert "Deep research" in _get(content0, "text") + annotations = _get(content0, "annotations") + assert any(_get(a, "file_id") == "file-xyz" for a in annotations) + + @pytest.mark.asyncio + async def test_H12_emulated_flow_provider_answers_without_tool_call(self): + """If provider answers directly (no tool call), still return OpenAI format.""" + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + direct_resp = self._make_mock_responses_api_response(text="I already know the answer.") + + with patch( + "litellm.responses.file_search.emulated_handler._call_aresponses", + new=AsyncMock(return_value=direct_resp), + ): + result = await aresponses_with_emulated_file_search( + input="What is 2+2?", + model="anthropic/claude-3-5-sonnet", + tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}], + ) + + def _get(item, key): + return item[key] if isinstance(item, dict) else getattr(item, key, None) + + assert _get(result.output[0], "type") == "file_search_call" + assert _get(result.output[1], "type") == "message" + assert "I already know" in _get(_get(result.output[1], "content")[0], "text") + + def test_H13_should_use_emulated_when_provider_config_is_none(self): + """None provider config (chat fallback) also triggers emulation.""" + from litellm.responses.file_search.emulated_handler import ( + should_use_emulated_file_search, + ) + + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + assert should_use_emulated_file_search(tools, None) is True From 1d6c55de50c0aa6e6fce706632bc61a957bfa95e Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 11:45:08 +0530 Subject: [PATCH 027/539] docs: add e2e testing tutorial for file_search Responses API Covers both paths: - Native passthrough (OpenAI/Azure): create vector store, run via SDK and proxy - Emulated fallback (Anthropic/any): register managed store, run via SDK and proxy Includes output format validation script and troubleshooting section. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../tutorials/file_search_responses_api.md | 325 ++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 docs/my-website/docs/tutorials/file_search_responses_api.md diff --git a/docs/my-website/docs/tutorials/file_search_responses_api.md b/docs/my-website/docs/tutorials/file_search_responses_api.md new file mode 100644 index 0000000000..9c8a773f20 --- /dev/null +++ b/docs/my-website/docs/tutorials/file_search_responses_api.md @@ -0,0 +1,325 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# File Search in the Responses API — E2E Testing Guide + +This tutorial walks you through end-to-end testing of the `file_search` tool in LiteLLM's Responses API. +Two paths are covered: + +| Path | When it runs | What LiteLLM does | +|---|---|---| +| **Native passthrough** | Provider natively supports `file_search` (OpenAI, Azure) | Decodes unified vector store ID → forwards to provider as-is | +| **Emulated fallback** | Provider doesn't support `file_search` (Anthropic, Bedrock, etc.) | Converts to a function tool → intercepts tool call → runs vector search → synthesizes OpenAI-format output | + +--- + +## Prerequisites + +```bash +pip install 'litellm[proxy]' +export OPENAI_API_KEY="sk-..." # for native path +export ANTHROPIC_API_KEY="sk-ant-..." # for emulated path +``` + +--- + +## Path 1: Native Passthrough (OpenAI) + +OpenAI natively handles `file_search`. LiteLLM decodes any unified vector store ID and forwards the request unchanged. + +### Step 1 — Create a vector store and upload a file + +```python +from openai import OpenAI + +client = OpenAI() # direct OpenAI call to set up test data + +# Upload a file +with open("knowledge.txt", "w") as f: + f.write("LiteLLM is a unified interface for 100+ LLM providers. " + "It supports chat completions, responses API, embeddings, and more.") + +file = client.files.create(file=open("knowledge.txt", "rb"), purpose="assistants") +print("file_id:", file.id) + +# Create a vector store and attach the file +vs = client.vector_stores.create(name="litellm-test-store") +client.vector_stores.files.create(vector_store_id=vs.id, file_id=file.id) +print("vector_store_id:", vs.id) +``` + +### Step 2 — Run file search via LiteLLM Python SDK + +```python showLineNumbers title="Native file_search via LiteLLM SDK" +import litellm + +response = litellm.responses( + model="openai/gpt-4.1", + input="What does LiteLLM support?", + tools=[{ + "type": "file_search", + "vector_store_ids": ["vs_abc123"] # replace with your vector_store_id + }], +) + +for item in response.output: + if item.type == "file_search_call": + print("Queries run:", item.queries) + print("Status:", item.status) + elif item.type == "message": + for block in item.content: + print("\nAnswer:", block.text) + for ann in block.annotations: + print(f" ↳ Citation: {ann.filename} (file_id={ann.file_id})") +``` + +**Expected output:** +``` +Queries run: ['What does LiteLLM support?'] +Status: completed + +Answer: LiteLLM is a unified interface for 100+ LLM providers... + ↳ Citation: knowledge.txt (file_id=file-xxxx) +``` + +### Step 3 — Run via LiteLLM Proxy + +Start the proxy: + +```bash title="config.yaml" +# config.yaml +model_list: + - model_name: gpt-4.1 + litellm_params: + model: openai/gpt-4.1 + api_key: os.environ/OPENAI_API_KEY +``` + +```bash +litellm --config config.yaml +``` + +Call the proxy: + +```python showLineNumbers title="Native file_search via LiteLLM Proxy" +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:4000", api_key="any") + +response = client.responses.create( + model="gpt-4.1", + input="What does LiteLLM support?", + tools=[{"type": "file_search", "vector_store_ids": ["vs_abc123"]}], +) + +for item in response.output: + print(item.type, getattr(item, "queries", getattr(item, "content", ""))) +``` + +--- + +## Path 2: Emulated Fallback (Anthropic / any non-native provider) + +When you use a provider that doesn't natively support `file_search`, LiteLLM: +1. Converts the `file_search` tool to a function tool (`litellm_file_search`). +2. Lets the provider call the function with a natural-language query. +3. Runs your vector store search internally. +4. Feeds results back and makes a follow-up call. +5. Returns the final answer in OpenAI's `file_search_call` + `message` format. + +### Step 1 — Register a LiteLLM-managed vector store + +LiteLLM's vector store registry lets you configure any supported vector store backend (OpenAI, Pinecone, Milvus, Qdrant, etc.): + +```python showLineNumbers title="Register vector store via LiteLLM Proxy API" +import requests + +# Register the vector store with LiteLLM Proxy +resp = requests.post( + "http://localhost:4000/v1/vector_stores/new", + headers={"Authorization": "Bearer sk-your-proxy-key"}, + json={ + "vector_store_id": "my-openai-vs", # your logical name + "custom_llm_provider": "openai", + "vector_store_name": "litellm-test-store", + "litellm_params": { + "api_key": "sk-..." # provider API key (or use credentials in config.yaml) + }, + }, +) +print(resp.json()) +# Returns: {"vector_store_id": "bGl0ZWxsbV9wcm94eToB..."} ← LiteLLM unified ID +``` + +:::tip +Save the returned `vector_store_id` — this is the **LiteLLM-managed unified ID** that encodes the provider routing. Pass this in `vector_store_ids` and LiteLLM will decode it automatically. +::: + +### Step 2 — Run file search via LiteLLM SDK (emulated) + +```python showLineNumbers title="Emulated file_search with Anthropic" +import litellm + +# Use the unified vector_store_id returned by /v1/vector_stores/new +UNIFIED_VS_ID = "bGl0ZWxsbV9wcm94eToB..." + +response = litellm.responses( + model="anthropic/claude-sonnet-4-5", + input="What does LiteLLM support?", + tools=[{ + "type": "file_search", + "vector_store_ids": [UNIFIED_VS_ID] + }], +) + +for item in response.output: + if item.type == "file_search_call": + print("Queries run:", item.queries) + elif item.type == "message": + for block in item.content: + print("\nAnswer:", block.text) + for ann in block.annotations: + print(f" ↳ Citation: {ann.filename}") +``` + +LiteLLM automatically detects that Anthropic doesn't support `file_search` natively and routes through the emulated handler. + +### Step 3 — Run via LiteLLM Proxy (emulated) + +```bash title="config.yaml" +model_list: + - model_name: claude-sonnet + litellm_params: + model: anthropic/claude-sonnet-4-5 + api_key: os.environ/ANTHROPIC_API_KEY +``` + +```python showLineNumbers title="Emulated file_search via LiteLLM Proxy" +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:4000", api_key="sk-your-proxy-key") + +response = client.responses.create( + model="claude-sonnet", + input="What does LiteLLM support?", + tools=[{ + "type": "file_search", + "vector_store_ids": ["bGl0ZWxsbV9wcm94eToB..."] # unified ID + }], +) + +for item in response.output: + if hasattr(item, "type"): + if item.type == "file_search_call": + print("Queries:", item.queries) + elif item.type == "message": + print("Answer:", item.content[0].text) +``` + +--- + +## Validating the Output Format + +Regardless of which path ran, the response always follows the OpenAI Responses API format: + +```json +{ + "output": [ + { + "type": "file_search_call", + "id": "fs_abc123", + "status": "completed", + "queries": ["What does LiteLLM support?"], + "search_results": null + }, + { + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "LiteLLM is a unified interface...", + "annotations": [ + { + "type": "file_citation", + "index": 150, + "file_id": "file-xxxx", + "filename": "knowledge.txt" + } + ] + } + ] + } + ] +} +``` + +**Validation script:** + +```python showLineNumbers title="Validate response structure" +def validate_file_search_response(response): + """Assert that response follows OpenAI file_search output format.""" + output = response.output + assert len(output) >= 2, "Expected at least 2 output items" + + # First item: file_search_call + fs_call = output[0] + fs_type = fs_call["type"] if isinstance(fs_call, dict) else fs_call.type + assert fs_type == "file_search_call", f"Expected file_search_call, got {fs_type}" + + fs_status = fs_call["status"] if isinstance(fs_call, dict) else fs_call.status + assert fs_status == "completed" + + # Second item: message + msg = output[1] + msg_type = msg["type"] if isinstance(msg, dict) else msg.type + assert msg_type == "message" + + content = msg["content"] if isinstance(msg, dict) else msg.content + assert len(content) > 0 + text_block = content[0] + text = text_block["text"] if isinstance(text_block, dict) else text_block.text + assert isinstance(text, str) and len(text) > 0 + + print("✅ Response structure valid") + print(f" Queries: {fs_call['queries'] if isinstance(fs_call, dict) else fs_call.queries}") + print(f" Answer length: {len(text)} chars") + annotations = text_block["annotations"] if isinstance(text_block, dict) else text_block.annotations + print(f" Citations: {len(annotations)}") + +validate_file_search_response(response) +``` + +--- + +## Troubleshooting + +### `UnsupportedParamsError` is raised + +This means `file_search` was passed to a provider that doesn't support it natively, but the emulated fallback couldn't route either. Check: +- The model string is correct (e.g. `anthropic/claude-sonnet-4-5`, not just `claude-sonnet-4-5`) +- The `custom_llm_provider` is resolved — LiteLLM needs it to look up the provider config + +### Vector store search returns no results + +- Confirm the vector store ID exists and has files attached +- For LiteLLM-managed stores, ensure the file has finished processing (`status: completed`) +- Try a broader query string + +### `403 Access denied` on vector store + +The calling team doesn't have access to the vector store. Either: +- The vector store was created by a different team +- Use a proxy admin key to bypass team-scoped access control + +### Empty `annotations` in emulated mode + +The emulated path adds `file_citation` annotations only when the vector store search result includes a `file_id`. If your vector store provider doesn't return file-level metadata in search results, annotations will be empty — the answer text will still be populated. + +--- + +## What to check next + +- [File Search reference in Responses API docs](/docs/response_api#file-search-vector-stores) — full API reference +- [Vector Store management](/docs/vector_store_files) — create and manage vector stores +- [Managed vector stores](/docs/providers/bedrock_vector_store) — provider-specific setup From 289f698a3c2e6763a0f87b74ba89a90e2c153701 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 14:36:31 +0530 Subject: [PATCH 028/539] fix(responses): align emulated file_search output and multi-query behavior Ensure non-OpenAI emulated file_search matches native Responses output by populating search_results (when requested), fixing TypedDict field access, and supporting multi-query searches from tool calls. Made-with: Cursor --- .../responses/file_search/emulated_handler.py | 197 +++++++++++++----- 1 file changed, 144 insertions(+), 53 deletions(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index b50ed6f399..13a11cc5d1 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -53,34 +53,43 @@ def should_use_emulated_file_search( def _build_function_tool(vector_store_ids: List[str]) -> Dict[str, Any]: """ - Create an OpenAI function-tool definition that describes file search. - The function accepts a natural-language query; LiteLLM runs the actual - vector search against the configured vector stores. + Create a Responses API function-tool definition that describes file search. + The function accepts one or more natural-language queries (like OpenAI's native + file_search); LiteLLM runs the actual vector search against the configured + vector stores. + + Note: Uses Responses API format (name/description/parameters at top level), + NOT Chat Completion format (nested under "function"), so that the + LiteLLMCompletionResponsesConfig transformation picks up name and description. """ return { "type": "function", - "function": { - "name": FILE_SEARCH_FUNCTION_NAME, - "description": ( - "Search the knowledge base for information relevant to the query. " - "Use this whenever you need to look up specific facts, documents, " - "or content from the vector store." - ), - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query to look up in the vector store.", - }, - "vector_store_id": { - "type": "string", - "description": "ID of the vector store to search.", - "enum": vector_store_ids, - }, + "name": FILE_SEARCH_FUNCTION_NAME, + "description": ( + "Search the knowledge base for information relevant to the query. " + "Use this whenever you need to look up specific facts, documents, " + "or content from the vector store. You can provide multiple queries " + "to search for different aspects of the information." + ), + "parameters": { + "type": "object", + "properties": { + "queries": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "One or more search queries to look up in the vector store. " + "Multiple queries help find comprehensive information from " + "different angles." + ), + }, + "vector_store_id": { + "type": "string", + "description": "ID of the vector store to search.", + "enum": vector_store_ids, }, - "required": ["query"], }, + "required": ["queries"], }, } @@ -117,37 +126,44 @@ def _replace_file_search_tools( # --------------------------------------------------------------------------- async def _run_vector_searches( - query: str, + queries: List[str], vector_store_ids: List[str], fallback_vector_store_ids: List[str], ) -> Tuple[List[str], List[VectorStoreSearchResult]]: """ - Run `asearch` against all vector stores and collect results. + Run `asearch` against all vector stores for all queries and collect results. + + Args: + queries: List of search queries to execute (like OpenAI's multi-query approach) + vector_store_ids: Specific vector store IDs to search + fallback_vector_store_ids: Fallback IDs if vector_store_ids is empty Returns: (queries_list, combined_results) """ import litellm.vector_stores.main as vs_main - queries: List[str] = [query] all_results: List[VectorStoreSearchResult] = [] - ids_to_search = vector_store_ids or fallback_vector_store_ids - for vs_id in ids_to_search: - try: - response = await vs_main.asearch( - vector_store_id=vs_id, - query=query, - ) - results_data = response.get("data") if isinstance(response, dict) else getattr(response, "data", None) - if results_data: - all_results.extend(results_data) - except Exception as exc: - verbose_logger.warning( - "file_search emulated: search failed for vector_store_id='%s': %s", - vs_id, - exc, - ) + + # Execute each query against all vector stores + for query in queries: + for vs_id in ids_to_search: + try: + response = await vs_main.asearch( + vector_store_id=vs_id, + query=query, + ) + results_data = response.get("data") if isinstance(response, dict) else getattr(response, "data", None) + if results_data: + all_results.extend(results_data) + except Exception as exc: + verbose_logger.warning( + "file_search emulated: search failed for query='%s', vector_store_id='%s': %s", + query, + vs_id, + exc, + ) return queries, all_results @@ -156,6 +172,13 @@ async def _run_vector_searches( # Result formatting # --------------------------------------------------------------------------- +def _get_field(result: Any, key: str, default: Any = None) -> Any: + """Read a field from either a dict/TypedDict or an attribute-based object.""" + if isinstance(result, dict): + return result.get(key, default) + return getattr(result, key, default) + + def _format_search_results_as_tool_output( results: List[VectorStoreSearchResult], ) -> str: @@ -165,10 +188,10 @@ def _format_search_results_as_tool_output( parts: List[str] = [] for i, result in enumerate(results, 1): - score = getattr(result, "score", None) - file_id = getattr(result, "file_id", None) - filename = getattr(result, "filename", None) - content_items = getattr(result, "content", []) or [] + score = _get_field(result, "score") + file_id = _get_field(result, "file_id") + filename = _get_field(result, "filename") + content_items = _get_field(result, "content") or [] text_chunks = [ c.get("text", "") if isinstance(c, dict) else getattr(c, "text", "") for c in content_items @@ -189,17 +212,57 @@ def _format_search_results_as_tool_output( return "\n\n".join(parts) +def _build_search_results_for_include( + results: List[VectorStoreSearchResult], +) -> List[Dict[str, Any]]: + """ + Convert VectorStoreSearchResult objects to the format expected in + file_search_call.search_results (mirrors OpenAI's include= format). + """ + formatted: List[Dict[str, Any]] = [] + for result in results: + content_items = _get_field(result, "content") or [] + text_chunks = [ + c.get("text", "") if isinstance(c, dict) else getattr(c, "text", "") + for c in content_items + ] + text = " ".join(t for t in text_chunks if t) + formatted.append( + { + "file_id": _get_field(result, "file_id") or "", + "filename": _get_field(result, "filename") or "", + "score": _get_field(result, "score"), + "text": text, + "attributes": _get_field(result, "attributes") or {}, + } + ) + return formatted + + def _build_file_search_call_output( call_id: str, queries: List[str], + results: Optional[List[VectorStoreSearchResult]] = None, + include_search_results: bool = False, ) -> Dict[str, Any]: - """Build the file_search_call output item (mirrors OpenAI's format).""" + """Build the file_search_call output item (mirrors OpenAI's format). + + Args: + call_id: Unique ID for this file_search call. + queries: List of search queries used. + results: The raw search results (used when include_search_results=True). + include_search_results: Populate search_results when the caller passed + ``include=["file_search_call.results"]``. + """ + search_results = None + if include_search_results and results: + search_results = _build_search_results_for_include(results) return { "type": "file_search_call", "id": call_id, "status": "completed", "queries": queries, - "search_results": None, + "search_results": search_results, } @@ -216,8 +279,8 @@ def _build_file_citation_annotations( seen_file_ids: set = set() for result in results: - file_id = getattr(result, "file_id", None) - filename = getattr(result, "filename", None) + file_id = _get_field(result, "file_id") + filename = _get_field(result, "filename") if not file_id or file_id in seen_file_ids: continue seen_file_ids.add(file_id) @@ -312,6 +375,19 @@ async def aresponses_with_emulated_file_search( Replaces file_search tools with a function tool, intercepts the tool call, runs vector search, and synthesizes an OpenAI-format response. """ + # Determine whether caller wants search_results populated in the output. + _include: List[str] = list(kwargs.get("include") or []) + _include_search_results = "file_search_call.results" in _include + + # Disable streaming for emulated file_search (not yet supported) + _original_stream = kwargs.get("stream") + if _original_stream: + verbose_logger.debug( + "Streaming is not yet supported for emulated file_search. " + "Disabling stream for this request." + ) + kwargs = {**kwargs, "stream": False} + # 1. Replace file_search tools with function tool transformed_tools, all_vs_ids = _replace_file_search_tools(tools) @@ -349,7 +425,12 @@ async def aresponses_with_emulated_file_search( response_text = _extract_text_from_responses_output(first_response) return _synthesize_responses_api_response( original_response=first_response, - file_search_call_output=_build_file_search_call_output(call_id, [str(input)]), + file_search_call_output=_build_file_search_call_output( + call_id=call_id, + queries=[str(input)], + results=None, + include_search_results=False, + ), message_output=_build_message_output(response_text, []), ) @@ -372,12 +453,20 @@ async def aresponses_with_emulated_file_search( except json.JSONDecodeError: args = {} - query = args.get("query", str(input)) + # Extract queries array (OpenAI-style multi-query support) + queries_from_call = args.get("queries") + if not queries_from_call: + # Fallback: check for single "query" field (backward compat) + single_query = args.get("query") + queries_from_call = [single_query] if single_query else [str(input)] + elif not isinstance(queries_from_call, list): + queries_from_call = [str(queries_from_call)] + vs_id_arg = args.get("vector_store_id") vs_ids_for_call = [vs_id_arg] if vs_id_arg else all_vs_ids queries, results = await _run_vector_searches( - query=query, + queries=queries_from_call, vector_store_ids=vs_ids_for_call, fallback_vector_store_ids=all_vs_ids, ) @@ -426,6 +515,8 @@ async def aresponses_with_emulated_file_search( file_search_call_output=_build_file_search_call_output( call_id=file_search_call_id, queries=all_queries or [str(input)], + results=all_results, + include_search_results=_include_search_results, ), message_output=_build_message_output(response_text, all_results), ) From e6d5e3af02cd838e8b173a44e90e00f53a3f7c73 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 14:36:38 +0530 Subject: [PATCH 029/539] fix(responses): avoid sending empty tools list in follow-up turns Drop tools=[] from transformed chat-completion requests so providers like Anthropic return normal assistant text after tool_result turns. Made-with: Cursor --- .../litellm_completion_transformation/transformation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index 71fa88fb75..3467dbbd27 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -229,9 +229,13 @@ class LiteLLMCompletionResponsesConfig: if litellm_logging_obj: litellm_logging_obj.stream_options = stream_options - # only pass non-None values + # only pass non-None / non-empty values + # Explicitly exclude an empty tools list — sending tools=[] to providers + # like Anthropic in a tool_result conversation makes them return empty content. litellm_completion_request = { - k: v for k, v in litellm_completion_request.items() if v is not None + k: v + for k, v in litellm_completion_request.items() + if v is not None and not (k == "tools" and v == []) } return litellm_completion_request From 82c2dce6b994433b9630d9546c1bf8c51cf719f9 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 14:54:53 +0530 Subject: [PATCH 030/539] docs(file_search): streamline guide with usage tabs, architecture, and Q&A Replace duplicate path-by-path sections with a single usage-first doc format that includes SDK/Proxy tabs, an architecture diagram, and a focused Q&A section. Made-with: Cursor --- .../tutorials/file_search_responses_api.md | 351 +++++++----------- 1 file changed, 141 insertions(+), 210 deletions(-) diff --git a/docs/my-website/docs/tutorials/file_search_responses_api.md b/docs/my-website/docs/tutorials/file_search_responses_api.md index 9c8a773f20..3c642148df 100644 --- a/docs/my-website/docs/tutorials/file_search_responses_api.md +++ b/docs/my-website/docs/tutorials/file_search_responses_api.md @@ -3,7 +3,17 @@ import TabItem from '@theme/TabItem'; # File Search in the Responses API — E2E Testing Guide -This tutorial walks you through end-to-end testing of the `file_search` tool in LiteLLM's Responses API. +LiteLLM now supports `file_search` in the Responses API across both: +- providers that support it natively (like OpenAI / Azure), and +- providers that do not (like Anthropic, Bedrock, and other non-native providers) via emulation. + +This page is both a quick blog-style overview and an end-to-end implementation guide. + +## What this is + +`file_search` lets models retrieve grounded context from your vector stores and answer with citations. +LiteLLM keeps one OpenAI-compatible output shape while routing requests through either native passthrough or an emulated fallback. + Two paths are covered: | Path | When it runs | What LiteLLM does | @@ -13,6 +23,117 @@ Two paths are covered: --- +## Usage + + + + +### 1. Setup `config.yaml` + +```yaml title="config.yaml" +model_list: + - model_name: gpt-4.1 + litellm_params: + model: openai/gpt-4.1 + api_key: os.environ/OPENAI_API_KEY + + - model_name: claude-sonnet + litellm_params: + model: anthropic/claude-sonnet-4-5 + api_key: os.environ/ANTHROPIC_API_KEY +``` + +### 2. Start the proxy + +```bash +litellm --config config.yaml +``` + +### 3. Call Responses API with `file_search` + +```python title="Proxy call" +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:4000", api_key="sk-your-proxy-key") + +response = client.responses.create( + model="claude-sonnet", # swap to "gpt-4.1" for native path + input="What does LiteLLM support?", + tools=[{ + "type": "file_search", + "vector_store_ids": ["vs_abc123"] + }], + include=["file_search_call.results"], +) + +print(response.output) +``` + + + + +### 1. Install + set keys + +```bash +pip install litellm +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +``` + +### 2. Call Responses API with `file_search` + +```python title="SDK call" +import litellm + +response = litellm.responses( + model="anthropic/claude-sonnet-4-5", # swap to openai/gpt-4.1 for native path + input="What does LiteLLM support?", + tools=[{ + "type": "file_search", + "vector_store_ids": ["vs_abc123"] + }], + include=["file_search_call.results"], +) + +print(response.output) +``` + + + + +### Behavior Matrix + +| Path | SDK model | Proxy model | Behavior | +|---|---|---|---| +| Native passthrough | `openai/gpt-4.1` | `gpt-4.1` | Provider executes native `file_search` | +| Emulated fallback | `anthropic/claude-sonnet-4-5` | `claude-sonnet` | LiteLLM converts to function tool and synthesizes OpenAI-format output | + +--- + +## Architecture Diagram + +```mermaid +flowchart TD + A[Client SDK or Proxy Caller] --> B[LiteLLM Responses API] + B --> C{Provider supports native file_search?} + + C -->|Yes| D[Native passthrough path] + D --> D1[Decode unified vector_store_id if needed] + D1 --> D2[Forward request to provider unchanged] + D2 --> D3[Provider performs file_search] + D3 --> Z[OpenAI-compatible output] + + C -->|No| E[Emulated fallback path] + E --> E1[Convert file_search to litellm_file_search function tool] + E1 --> E2[First model call returns tool call with one or more queries] + E2 --> E3[LiteLLM executes vector search for each query] + E3 --> E4[Second model call with tool_result context] + E4 --> E5[Synthesize file_search_call + message + citations] + E5 --> Z[OpenAI-compatible output] +``` + +--- + ## Prerequisites ```bash @@ -23,200 +144,7 @@ export ANTHROPIC_API_KEY="sk-ant-..." # for emulated path --- -## Path 1: Native Passthrough (OpenAI) - -OpenAI natively handles `file_search`. LiteLLM decodes any unified vector store ID and forwards the request unchanged. - -### Step 1 — Create a vector store and upload a file - -```python -from openai import OpenAI - -client = OpenAI() # direct OpenAI call to set up test data - -# Upload a file -with open("knowledge.txt", "w") as f: - f.write("LiteLLM is a unified interface for 100+ LLM providers. " - "It supports chat completions, responses API, embeddings, and more.") - -file = client.files.create(file=open("knowledge.txt", "rb"), purpose="assistants") -print("file_id:", file.id) - -# Create a vector store and attach the file -vs = client.vector_stores.create(name="litellm-test-store") -client.vector_stores.files.create(vector_store_id=vs.id, file_id=file.id) -print("vector_store_id:", vs.id) -``` - -### Step 2 — Run file search via LiteLLM Python SDK - -```python showLineNumbers title="Native file_search via LiteLLM SDK" -import litellm - -response = litellm.responses( - model="openai/gpt-4.1", - input="What does LiteLLM support?", - tools=[{ - "type": "file_search", - "vector_store_ids": ["vs_abc123"] # replace with your vector_store_id - }], -) - -for item in response.output: - if item.type == "file_search_call": - print("Queries run:", item.queries) - print("Status:", item.status) - elif item.type == "message": - for block in item.content: - print("\nAnswer:", block.text) - for ann in block.annotations: - print(f" ↳ Citation: {ann.filename} (file_id={ann.file_id})") -``` - -**Expected output:** -``` -Queries run: ['What does LiteLLM support?'] -Status: completed - -Answer: LiteLLM is a unified interface for 100+ LLM providers... - ↳ Citation: knowledge.txt (file_id=file-xxxx) -``` - -### Step 3 — Run via LiteLLM Proxy - -Start the proxy: - -```bash title="config.yaml" -# config.yaml -model_list: - - model_name: gpt-4.1 - litellm_params: - model: openai/gpt-4.1 - api_key: os.environ/OPENAI_API_KEY -``` - -```bash -litellm --config config.yaml -``` - -Call the proxy: - -```python showLineNumbers title="Native file_search via LiteLLM Proxy" -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:4000", api_key="any") - -response = client.responses.create( - model="gpt-4.1", - input="What does LiteLLM support?", - tools=[{"type": "file_search", "vector_store_ids": ["vs_abc123"]}], -) - -for item in response.output: - print(item.type, getattr(item, "queries", getattr(item, "content", ""))) -``` - ---- - -## Path 2: Emulated Fallback (Anthropic / any non-native provider) - -When you use a provider that doesn't natively support `file_search`, LiteLLM: -1. Converts the `file_search` tool to a function tool (`litellm_file_search`). -2. Lets the provider call the function with a natural-language query. -3. Runs your vector store search internally. -4. Feeds results back and makes a follow-up call. -5. Returns the final answer in OpenAI's `file_search_call` + `message` format. - -### Step 1 — Register a LiteLLM-managed vector store - -LiteLLM's vector store registry lets you configure any supported vector store backend (OpenAI, Pinecone, Milvus, Qdrant, etc.): - -```python showLineNumbers title="Register vector store via LiteLLM Proxy API" -import requests - -# Register the vector store with LiteLLM Proxy -resp = requests.post( - "http://localhost:4000/v1/vector_stores/new", - headers={"Authorization": "Bearer sk-your-proxy-key"}, - json={ - "vector_store_id": "my-openai-vs", # your logical name - "custom_llm_provider": "openai", - "vector_store_name": "litellm-test-store", - "litellm_params": { - "api_key": "sk-..." # provider API key (or use credentials in config.yaml) - }, - }, -) -print(resp.json()) -# Returns: {"vector_store_id": "bGl0ZWxsbV9wcm94eToB..."} ← LiteLLM unified ID -``` - -:::tip -Save the returned `vector_store_id` — this is the **LiteLLM-managed unified ID** that encodes the provider routing. Pass this in `vector_store_ids` and LiteLLM will decode it automatically. -::: - -### Step 2 — Run file search via LiteLLM SDK (emulated) - -```python showLineNumbers title="Emulated file_search with Anthropic" -import litellm - -# Use the unified vector_store_id returned by /v1/vector_stores/new -UNIFIED_VS_ID = "bGl0ZWxsbV9wcm94eToB..." - -response = litellm.responses( - model="anthropic/claude-sonnet-4-5", - input="What does LiteLLM support?", - tools=[{ - "type": "file_search", - "vector_store_ids": [UNIFIED_VS_ID] - }], -) - -for item in response.output: - if item.type == "file_search_call": - print("Queries run:", item.queries) - elif item.type == "message": - for block in item.content: - print("\nAnswer:", block.text) - for ann in block.annotations: - print(f" ↳ Citation: {ann.filename}") -``` - -LiteLLM automatically detects that Anthropic doesn't support `file_search` natively and routes through the emulated handler. - -### Step 3 — Run via LiteLLM Proxy (emulated) - -```bash title="config.yaml" -model_list: - - model_name: claude-sonnet - litellm_params: - model: anthropic/claude-sonnet-4-5 - api_key: os.environ/ANTHROPIC_API_KEY -``` - -```python showLineNumbers title="Emulated file_search via LiteLLM Proxy" -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:4000", api_key="sk-your-proxy-key") - -response = client.responses.create( - model="claude-sonnet", - input="What does LiteLLM support?", - tools=[{ - "type": "file_search", - "vector_store_ids": ["bGl0ZWxsbV9wcm94eToB..."] # unified ID - }], -) - -for item in response.output: - if hasattr(item, "type"): - if item.type == "file_search_call": - print("Queries:", item.queries) - elif item.type == "message": - print("Answer:", item.content[0].text) -``` - ---- +## Example response shape ## Validating the Output Format @@ -292,29 +220,32 @@ validate_file_search_response(response) --- -## Troubleshooting +## Q&A -### `UnsupportedParamsError` is raised +### Q: Why do I see `UnsupportedParamsError`? -This means `file_search` was passed to a provider that doesn't support it natively, but the emulated fallback couldn't route either. Check: -- The model string is correct (e.g. `anthropic/claude-sonnet-4-5`, not just `claude-sonnet-4-5`) -- The `custom_llm_provider` is resolved — LiteLLM needs it to look up the provider config +A: This usually means `file_search` was passed to a provider that does not support it natively and emulation could not route correctly. +Check: +- The model string is valid (for example, `anthropic/claude-sonnet-4-5`). +- `custom_llm_provider` resolves correctly so LiteLLM can load the provider config. -### Vector store search returns no results +### Q: Why does vector search return no results? -- Confirm the vector store ID exists and has files attached -- For LiteLLM-managed stores, ensure the file has finished processing (`status: completed`) -- Try a broader query string +A: Common causes: +- The vector store ID is wrong or has no files attached. +- In LiteLLM-managed stores, file ingestion is not complete (`status != completed`). +- The query is too narrow; try a broader query. -### `403 Access denied` on vector store +### Q: Why am I getting `403 Access denied` on vector store calls? -The calling team doesn't have access to the vector store. Either: -- The vector store was created by a different team -- Use a proxy admin key to bypass team-scoped access control +A: The caller does not have access to that vector store. +- The store may belong to another team. +- Use an admin/proxy key if your setup requires cross-team access. -### Empty `annotations` in emulated mode +### Q: Why are `annotations` empty in emulated mode? -The emulated path adds `file_citation` annotations only when the vector store search result includes a `file_id`. If your vector store provider doesn't return file-level metadata in search results, annotations will be empty — the answer text will still be populated. +A: `file_citation` annotations require `file_id` metadata in search results. +If your vector backend does not return file-level metadata, the answer text is still generated but citations can be empty. --- From e22d9031e0a50ea522fa4d549bd1e8e1e9e88028 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 14:59:55 +0530 Subject: [PATCH 031/539] docs(response_api): move file_search details to dedicated tutorial Replace inline file_search documentation in response_api.md with a canonical link and add the new tutorial to sidebars so users discover the usage-first guide. Made-with: Cursor --- docs/my-website/docs/response_api.md | 127 +-------------------------- docs/my-website/sidebars.js | 1 + 2 files changed, 3 insertions(+), 125 deletions(-) diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index 183b339900..3df84faa66 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -1558,132 +1558,9 @@ curl -X POST "http://localhost:4000/v1/responses" \ ## File Search (Vector Stores) -The **file_search** tool lets the model search your vector stores and cite retrieved content in its answer (OpenAI Responses API format). Pass `tools=[{"type": "file_search", "vector_store_ids": [...]}]`. The response includes a `file_search_call` output item and `file_citation` annotations on the answer text. +For full `file_search` usage (native + emulated fallback), SDK/Proxy examples, architecture diagram, and Q&A, see: -**Supported providers:** `openai`, `azure` (native). Other providers will receive an `UnsupportedParamsError` until the emulated-fallback path is available. - -:::note -If you are using LiteLLM-managed vector stores (created via `/v1/vector_stores`), pass the LiteLLM vector store ID directly — LiteLLM automatically decodes it to the provider-native ID before sending the request. -::: - -### Python SDK - -```python showLineNumbers title="File search with LiteLLM Python SDK" -import litellm - -response = litellm.responses( - model="openai/gpt-4.1", - input="What is deep research?", - tools=[{ - "type": "file_search", - "vector_store_ids": ["vs_abc123"] # native or LiteLLM-managed vector store ID - }], -) - -# Output contains a file_search_call item followed by the answer with citations -for item in response.output: - if item.type == "file_search_call": - print("Queries:", item.queries) - elif item.type == "message": - for block in item.content: - print(block.text) - for ann in block.annotations: - print(f" ↳ {ann.filename} (file_id={ann.file_id})") -``` - -#### Response Format - -```json -{ - "output": [ - { - "type": "file_search_call", - "id": "fs_67c09ccea8c48191ade9367e3ba71515", - "status": "completed", - "queries": ["What is deep research?"], - "search_results": null - }, - { - "id": "msg_67c09cd3091c819185af2be5d13d87de", - "type": "message", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": "Deep research is a capability that allows for extensive inquiry ...", - "annotations": [ - { - "type": "file_citation", - "index": 992, - "file_id": "file-2dtbBZdjtDKS8eqWxqbgDi", - "filename": "deep_research_blog.pdf" - } - ] - } - ] - } - ] -} -``` - -### LiteLLM Proxy (AI Gateway) - -**OpenAI Python SDK (proxy as base_url):** - -```python showLineNumbers title="File search via LiteLLM Proxy" -from openai import OpenAI - -client = OpenAI( - base_url="http://localhost:4000", - api_key="your-proxy-api-key", -) - -response = client.responses.create( - model="openai/gpt-4.1", - input="Summarise the Q3 earnings report.", - tools=[{ - "type": "file_search", - "vector_store_ids": ["vs_abc123"] - }], -) -``` - -**curl:** - -```bash title="File search via curl to LiteLLM Proxy" -curl -X POST "http://localhost:4000/v1/responses" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer your-proxy-api-key" \ - -d '{ - "model": "openai/gpt-4.1", - "input": "Summarise the Q3 earnings report.", - "tools": [{"type": "file_search", "vector_store_ids": ["vs_abc123"]}] - }' -``` - -### Using LiteLLM-Managed Vector Stores - -If you created a vector store through LiteLLM (`POST /v1/vector_stores/new`), use the returned `vector_store_id` directly. LiteLLM decodes the unified ID to the provider-native vector store ID automatically. - -```python showLineNumbers title="File search with LiteLLM-managed vector store" -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:4000", api_key="your-proxy-api-key") - -# vector_store_id returned by POST /v1/vector_stores/new -managed_vs_id = "bGl0ZWxsbV9wcm94eTo..." # LiteLLM-managed ID - -response = client.responses.create( - model="openai/gpt-4.1", - input="What does the documentation say about authentication?", - tools=[{"type": "file_search", "vector_store_ids": [managed_vs_id]}], -) -``` - -LiteLLM will: -1. Verify the calling team has access to the vector store. -2. Decode the managed ID to the provider-native vector store ID. -3. Forward the request to the provider unchanged. +- [`File Search in the Responses API — E2E Testing Guide`](/docs/tutorials/file_search_responses_api) ## Session Management diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1362745a91..4a8d67409a 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -1054,6 +1054,7 @@ const sidebars = { label: "AI Coding Tools (OpenWebUI, Claude Code, Gemini CLI, OpenAI Codex, etc.)", href: "/docs/ai_tools", }, + "tutorials/file_search_responses_api", "tutorials/anthropic_file_usage", "tutorials/default_team_self_serve", "tutorials/msft_sso", From 729f7d48eb0f575bc11b0e13b870f96fe522aaf5 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 15:10:46 +0530 Subject: [PATCH 032/539] fix(file_search): address greptile review on follow-up calls and tests Include all function_call items when building emulated follow-up input and update tests to assert real emulated routing + Responses-format function tool structure. Made-with: Cursor --- .../responses/file_search/emulated_handler.py | 24 ++-- .../llms/test_file_search_responses.py | 110 +++++++++++++----- 2 files changed, 101 insertions(+), 33 deletions(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 13a11cc5d1..2ac1d3dea6 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -481,18 +481,28 @@ async def aresponses_with_emulated_file_search( } ) - # 5. Build follow-up input: original messages + assistant's tool call + tool results + # 5. Build follow-up input: original messages + all assistant tool calls + tool results original_input_items = list(input) if isinstance(input, (list, tuple)) else [{"role": "user", "content": str(input)}] - follow_up_input = ( - original_input_items - + [ + follow_up_function_calls: List[Dict[str, Any]] = [] + for tc in file_search_calls: + if isinstance(tc, dict): + tc_call_id = tc.get("call_id") or tc.get("id") or file_search_call_id + tc_args = tc.get("arguments") or "{}" + else: + tc_call_id = getattr(tc, "call_id", None) or getattr(tc, "id", file_search_call_id) + tc_args = getattr(tc, "arguments", "{}") or "{}" + follow_up_function_calls.append( { "type": "function_call", "name": FILE_SEARCH_FUNCTION_NAME, - "call_id": file_search_calls[0].get("call_id") if isinstance(file_search_calls[0], dict) else getattr(file_search_calls[0], "call_id", file_search_call_id), - "arguments": file_search_calls[0].get("arguments") if isinstance(file_search_calls[0], dict) else getattr(file_search_calls[0], "arguments", "{}"), + "call_id": tc_call_id, + "arguments": tc_args, } - ] + ) + + follow_up_input = ( + original_input_items + + follow_up_function_calls + tool_results ) diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index 6f91b5386e..f1eb265054 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -179,7 +179,7 @@ class TestSupportsNativeFileSearch: # --------------------------------------------------------------------------- class TestFileSearchGuardInResponsesMain: - """Tests for _has_file_search_tool helper and the UnsupportedParamsError guard.""" + """Tests for _has_file_search_tool helper and emulated routing guard.""" def test_has_file_search_tool_true(self): from litellm.responses.main import _has_file_search_tool @@ -210,40 +210,98 @@ class TestFileSearchGuardInResponsesMain: assert config.supports_native_file_search() # No exception expected — the guard would pass. - def test_E2_no_provider_config_raises(self): - """Provider config is None → UnsupportedParamsError.""" - from litellm.exceptions import UnsupportedParamsError - from litellm.responses.main import _has_file_search_tool + def test_E2_no_provider_config_routes_to_emulated_handler(self): + """Provider config None + file_search should route to emulated handler.""" + from litellm.responses.main import responses tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] - assert _has_file_search_tool(tools) + logging_obj = MagicMock() + expected = {"ok": True} - with pytest.raises(UnsupportedParamsError): - if _has_file_search_tool(tools) and True: # config is None - raise UnsupportedParamsError( - message="Provider does not support file_search", - llm_provider="anthropic", - model="claude-3", - ) + with ( + patch( + "litellm.responses.main.litellm.get_llm_provider", + return_value=("claude-sonnet-4-5", "anthropic", None, None), + ), + patch( + "litellm.responses.main.update_responses_input_with_model_file_ids", + return_value="hello", + ), + patch( + "litellm.responses.main.update_responses_tools_with_model_file_ids", + return_value=tools, + ), + patch( + "litellm.responses.main.ProviderConfigManager.get_provider_responses_api_config", + return_value=None, + ), + patch( + "litellm.responses.main.ResponsesAPIRequestUtils.get_requested_response_api_optional_param", + return_value={}, + ), + patch("litellm.responses.main.run_async_function", return_value=expected) as run_async_mock, + ): + result = responses( + input="hello", + model="anthropic/claude-sonnet-4-5", + tools=tools, + litellm_logging_obj=logging_obj, + litellm_call_id="call-123", + ) - def test_E3_non_native_provider_config_raises(self): - """Provider config.supports_native_file_search() == False → error.""" - from litellm.exceptions import UnsupportedParamsError + assert result == expected + assert run_async_mock.called + routed_func = run_async_mock.call_args.args[0] + assert routed_func.__name__ == "aresponses_with_emulated_file_search" + + def test_E3_non_native_provider_config_routes_to_emulated_handler(self): + """Non-native provider config + file_search should route to emulated handler.""" from litellm.llms.base_llm.responses.transformation import ( BaseResponsesAPIConfig, ) + from litellm.responses.main import responses + tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] + logging_obj = MagicMock() + expected = {"ok": True} mock_config = MagicMock(spec=BaseResponsesAPIConfig) mock_config.supports_native_file_search.return_value = False - tools = [{"type": "file_search"}] - with pytest.raises(UnsupportedParamsError): - if not mock_config.supports_native_file_search(): - raise UnsupportedParamsError( - message="Provider does not support file_search", - llm_provider="anthropic", - model="claude-3", - ) + with ( + patch( + "litellm.responses.main.litellm.get_llm_provider", + return_value=("claude-sonnet-4-5", "anthropic", None, None), + ), + patch( + "litellm.responses.main.update_responses_input_with_model_file_ids", + return_value="hello", + ), + patch( + "litellm.responses.main.update_responses_tools_with_model_file_ids", + return_value=tools, + ), + patch( + "litellm.responses.main.ProviderConfigManager.get_provider_responses_api_config", + return_value=mock_config, + ), + patch( + "litellm.responses.main.ResponsesAPIRequestUtils.get_requested_response_api_optional_param", + return_value={}, + ), + patch("litellm.responses.main.run_async_function", return_value=expected) as run_async_mock, + ): + result = responses( + input="hello", + model="anthropic/claude-sonnet-4-5", + tools=tools, + litellm_logging_obj=logging_obj, + litellm_call_id="call-123", + ) + + assert result == expected + assert run_async_mock.called + routed_func = run_async_mock.call_args.args[0] + assert routed_func.__name__ == "aresponses_with_emulated_file_search" def test_E4_no_file_search_tools_no_error(self): """No file_search tool in request → guard never fires.""" @@ -473,9 +531,9 @@ class TestEmulatedFileSearchHandler: assert vs_ids == ["vs_abc", "vs_def"] assert len(new_tools) == 1 assert new_tools[0]["type"] == "function" - assert new_tools[0]["function"]["name"] == "litellm_file_search" + assert new_tools[0]["name"] == "litellm_file_search" # Both store IDs appear in the enum - enum_ids = new_tools[0]["function"]["parameters"]["properties"]["vector_store_id"]["enum"] + enum_ids = new_tools[0]["parameters"]["properties"]["vector_store_id"]["enum"] assert "vs_abc" in enum_ids assert "vs_def" in enum_ids From 77a5093ce287db8bdd32ca753a2dd84977cd066e Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 15:20:56 +0530 Subject: [PATCH 033/539] fix(file_search): preserve emulated response params and hidden metadata Forward explicit responses() params on emulated file search calls and preserve hidden params on synthesized responses so callback billing/logging context is retained. Made-with: Cursor --- litellm/responses/file_search/emulated_handler.py | 9 +++++---- litellm/responses/main.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 2ac1d3dea6..85a8f78998 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -339,9 +339,7 @@ def _synthesize_responses_api_response( output[0] = file_search_call item output[1] = message item (with citations) """ - import litellm - - return ResponsesAPIResponse( + synthesized = ResponsesAPIResponse( id=getattr(original_response, "id", f"resp_{uuid.uuid4().hex}"), object="response", created_at=getattr(original_response, "created_at", int(time.time())), @@ -351,6 +349,9 @@ def _synthesize_responses_api_response( usage=getattr(original_response, "usage", None), error=None, ) + if hasattr(original_response, "_hidden_params"): + synthesized._hidden_params = getattr(original_response, "_hidden_params") + return synthesized # --------------------------------------------------------------------------- @@ -513,7 +514,7 @@ async def aresponses_with_emulated_file_search( input=follow_up_input, model=model, tools=None, # no tools needed for the answer step - **{k: v for k, v in kwargs.items() if k not in ("tools",)}, + **kwargs, ), ) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 5438676c5f..761be5a2e1 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -742,6 +742,7 @@ def responses( "previous_response_id": previous_response_id, "reasoning": reasoning, "store": store, + "background": background, "stream": stream, "temperature": temperature, "text": text, @@ -749,6 +750,10 @@ def responses( "top_p": top_p, "truncation": truncation, "user": user, + "service_tier": service_tier, + "safety_identifier": safety_identifier, + "text_format": text_format, + "allowed_openai_params": allowed_openai_params, "extra_headers": extra_headers, "extra_query": extra_query, "extra_body": extra_body, From 5692db812389efd118a14fb403b1da1e278c56ee Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 15:33:11 +0530 Subject: [PATCH 034/539] fix(file_search): address latest greptile feedback Strip internal logging ids from emulated sub-calls, dedupe included search_results by file_id, clean unused imports, and add unit coverage for dedupe behavior. Made-with: Cursor --- .../responses/file_search/emulated_handler.py | 12 +++++++---- litellm/responses/main.py | 3 ++- .../llms/test_file_search_responses.py | 21 +++++++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 85a8f78998..5d2c23fdfa 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -14,9 +14,7 @@ Flow: import json import time import uuid -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, cast - -import httpx +from typing import Any, Dict, Iterable, List, Optional, Tuple, cast from litellm._logging import verbose_logger from litellm.types.llms.openai import ResponsesAPIResponse @@ -220,7 +218,13 @@ def _build_search_results_for_include( file_search_call.search_results (mirrors OpenAI's include= format). """ formatted: List[Dict[str, Any]] = [] + seen_file_ids: set = set() for result in results: + file_id = _get_field(result, "file_id") or "" + if file_id and file_id in seen_file_ids: + continue + if file_id: + seen_file_ids.add(file_id) content_items = _get_field(result, "content") or [] text_chunks = [ c.get("text", "") if isinstance(c, dict) else getattr(c, "text", "") @@ -229,7 +233,7 @@ def _build_search_results_for_include( text = " ".join(t for t in text_chunks if t) formatted.append( { - "file_id": _get_field(result, "file_id") or "", + "file_id": file_id, "filename": _get_field(result, "filename") or "", "score": _get_field(result, "score"), "text": text, diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 761be5a2e1..4404e6b366 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -732,6 +732,7 @@ def responses( aresponses_with_emulated_file_search, ) + _internal_skip = {"litellm_logging_obj", "litellm_call_id", "aresponses"} emulated_kwargs = { "include": include, "instructions": instructions, @@ -759,7 +760,7 @@ def responses( "extra_body": extra_body, "timeout": timeout, "custom_llm_provider": custom_llm_provider, - **kwargs, + **{k: v for k, v in kwargs.items() if k not in _internal_skip}, } if _is_async: return aresponses_with_emulated_file_search( diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index f1eb265054..63599781c1 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -659,6 +659,27 @@ class TestEmulatedFileSearchHandler: annotations = _build_file_citation_annotations([r1, r2], "text") assert len(annotations) == 1 + def test_H14_include_search_results_dedupes_by_file_id(self): + from litellm.responses.file_search.emulated_handler import ( + _build_search_results_for_include, + ) + + r1, r2 = MagicMock(), MagicMock() + r1.file_id = "file-abc" + r1.filename = "doc.pdf" + r1.score = 0.9 + r1.attributes = {} + r1.content = [{"type": "text", "text": "first hit"}] + r2.file_id = "file-abc" # same file appears for a second query + r2.filename = "doc.pdf" + r2.score = 0.85 + r2.attributes = {} + r2.content = [{"type": "text", "text": "second hit"}] + + search_results = _build_search_results_for_include([r1, r2]) + assert len(search_results) == 1 + assert search_results[0]["file_id"] == "file-abc" + # --- End-to-end (mocked) --- @pytest.mark.asyncio From 8eb8756e844049dd9d1f9b32d1f8f27058e28d0f Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 15:55:41 +0530 Subject: [PATCH 035/539] fix: Preserve annotations in Azure AI Foundry Agents responses Azure AI Agents with Grounding (e.g., Bing Search) include annotations (citation URLs) in responses, but the handler was dropping them during transformation. This fix: - Extracts annotations from text content in agent responses - Transforms them to OpenAI-compatible ChatCompletionAnnotation format - Passes annotations through all completion paths (sync, async, streaming) - Handles both polling and SSE streaming responses Fixes #19126 Co-Authored-By: Claude Haiku 4.5 --- litellm/llms/azure_ai/agents/handler.py | 110 +++++++++++++++---- tests/llm_translation/test_azure_agents.py | 117 ++++++++++++++++++++- 2 files changed, 207 insertions(+), 20 deletions(-) diff --git a/litellm/llms/azure_ai/agents/handler.py b/litellm/llms/azure_ai/agents/handler.py index 9eeec7f4e3..5b779acb0d 100644 --- a/litellm/llms/azure_ai/agents/handler.py +++ b/litellm/llms/azure_ai/agents/handler.py @@ -97,14 +97,63 @@ class AzureAIAgentsHandler: # ------------------------------------------------------------------------- # Response Helpers # ------------------------------------------------------------------------- - def _extract_content_from_messages(self, messages_data: dict) -> str: - """Extract assistant content from the messages response.""" + def _extract_content_from_messages( + self, messages_data: dict + ) -> Tuple[str, Optional[List[Dict[str, Any]]]]: + """Extract assistant content and annotations from the messages response. + + Returns (content, annotations) where annotations is a list of + OpenAI-compatible ChatCompletionAnnotation dicts, or None. + """ for msg in messages_data.get("data", []): if msg.get("role") == "assistant": for content_item in msg.get("content", []): if content_item.get("type") == "text": - return content_item.get("text", {}).get("value", "") - return "" + text_obj = content_item.get("text", {}) + content = text_obj.get("value", "") + raw_annotations = text_obj.get("annotations") + annotations = self._transform_annotations( + raw_annotations + ) + return content, annotations + return "", None + + def _transform_annotations( + self, + raw_annotations: Optional[List[Dict[str, Any]]], + ) -> Optional[List[Dict[str, Any]]]: + """Transform Azure AI Foundry annotations to OpenAI-compatible format. + + Azure AI returns annotations like: + {"type": "url_citation", "text": "[1]", "start_index": 10, + "end_index": 13, "url_citation": {"url": "...", "title": "..."}} + + OpenAI expects: + {"type": "url_citation", "url_citation": {"url": "...", "title": "...", + "start_index": 10, "end_index": 13}} + """ + if not raw_annotations: + return None + + result: List[Dict[str, Any]] = [] + for ann in raw_annotations: + ann_type = ann.get("type", "url_citation") + if ann_type == "url_citation": + url_citation = dict(ann.get("url_citation", {})) + # Azure puts start/end_index at annotation level; OpenAI + # expects them inside url_citation + if "start_index" in ann and "start_index" not in url_citation: + url_citation["start_index"] = ann["start_index"] + if "end_index" in ann and "end_index" not in url_citation: + url_citation["end_index"] = ann["end_index"] + result.append( + {"type": "url_citation", "url_citation": url_citation} + ) + else: + # Pass through unknown annotation types as-is + result.append(ann) + + return result if result else None def _build_model_response( self, @@ -113,15 +162,23 @@ class AzureAIAgentsHandler: model_response: ModelResponse, thread_id: str, messages: List[Dict[str, Any]], + annotations: Optional[List[Dict[str, Any]]] = None, ) -> ModelResponse: """Build the ModelResponse from agent output.""" from litellm.types.utils import Choices, Message, Usage + message_kwargs: Dict[str, Any] = { + "content": content, + "role": "assistant", + } + if annotations: + message_kwargs["annotations"] = annotations + model_response.choices = [ Choices( finish_reason="stop", index=0, - message=Message(content=content, role="assistant"), + message=Message(**message_kwargs), ) ] model_response.model = model @@ -250,7 +307,7 @@ class AzureAIAgentsHandler: ) # Execute the agent flow - thread_id, content = self._execute_agent_flow_sync( + thread_id, content, annotations = self._execute_agent_flow_sync( make_request=make_request, api_base=api_base, api_version=api_version, @@ -261,7 +318,7 @@ class AzureAIAgentsHandler: ) return self._build_model_response( - model, content, model_response, thread_id, messages + model, content, model_response, thread_id, messages, annotations ) def _execute_agent_flow_sync( @@ -273,8 +330,8 @@ class AzureAIAgentsHandler: thread_id: Optional[str], messages: List[Dict[str, Any]], optional_params: dict, - ) -> Tuple[str, str]: - """Execute the agent flow synchronously. Returns (thread_id, content).""" + ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + """Execute the agent flow synchronously. Returns (thread_id, content, annotations).""" # Step 1: Create thread if not provided if not thread_id: @@ -347,8 +404,8 @@ class AzureAIAgentsHandler: ) self._check_response(response, [200], "Failed to get messages") - content = self._extract_content_from_messages(response.json()) - return thread_id, content + content, annotations = self._extract_content_from_messages(response.json()) + return thread_id, content, annotations # ------------------------------------------------------------------------- # Async Completion @@ -399,7 +456,7 @@ class AzureAIAgentsHandler: ) # Execute the agent flow - thread_id, content = await self._execute_agent_flow_async( + thread_id, content, annotations = await self._execute_agent_flow_async( make_request=make_request, api_base=api_base, api_version=api_version, @@ -410,7 +467,7 @@ class AzureAIAgentsHandler: ) return self._build_model_response( - model, content, model_response, thread_id, messages + model, content, model_response, thread_id, messages, annotations ) async def _execute_agent_flow_async( @@ -422,8 +479,8 @@ class AzureAIAgentsHandler: thread_id: Optional[str], messages: List[Dict[str, Any]], optional_params: dict, - ) -> Tuple[str, str]: - """Execute the agent flow asynchronously. Returns (thread_id, content).""" + ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + """Execute the agent flow asynchronously. Returns (thread_id, content, annotations).""" # Step 1: Create thread if not provided if not thread_id: @@ -496,8 +553,8 @@ class AzureAIAgentsHandler: ) self._check_response(response, [200], "Failed to get messages") - content = self._extract_content_from_messages(response.json()) - return thread_id, content + content, annotations = self._extract_content_from_messages(response.json()) + return thread_id, content, annotations # ------------------------------------------------------------------------- # Streaming Completion (Native SSE) @@ -585,6 +642,7 @@ class AzureAIAgentsHandler: response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" created = int(time.time()) thread_id = None + collected_annotations: Optional[List[Dict[str, Any]]] = None current_event = None @@ -600,6 +658,9 @@ class AzureAIAgentsHandler: if data_str == "[DONE]": # Send final chunk with finish_reason + final_delta_kwargs: Dict[str, Any] = {"content": None} + if collected_annotations: + final_delta_kwargs["annotations"] = collected_annotations final_chunk = ModelResponseStream( id=response_id, created=created, @@ -609,7 +670,7 @@ class AzureAIAgentsHandler: StreamingChoices( finish_reason="stop", index=0, - delta=Delta(content=None), + delta=Delta(**final_delta_kwargs), ) ], ) @@ -628,6 +689,19 @@ class AzureAIAgentsHandler: thread_id = data["id"] verbose_logger.debug(f"Stream created thread: {thread_id}") + # Extract annotations from completed message + if current_event == "thread.message.completed": + for content_item in data.get("content", []): + if content_item.get("type") == "text": + raw_annotations = content_item.get("text", {}).get( + "annotations" + ) + transformed = self._transform_annotations( + raw_annotations + ) + if transformed: + collected_annotations = transformed + # Process message deltas - this is where the actual content comes if current_event == "thread.message.delta": delta_content = data.get("delta", {}).get("content", []) diff --git a/tests/llm_translation/test_azure_agents.py b/tests/llm_translation/test_azure_agents.py index 66a46d5338..19ce49a3bc 100644 --- a/tests/llm_translation/test_azure_agents.py +++ b/tests/llm_translation/test_azure_agents.py @@ -343,13 +343,126 @@ def test_azure_ai_agents_extract_content_from_messages(): ] } - content = handler._extract_content_from_messages(messages_data) + content, annotations = handler._extract_content_from_messages(messages_data) assert content == "The answer is 100." + assert annotations is None # Test empty response empty_data = {"data": []} - content = handler._extract_content_from_messages(empty_data) + content, annotations = handler._extract_content_from_messages(empty_data) assert content == "" + assert annotations is None + + +def test_azure_ai_agents_extract_content_with_annotations(): + """ + Test that annotations (e.g., Bing Search citations) are extracted from + Azure Agents message responses and transformed to OpenAI-compatible format. + + Ref: https://github.com/BerriAI/litellm/issues/19126 + """ + from litellm.llms.azure_ai.agents.handler import AzureAIAgentsHandler + + handler = AzureAIAgentsHandler() + + messages_data = { + "data": [ + { + "id": "msg_abc", + "role": "assistant", + "content": [ + { + "type": "text", + "text": { + "value": "According to sources [1], the answer is yes.", + "annotations": [ + { + "type": "url_citation", + "text": "[1]", + "start_index": 22, + "end_index": 25, + "url_citation": { + "url": "https://example.com/source", + "title": "Example Source" + } + } + ] + } + } + ] + } + ] + } + + content, annotations = handler._extract_content_from_messages(messages_data) + assert content == "According to sources [1], the answer is yes." + assert annotations is not None + assert len(annotations) == 1 + assert annotations[0]["type"] == "url_citation" + assert annotations[0]["url_citation"]["url"] == "https://example.com/source" + assert annotations[0]["url_citation"]["title"] == "Example Source" + # start/end_index should be moved into url_citation for OpenAI compatibility + assert annotations[0]["url_citation"]["start_index"] == 22 + assert annotations[0]["url_citation"]["end_index"] == 25 + + +def test_azure_ai_agents_build_model_response_with_annotations(): + """ + Test that _build_model_response includes annotations in the Message object. + """ + from litellm.llms.azure_ai.agents.handler import AzureAIAgentsHandler + from litellm.types.utils import ModelResponse + + handler = AzureAIAgentsHandler() + model_response = ModelResponse() + + annotations = [ + { + "type": "url_citation", + "url_citation": { + "url": "https://example.com", + "title": "Example", + "start_index": 0, + "end_index": 5, + }, + } + ] + + result = handler._build_model_response( + model="azure_ai/agents/asst_123", + content="Hello [1]", + model_response=model_response, + thread_id="thread_abc", + messages=[{"role": "user", "content": "test"}], + annotations=annotations, + ) + + assert result.choices[0].message.content == "Hello [1]" + assert result.choices[0].message.annotations is not None + assert len(result.choices[0].message.annotations) == 1 + assert result.choices[0].message.annotations[0]["type"] == "url_citation" + + +def test_azure_ai_agents_build_model_response_without_annotations(): + """ + Test that _build_model_response works correctly without annotations. + """ + from litellm.llms.azure_ai.agents.handler import AzureAIAgentsHandler + from litellm.types.utils import ModelResponse + + handler = AzureAIAgentsHandler() + model_response = ModelResponse() + + result = handler._build_model_response( + model="azure_ai/agents/asst_123", + content="Hello", + model_response=model_response, + thread_id="thread_abc", + messages=[{"role": "user", "content": "test"}], + ) + + assert result.choices[0].message.content == "Hello" + assert getattr(result.choices[0].message, "annotations", None) is None @pytest.mark.asyncio From a286050293b448bdf8eb3d33755df9ce6b2915d9 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 17:28:36 +0530 Subject: [PATCH 036/539] Add basic gpt-5.4 mini and nano entry in model map --- ...odel_prices_and_context_window_backup.json | 68 +++++++++++++++++++ model_prices_and_context_window.json | 68 +++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 6786fc3359..81b5c690be 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18747,6 +18747,74 @@ "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": true }, + "gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_none_reasoning_effort": true, + "supports_xhigh_reasoning_effort": true + }, + "gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "openai", + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": true + }, "gpt-5-pro": { "input_cost_per_token": 1.5e-05, "input_cost_per_token_batches": 7.5e-06, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 6786fc3359..81b5c690be 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18747,6 +18747,74 @@ "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": true }, + "gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_none_reasoning_effort": true, + "supports_xhigh_reasoning_effort": true + }, + "gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "openai", + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": true + }, "gpt-5-pro": { "input_cost_per_token": 1.5e-05, "input_cost_per_token_batches": 7.5e-06, From 8a8047e5190fc7c3c0bbca692a19bef2a184eadc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 17:32:58 +0530 Subject: [PATCH 037/539] Add all missing entries in model entries --- ...odel_prices_and_context_window_backup.json | 30 +++++++++++++++++++ model_prices_and_context_window.json | 30 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 81b5c690be..4a627d4aa1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18749,13 +18749,27 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, + "cache_read_input_token_cost_above_272k_tokens": 1.5e-07, + "cache_read_input_token_cost_flex": 3.9e-08, + "cache_read_input_token_cost_priority": 1.5e-07, + "cache_read_input_token_cost_above_272k_tokens_priority": 3e-07, "input_cost_per_token": 7.5e-07, + "input_cost_per_token_above_272k_tokens": 1.5e-06, + "input_cost_per_token_flex": 3.75e-07, + "input_cost_per_token_batches": 3.75e-07, + "input_cost_per_token_priority": 1.5e-06, + "input_cost_per_token_above_272k_tokens_priority": 3e-06, "litellm_provider": "openai", "max_input_tokens": 1050000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 4.5e-06, + "output_cost_per_token_above_272k_tokens": 6.75e-06, + "output_cost_per_token_flex": 2.25e-06, + "output_cost_per_token_batches": 2.25e-06, + "output_cost_per_token_priority": 6.75e-06, + "output_cost_per_token_above_272k_tokens_priority": 1.0125e-05, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18777,19 +18791,34 @@ "supports_response_schema": true, "supports_system_messages": true, "supports_tool_choice": true, + "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, + "cache_read_input_token_cost_above_272k_tokens": 4e-08, + "cache_read_input_token_cost_flex": 1.04e-08, + "cache_read_input_token_cost_priority": 4e-08, + "cache_read_input_token_cost_above_272k_tokens_priority": 8e-08, "input_cost_per_token": 2e-07, + "input_cost_per_token_above_272k_tokens": 4e-07, + "input_cost_per_token_flex": 1e-07, + "input_cost_per_token_batches": 1e-07, + "input_cost_per_token_priority": 4e-07, + "input_cost_per_token_above_272k_tokens_priority": 8e-07, "litellm_provider": "openai", "max_input_tokens": 128000, "max_output_tokens": 16384, "max_tokens": 16384, "mode": "chat", "output_cost_per_token": 1.25e-06, + "output_cost_per_token_above_272k_tokens": 1.875e-06, + "output_cost_per_token_flex": 6.25e-07, + "output_cost_per_token_batches": 6.25e-07, + "output_cost_per_token_priority": 1.875e-06, + "output_cost_per_token_above_272k_tokens_priority": 2.8125e-06, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18811,6 +18840,7 @@ "supports_response_schema": true, "supports_system_messages": true, "supports_tool_choice": true, + "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": true diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 81b5c690be..4a627d4aa1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18749,13 +18749,27 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, + "cache_read_input_token_cost_above_272k_tokens": 1.5e-07, + "cache_read_input_token_cost_flex": 3.9e-08, + "cache_read_input_token_cost_priority": 1.5e-07, + "cache_read_input_token_cost_above_272k_tokens_priority": 3e-07, "input_cost_per_token": 7.5e-07, + "input_cost_per_token_above_272k_tokens": 1.5e-06, + "input_cost_per_token_flex": 3.75e-07, + "input_cost_per_token_batches": 3.75e-07, + "input_cost_per_token_priority": 1.5e-06, + "input_cost_per_token_above_272k_tokens_priority": 3e-06, "litellm_provider": "openai", "max_input_tokens": 1050000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 4.5e-06, + "output_cost_per_token_above_272k_tokens": 6.75e-06, + "output_cost_per_token_flex": 2.25e-06, + "output_cost_per_token_batches": 2.25e-06, + "output_cost_per_token_priority": 6.75e-06, + "output_cost_per_token_above_272k_tokens_priority": 1.0125e-05, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18777,19 +18791,34 @@ "supports_response_schema": true, "supports_system_messages": true, "supports_tool_choice": true, + "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, + "cache_read_input_token_cost_above_272k_tokens": 4e-08, + "cache_read_input_token_cost_flex": 1.04e-08, + "cache_read_input_token_cost_priority": 4e-08, + "cache_read_input_token_cost_above_272k_tokens_priority": 8e-08, "input_cost_per_token": 2e-07, + "input_cost_per_token_above_272k_tokens": 4e-07, + "input_cost_per_token_flex": 1e-07, + "input_cost_per_token_batches": 1e-07, + "input_cost_per_token_priority": 4e-07, + "input_cost_per_token_above_272k_tokens_priority": 8e-07, "litellm_provider": "openai", "max_input_tokens": 128000, "max_output_tokens": 16384, "max_tokens": 16384, "mode": "chat", "output_cost_per_token": 1.25e-06, + "output_cost_per_token_above_272k_tokens": 1.875e-06, + "output_cost_per_token_flex": 6.25e-07, + "output_cost_per_token_batches": 6.25e-07, + "output_cost_per_token_priority": 1.875e-06, + "output_cost_per_token_above_272k_tokens_priority": 2.8125e-06, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18811,6 +18840,7 @@ "supports_response_schema": true, "supports_system_messages": true, "supports_tool_choice": true, + "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": true From 34af653ff8b050a914166a96543c2646e3f8dd7f Mon Sep 17 00:00:00 2001 From: Milan Date: Tue, 17 Mar 2026 14:20:16 +0200 Subject: [PATCH 038/539] docs: note min version for encrypted_content_affinity Document that encrypted_content_affinity requires LiteLLM >= 1.82.1 to prevent /responses invalid_encrypted_content when routing shifts deployments. Made-with: Cursor --- docs/my-website/docs/proxy/config_settings.md | 2 +- docs/my-website/docs/proxy/load_balancing.md | 2 +- docs/my-website/docs/response_api.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index a0e404e3a1..1d6fc1b03b 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -361,7 +361,7 @@ router_settings: | redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** | | cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. | | router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) | -| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Supported: `router_budget_limiting`, `prompt_caching`, `responses_api_deployment_check`, `encrypted_content_affinity`, `deployment_affinity`, `session_affinity`, `forward_client_headers_by_model_group` | +| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Supported: `router_budget_limiting`, `prompt_caching`, `responses_api_deployment_check`, `encrypted_content_affinity` (**requires LiteLLM >= 1.82.1**), `deployment_affinity`, `session_affinity`, `forward_client_headers_by_model_group` | | deployment_affinity_ttl_seconds | int | TTL (seconds) for user-key → deployment affinity mapping when `deployment_affinity` is enabled (configured at Router init / proxy startup). Defaults to `3600` (1 hour). | | ignore_invalid_deployments | boolean | If true, ignores invalid deployments. Default for proxy is True - to prevent invalid models from blocking other models from being loaded. | | search_tools | List[SearchToolTypedDict] | List of search tool configurations for Search API integration. Each tool specifies a search_tool_name and litellm_params with search_provider, api_key, api_base, etc. [Further Docs](../search.md) | diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 5bf39d179f..313df99b25 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -352,7 +352,7 @@ If `order=1` deployment is unavailable (e.g., rate-limited), the router falls ba When load balancing OpenAI's Responses API across deployments with **different API keys** (e.g., different Azure regions or organizations), encrypted content items (like `rs_...` reasoning items) can only be decrypted by the originating API key. -**Solution:** Use the `encrypted_content_affinity` pre-call check to automatically route follow-up requests containing encrypted items to the correct deployment: +**Solution:** Use the `encrypted_content_affinity` pre-call check (**requires LiteLLM >= 1.82.1**) to automatically route follow-up requests containing encrypted items to the correct deployment: ```yaml model_list: diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index fb55ae9f9d..66aa2e1ad9 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -1160,7 +1160,7 @@ follow_up = await router.aresponses( To enable session continuity for Responses API in your LiteLLM proxy, set `optional_pre_call_checks` in your proxy config.yaml. - `responses_api_deployment_check`: high priority routing when `previous_response_id` is provided -- `encrypted_content_affinity`: **[Recommended]** content-aware routing for encrypted items (e.g., `rs_...` reasoning items) +- `encrypted_content_affinity`: **[Recommended]** content-aware routing for encrypted items (e.g., `rs_...` reasoning items) (**requires LiteLLM >= 1.82.1**) - `session_affinity`: sticky sessions based on session id (takes priority over `deployment_affinity`) - `deployment_affinity`: sticky sessions based on user key (applies even without `previous_response_id`) From 464ac7be12e14f61c0b5b677e7f631be0c95890a Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 18:08:07 +0530 Subject: [PATCH 039/539] Fix doc --- docs/my-website/docs/tutorials/file_search_responses_api.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/my-website/docs/tutorials/file_search_responses_api.md b/docs/my-website/docs/tutorials/file_search_responses_api.md index 3c642148df..5ff2adee0a 100644 --- a/docs/my-website/docs/tutorials/file_search_responses_api.md +++ b/docs/my-website/docs/tutorials/file_search_responses_api.md @@ -1,14 +1,12 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# File Search in the Responses API — E2E Testing Guide +# File Search in the Responses API LiteLLM now supports `file_search` in the Responses API across both: - providers that support it natively (like OpenAI / Azure), and - providers that do not (like Anthropic, Bedrock, and other non-native providers) via emulation. -This page is both a quick blog-style overview and an end-to-end implementation guide. - ## What this is `file_search` lets models retrieve grounded context from your vector stores and answer with citations. From 8b7eac5dc93db889a43f7ef233e44472b832b447 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 18:10:24 +0530 Subject: [PATCH 040/539] Fix doc --- .../docs/tutorials/file_search_responses_api.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/my-website/docs/tutorials/file_search_responses_api.md b/docs/my-website/docs/tutorials/file_search_responses_api.md index 5ff2adee0a..9df18b4095 100644 --- a/docs/my-website/docs/tutorials/file_search_responses_api.md +++ b/docs/my-website/docs/tutorials/file_search_responses_api.md @@ -15,11 +15,11 @@ LiteLLM keeps one OpenAI-compatible output shape while routing requests through Two paths are covered: | Path | When it runs | What LiteLLM does | -|---|---|---| +|||| | **Native passthrough** | Provider natively supports `file_search` (OpenAI, Azure) | Decodes unified vector store ID → forwards to provider as-is | | **Emulated fallback** | Provider doesn't support `file_search` (Anthropic, Bedrock, etc.) | Converts to a function tool → intercepts tool call → runs vector search → synthesizes OpenAI-format output | ---- + ## Usage @@ -102,11 +102,11 @@ print(response.output) ### Behavior Matrix | Path | SDK model | Proxy model | Behavior | -|---|---|---|---| +||||| | Native passthrough | `openai/gpt-4.1` | `gpt-4.1` | Provider executes native `file_search` | | Emulated fallback | `anthropic/claude-sonnet-4-5` | `claude-sonnet` | LiteLLM converts to function tool and synthesizes OpenAI-format output | ---- + ## Architecture Diagram @@ -130,7 +130,7 @@ flowchart TD E5 --> Z[OpenAI-compatible output] ``` ---- + ## Prerequisites @@ -140,7 +140,7 @@ export OPENAI_API_KEY="sk-..." # for native path export ANTHROPIC_API_KEY="sk-ant-..." # for emulated path ``` ---- + ## Example response shape @@ -216,7 +216,7 @@ def validate_file_search_response(response): validate_file_search_response(response) ``` ---- + ## Q&A @@ -245,7 +245,7 @@ A: The caller does not have access to that vector store. A: `file_citation` annotations require `file_id` metadata in search results. If your vector backend does not return file-level metadata, the answer text is still generated but citations can be empty. ---- + ## What to check next From 24429227d31be7e94c02d84e2f480fe945be5a1c Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 11:52:29 -0300 Subject: [PATCH 041/539] fix(model-prices): correct supported_regions for Vertex AI DeepSeek models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #23859 - deepseek-v3.2-maas: us-west2 → global (per Google docs) - deepseek-v3.1-maas: us-west2 → us-central1 - deepseek-r1-0528-maas: add supported_regions: us-central1 - deepseek-ocr-maas: add supported_regions: us-central1 --- model_prices_and_context_window.json | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 9b1d81fee4..614e459a1c 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -30431,7 +30431,7 @@ "output_cost_per_token": 5.4e-06, "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models", "supported_regions": [ - "us-west2" + "us-central1" ], "supports_assistant_prefill": true, "supports_function_calling": true, @@ -30451,7 +30451,7 @@ "output_cost_per_token_batches": 8.4e-07, "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models", "supported_regions": [ - "us-west2" + "global" ], "supports_assistant_prefill": true, "supports_function_calling": true, @@ -30472,7 +30472,10 @@ "supports_function_calling": true, "supports_prompt_caching": true, "supports_reasoning": true, - "supports_tool_choice": true + "supports_tool_choice": true, + "supported_regions": [ + "us-central1" + ] }, "vertex_ai/gemini-2.5-flash-image": { "cache_read_input_token_cost": 3e-08, @@ -31092,7 +31095,10 @@ "input_cost_per_token": 3e-07, "output_cost_per_token": 1.2e-06, "ocr_cost_per_page": 0.0003, - "source": "https://cloud.google.com/vertex-ai/pricing" + "source": "https://cloud.google.com/vertex-ai/pricing", + "supported_regions": [ + "us-central1" + ] }, "vertex_ai/openai/gpt-oss-120b-maas": { "input_cost_per_token": 1.5e-07, From d39eac26832f0168be871a6d069b74d11e31b55d Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 12:12:23 -0300 Subject: [PATCH 042/539] fix: move supported_regions before supports_* fields for alphabetical order --- model_prices_and_context_window.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 614e459a1c..091a1f57b0 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -30468,14 +30468,14 @@ "mode": "chat", "output_cost_per_token": 5.4e-06, "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models", + "supported_regions": [ + "us-central1" + ], "supports_assistant_prefill": true, "supports_function_calling": true, "supports_prompt_caching": true, "supports_reasoning": true, - "supports_tool_choice": true, - "supported_regions": [ - "us-central1" - ] + "supports_tool_choice": true }, "vertex_ai/gemini-2.5-flash-image": { "cache_read_input_token_cost": 3e-08, From 3eeb14bf1a17f52382d0e82b58979147e99fd89d Mon Sep 17 00:00:00 2001 From: cohml <62400541+cohml@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:32:01 -0400 Subject: [PATCH 043/539] fix(cache): Fix Redis cluster caching (#23480) * fix redis cluster startup_nodes check order * add tests for redis cluster startup_nodes fix --- litellm/_redis.py | 63 ++++++++++------ tests/test_litellm/test_redis.py | 122 ++++++++++++++++++++++++++++++- 2 files changed, 160 insertions(+), 25 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index b754c1f433..2bf32d71b2 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -222,8 +222,12 @@ def _get_redis_client_logic(**env_overrides): "REDIS_CLUSTER_NODES" ) + # If startup_nodes resolved to None (not set by kwarg or env), remove the key + # entirely so callers can rely on key presence as a reliable cluster-mode signal. if _startup_nodes is not None and isinstance(_startup_nodes, str): redis_kwargs["startup_nodes"] = json.loads(_startup_nodes) + elif _startup_nodes is None: + redis_kwargs.pop("startup_nodes", None) _sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore "REDIS_SENTINEL_NODES" @@ -273,10 +277,14 @@ def _get_redis_client_logic(**env_overrides): redis_kwargs["ssl_ca_certs"] = _gcp_ssl_ca_certs if "url" in redis_kwargs and redis_kwargs["url"] is not None: - redis_kwargs.pop("host", None) - redis_kwargs.pop("port", None) - redis_kwargs.pop("db", None) - redis_kwargs.pop("password", None) + # Only strip host/port/db/password when not routing to a cluster. + # When startup_nodes is also present the cluster path takes priority and + # needs the password for authentication. + if not redis_kwargs.get("startup_nodes"): + redis_kwargs.pop("host", None) + redis_kwargs.pop("port", None) + redis_kwargs.pop("db", None) + redis_kwargs.pop("password", None) elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None: pass elif ( @@ -368,6 +376,10 @@ def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis: def get_redis_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) + + if "startup_nodes" in redis_kwargs: + return init_redis_cluster(redis_kwargs) + if "url" in redis_kwargs and redis_kwargs["url"] is not None: args = _get_redis_url_kwargs() url_kwargs = {} @@ -377,9 +389,6 @@ def get_redis_client(**env_overrides): return redis.Redis.from_url(**url_kwargs) - if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore - return init_redis_cluster(redis_kwargs) - # Check for Redis Sentinel if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: return _init_redis_sentinel(redis_kwargs) @@ -392,21 +401,6 @@ def get_redis_async_client( **env_overrides, ) -> Union[async_redis.Redis, async_redis.RedisCluster]: redis_kwargs = _get_redis_client_logic(**env_overrides) - if "url" in redis_kwargs and redis_kwargs["url"] is not None: - if connection_pool is not None: - return async_redis.Redis(connection_pool=connection_pool) - args = _get_redis_url_kwargs(client=async_redis.Redis.from_url) - url_kwargs = {} - for arg in redis_kwargs: - if arg in args: - url_kwargs[arg] = redis_kwargs[arg] - else: - verbose_logger.debug( - "REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format( - arg - ) - ) - return async_redis.Redis.from_url(**url_kwargs) if "startup_nodes" in redis_kwargs: from redis.cluster import ClusterNode @@ -469,6 +463,22 @@ def get_redis_async_client( return cluster_client + if "url" in redis_kwargs and redis_kwargs["url"] is not None: + if connection_pool is not None: + return async_redis.Redis(connection_pool=connection_pool) + args = _get_redis_url_kwargs(client=async_redis.Redis.from_url) + url_kwargs = {} + for arg in redis_kwargs: + if arg in args: + url_kwargs[arg] = redis_kwargs[arg] + else: + verbose_logger.debug( + "REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format( + arg + ) + ) + return async_redis.Redis.from_url(**url_kwargs) + # Check for Redis Sentinel if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: return _init_async_redis_sentinel(redis_kwargs) @@ -482,9 +492,15 @@ def get_redis_async_client( ) -def get_redis_connection_pool(**env_overrides): +def get_redis_connection_pool( + **env_overrides, +) -> Optional[async_redis.BlockingConnectionPool]: redis_kwargs = _get_redis_client_logic(**env_overrides) verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs) + + if "startup_nodes" in redis_kwargs: + return None + if "url" in redis_kwargs and redis_kwargs["url"] is not None: pool_kwargs = { "timeout": REDIS_CONNECTION_POOL_TIMEOUT, @@ -504,7 +520,6 @@ def get_redis_connection_pool(**env_overrides): connection_class = async_redis.SSLConnection redis_kwargs.pop("ssl", None) redis_kwargs["connection_class"] = connection_class - redis_kwargs.pop("startup_nodes", None) return async_redis.BlockingConnectionPool( timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs ) diff --git a/tests/test_litellm/test_redis.py b/tests/test_litellm/test_redis.py index 4709faea4b..1590719099 100644 --- a/tests/test_litellm/test_redis.py +++ b/tests/test_litellm/test_redis.py @@ -1,7 +1,15 @@ -from litellm._redis import get_redis_url_from_environment, _get_redis_cluster_kwargs, get_redis_async_client +from litellm._redis import ( + get_redis_url_from_environment, + _get_redis_cluster_kwargs, + get_redis_async_client, + get_redis_client, + get_redis_connection_pool, +) +import json import os import pytest from unittest.mock import MagicMock, patch +import redis import redis.asyncio as async_redis def test_get_redis_url_from_environment_single_url(monkeypatch): @@ -167,3 +175,115 @@ def test_get_redis_async_client_without_connection_pool(): # Verify Redis was called without connection_pool in kwargs call_kwargs = mock_redis.call_args[1] assert "connection_pool" not in call_kwargs, "connection_pool should not be in kwargs when not provided" + +@patch("litellm._redis.init_redis_cluster") +def test_sync_client_prefers_cluster_over_url(mock_init_cluster, monkeypatch): + """ + Test get_redis_client returns RedisCluster when startup_nodes is present even if + REDIS_URL is also set. + """ + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster) + + startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}] + get_redis_client(startup_nodes=startup_nodes) + + mock_init_cluster.assert_called_once() + call_kwargs = mock_init_cluster.call_args[0][0] + assert ( + "startup_nodes" in call_kwargs + ), "startup_nodes must be forwarded to init_redis_cluster" + +@patch("litellm._redis.async_redis.RedisCluster") +def test_async_client_prefers_cluster_over_url(mock_cluster_cls, monkeypatch): + """ + Test (1) get_redis_async_client returns async RedisCluster when startup_nodes is present + even if REDIS_URL is also set and (2) startup_nodes is forwarded to RedisCluster. + """ + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + + startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}] + get_redis_async_client(startup_nodes=startup_nodes) + + mock_cluster_cls.assert_called_once() + call_kwargs = mock_cluster_cls.call_args[1] + assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to async RedisCluster" + assert len(call_kwargs["startup_nodes"]) == 1, "should forward exactly 1 cluster node" + + +@patch("litellm._redis.async_redis.RedisCluster") +def test_async_client_prefers_cluster_over_url_via_env_var(mock_cluster_cls, monkeypatch): + """ + Test get_redis_async_client returns async RedisCluster when REDIS_CLUSTER_NODES is set + even if REDIS_URL is also set. + """ + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + monkeypatch.setenv( + "REDIS_CLUSTER_NODES", + json.dumps([{"host": "cluster-node.example.com", "port": 6379}]), + ) + + get_redis_async_client() + + mock_cluster_cls.assert_called_once() + call_kwargs = mock_cluster_cls.call_args[1] + assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to async RedisCluster" + +@patch("litellm._redis.init_redis_cluster") +def test_sync_client_prefers_cluster_over_url_via_env_var(mock_init_cluster, monkeypatch): + """ + Test get_redis_client returns RedisCluster when REDIS_CLUSTER_NODES is set even if + REDIS_URL is also set. + """ + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + monkeypatch.setenv( + "REDIS_CLUSTER_NODES", + json.dumps([{"host": "cluster-node.example.com", "port": 6379}]), + ) + mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster) + + get_redis_client() + + mock_init_cluster.assert_called_once() + call_kwargs = mock_init_cluster.call_args[0][0] + assert "startup_nodes" in call_kwargs, "startup_nodes must be forwarded to init_redis_cluster" + assert len(call_kwargs["startup_nodes"]) == 1 + +@patch("litellm._redis.init_redis_cluster") +def test_sync_client_preserves_password_for_cluster_when_url_also_set(mock_init_cluster, monkeypatch): + """ + Test _get_redis_client_logic does not strip password from redis_kwargs when + startup_nodes is present even if REDIS_URL is also set. + """ + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + monkeypatch.setenv("REDIS_PASSWORD", "secret") + mock_init_cluster.return_value = MagicMock(spec=redis.RedisCluster) + + startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}] + get_redis_client(startup_nodes=startup_nodes) + + mock_init_cluster.assert_called_once() + call_kwargs = mock_init_cluster.call_args[0][0] + assert "password" in call_kwargs, "password must not be stripped when routing to cluster" + assert call_kwargs["password"] == "secret" + + +def test_connection_pool_returns_none_for_cluster(monkeypatch): + """Test get_redis_connection_pool returns None when startup_nodes is present.""" + monkeypatch.setenv("REDIS_URL", "redis://fallback-host:6379") + startup_nodes = [{"host": "cluster-node.example.com", "port": 6379}] + result = get_redis_connection_pool(startup_nodes=startup_nodes) + assert result is None, "connection pool must be None for cluster mode" + + +@patch("litellm._redis.redis.Redis.from_url") +def test_sync_client_url_used_when_no_cluster(mock_from_url, monkeypatch): + """ + Test get_redis_client default to using URL path when no startup_nodes are provided. + """ + monkeypatch.setenv("REDIS_URL", "redis://plain-host:6379") + monkeypatch.delenv("REDIS_CLUSTER_NODES", raising=False) + + get_redis_client() + + mock_from_url.assert_called_once() From bb8e0cd3e21aeea2cf23cbb3de28e37a1a492bfc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 22:47:45 +0530 Subject: [PATCH 044/539] Update 5.4 family values correctly --- ...odel_prices_and_context_window_backup.json | 32 +++++-------------- model_prices_and_context_window.json | 32 +++++-------------- 2 files changed, 16 insertions(+), 48 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 4a627d4aa1..0be7aefa36 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18749,27 +18749,19 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, - "cache_read_input_token_cost_above_272k_tokens": 1.5e-07, - "cache_read_input_token_cost_flex": 3.9e-08, - "cache_read_input_token_cost_priority": 1.5e-07, - "cache_read_input_token_cost_above_272k_tokens_priority": 3e-07, + "cache_read_input_token_cost_flex": 3.75e-08, + "cache_read_input_token_cost_batches": 3.8e-08, "input_cost_per_token": 7.5e-07, - "input_cost_per_token_above_272k_tokens": 1.5e-06, "input_cost_per_token_flex": 3.75e-07, "input_cost_per_token_batches": 3.75e-07, - "input_cost_per_token_priority": 1.5e-06, - "input_cost_per_token_above_272k_tokens_priority": 3e-06, "litellm_provider": "openai", - "max_input_tokens": 1050000, + "max_input_tokens": 272000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 4.5e-06, - "output_cost_per_token_above_272k_tokens": 6.75e-06, "output_cost_per_token_flex": 2.25e-06, "output_cost_per_token_batches": 2.25e-06, - "output_cost_per_token_priority": 6.75e-06, - "output_cost_per_token_above_272k_tokens_priority": 1.0125e-05, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18798,27 +18790,19 @@ }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, - "cache_read_input_token_cost_above_272k_tokens": 4e-08, - "cache_read_input_token_cost_flex": 1.04e-08, - "cache_read_input_token_cost_priority": 4e-08, - "cache_read_input_token_cost_above_272k_tokens_priority": 8e-08, + "cache_read_input_token_cost_flex": 1e-08, + "cache_read_input_token_cost_batches": 1e-08, "input_cost_per_token": 2e-07, - "input_cost_per_token_above_272k_tokens": 4e-07, "input_cost_per_token_flex": 1e-07, "input_cost_per_token_batches": 1e-07, - "input_cost_per_token_priority": 4e-07, - "input_cost_per_token_above_272k_tokens_priority": 8e-07, "litellm_provider": "openai", - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "max_tokens": 16384, + "max_input_tokens": 272000, + "max_output_tokens": 128000, + "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 1.25e-06, - "output_cost_per_token_above_272k_tokens": 1.875e-06, "output_cost_per_token_flex": 6.25e-07, "output_cost_per_token_batches": 6.25e-07, - "output_cost_per_token_priority": 1.875e-06, - "output_cost_per_token_above_272k_tokens_priority": 2.8125e-06, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 4a627d4aa1..0be7aefa36 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18749,27 +18749,19 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, - "cache_read_input_token_cost_above_272k_tokens": 1.5e-07, - "cache_read_input_token_cost_flex": 3.9e-08, - "cache_read_input_token_cost_priority": 1.5e-07, - "cache_read_input_token_cost_above_272k_tokens_priority": 3e-07, + "cache_read_input_token_cost_flex": 3.75e-08, + "cache_read_input_token_cost_batches": 3.8e-08, "input_cost_per_token": 7.5e-07, - "input_cost_per_token_above_272k_tokens": 1.5e-06, "input_cost_per_token_flex": 3.75e-07, "input_cost_per_token_batches": 3.75e-07, - "input_cost_per_token_priority": 1.5e-06, - "input_cost_per_token_above_272k_tokens_priority": 3e-06, "litellm_provider": "openai", - "max_input_tokens": 1050000, + "max_input_tokens": 272000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 4.5e-06, - "output_cost_per_token_above_272k_tokens": 6.75e-06, "output_cost_per_token_flex": 2.25e-06, "output_cost_per_token_batches": 2.25e-06, - "output_cost_per_token_priority": 6.75e-06, - "output_cost_per_token_above_272k_tokens_priority": 1.0125e-05, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", @@ -18798,27 +18790,19 @@ }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, - "cache_read_input_token_cost_above_272k_tokens": 4e-08, - "cache_read_input_token_cost_flex": 1.04e-08, - "cache_read_input_token_cost_priority": 4e-08, - "cache_read_input_token_cost_above_272k_tokens_priority": 8e-08, + "cache_read_input_token_cost_flex": 1e-08, + "cache_read_input_token_cost_batches": 1e-08, "input_cost_per_token": 2e-07, - "input_cost_per_token_above_272k_tokens": 4e-07, "input_cost_per_token_flex": 1e-07, "input_cost_per_token_batches": 1e-07, - "input_cost_per_token_priority": 4e-07, - "input_cost_per_token_above_272k_tokens_priority": 8e-07, "litellm_provider": "openai", - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "max_tokens": 16384, + "max_input_tokens": 272000, + "max_output_tokens": 128000, + "max_tokens": 128000, "mode": "chat", "output_cost_per_token": 1.25e-06, - "output_cost_per_token_above_272k_tokens": 1.875e-06, "output_cost_per_token_flex": 6.25e-07, "output_cost_per_token_batches": 6.25e-07, - "output_cost_per_token_priority": 1.875e-06, - "output_cost_per_token_above_272k_tokens_priority": 2.8125e-06, "supported_endpoints": [ "/v1/chat/completions", "/v1/batch", From 28506edd4951b2120779f0fefc745d36fc637dff Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 22:49:44 +0530 Subject: [PATCH 045/539] Add dodcs for gpt-4.5-mini --- .../blog/gpt_5_4_mini_nano/index.md | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 docs/my-website/blog/gpt_5_4_mini_nano/index.md diff --git a/docs/my-website/blog/gpt_5_4_mini_nano/index.md b/docs/my-website/blog/gpt_5_4_mini_nano/index.md new file mode 100644 index 0000000000..6d7c2b33f7 --- /dev/null +++ b/docs/my-website/blog/gpt_5_4_mini_nano/index.md @@ -0,0 +1,106 @@ +--- +slug: gpt_5_4_mini_nano +title: "Day 0 Support: GPT-5.4-mini and GPT-5.4-nano" +date: 2026-03-17T10:00:00 +authors: + - name: Sameer Kankute + title: SWE @ LiteLLM (LLM Translation) + url: https://www.linkedin.com/in/sameer-kankute/ + image_url: https://pbs.twimg.com/profile_images/2001352686994907136/ONgNuSk5_400x400.jpg + - name: Krrish Dholakia + title: "CEO, LiteLLM" + url: https://www.linkedin.com/in/krish-d/ + image_url: https://pbs.twimg.com/profile_images/1298587542745358340/DZv3Oj-h_400x400.jpg + - name: Ishaan Jaff + title: "CTO, LiteLLM" + url: https://www.linkedin.com/in/reffajnaahsi/ + image_url: https://pbs.twimg.com/profile_images/1613813310264340481/lz54oEiB_400x400.jpg +description: "GPT-5.4-mini and GPT-5.4-nano model support in LiteLLM" +tags: [openai, gpt-5.4-mini, gpt-5.4-nano, completion] +hide_table_of_contents: false +--- + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +LiteLLM now supports GPT-5.4-mini and GPT-5.4-nano — cost-effective models for simple completions and high-throughput workloads. + +:::note +If you're on **v1.82.3-stable** or above, you don't need any update to use these models. +::: + +## Usage + + + + +**1. Setup config.yaml** + +```yaml +model_list: + - model_name: gpt-5.4-mini + litellm_params: + model: openai/gpt-5.4-mini + api_key: os.environ/OPENAI_API_KEY + - model_name: gpt-5.4-nano + litellm_params: + model: openai/gpt-5.4-nano + api_key: os.environ/OPENAI_API_KEY +``` + +**2. Start the proxy** + +```bash +litellm --config /path/to/config.yaml +``` + +**3. Test it** + +```bash +# GPT-5.4-mini +curl -X POST "http://localhost:4000/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $LITELLM_KEY" \ + -d '{ + "model": "gpt-5.4-mini", + "messages": [{"role": "user", "content": "What is the capital of France?"}] + }' + +# GPT-5.4-nano +curl -X POST "http://localhost:4000/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $LITELLM_KEY" \ + -d '{ + "model": "gpt-5.4-nano", + "messages": [{"role": "user", "content": "What is 2 + 2?"}] + }' +``` + + + + +```python +from litellm import completion + +# GPT-5.4-mini +response = completion( + model="openai/gpt-5.4-mini", + messages=[{"role": "user", "content": "What is the capital of France?"}], +) +print(response.choices[0].message.content) + +# GPT-5.4-nano +response = completion( + model="openai/gpt-5.4-nano", + messages=[{"role": "user", "content": "What is 2 + 2?"}], +) +print(response.choices[0].message.content) +``` + + + + +## Notes + +- Both models support function calling, vision, and tool-use — see the [OpenAI provider docs](../../docs/providers/openai) for advanced usage. +- GPT-5.4-nano is the most cost-effective option for simple tasks; GPT-5.4-mini offers a balance of speed and capability. From b0db75df1fb9f5871cdea662035ee719ef0a9149 Mon Sep 17 00:00:00 2001 From: rstar327 Date: Tue, 17 Mar 2026 13:35:07 -0400 Subject: [PATCH 046/539] fix(proxy): convert max_budget to float when set from environment variable (#23855) Fixes #23843 --- litellm/proxy/proxy_server.py | 68 +++++++++---------- .../proxy/test_max_budget_env_var.py | 38 +++++++++++ 2 files changed, 71 insertions(+), 35 deletions(-) create mode 100644 tests/test_litellm/proxy/test_max_budget_env_var.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9c29927c5c..a775b12b20 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -638,9 +638,9 @@ except ImportError: server_root_path = get_server_root_path() _license_check = LicenseCheck() premium_user: bool = _license_check.is_premium() -premium_user_data: Optional[ - "EnterpriseLicenseData" -] = _license_check.airgapped_license_data +premium_user_data: Optional["EnterpriseLicenseData"] = ( + _license_check.airgapped_license_data +) global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) @@ -1523,9 +1523,9 @@ master_key: Optional[str] = None config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional[ - "ClientSession" -] = None # Global shared session for connection reuse +shared_aiohttp_session: Optional["ClientSession"] = ( + None # Global shared session for connection reuse +) user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1533,13 +1533,13 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False -native_background_mode: List[ - str -] = [] # Models that should use native provider background mode instead of polling +native_background_mode: List[str] = ( + [] +) # Models that should use native provider background mode instead of polling polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache user_custom_auth = None user_custom_key_generate = None @@ -1898,9 +1898,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2021,11 +2021,9 @@ def run_ollama_serve(): with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - verbose_proxy_logger.debug( - f""" + verbose_proxy_logger.debug(f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """ - ) + """) def _get_process_rss_mb() -> Optional[float]: @@ -3303,7 +3301,7 @@ class ProxyConfig: async_only_mode=True # only init async clients ), ignore_invalid_deployments=True, # don't raise an error if a deployment is invalid - ) # type:ignore + ) # type: ignore if redis_usage_cache is not None and router.cache.redis_cache is None: router._update_redis_cache(cache=redis_usage_cache) @@ -4952,10 +4950,10 @@ class ProxyConfig: ) try: - guardrails_in_db: List[ - Guardrail - ] = await GuardrailRegistry.get_all_guardrails_from_db( - prisma_client=prisma_client + guardrails_in_db: List[Guardrail] = ( + await GuardrailRegistry.get_all_guardrails_from_db( + prisma_client=prisma_client + ) ) verbose_proxy_logger.debug( "guardrails from the DB %s", str(guardrails_in_db) @@ -5337,9 +5335,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -5357,8 +5355,8 @@ async def initialize( # noqa: PLR0915 litellm.add_function_to_prompt = True dynamic_config["general"]["add_function_to_prompt"] = True if max_budget: # litellm-specific param - litellm.max_budget = max_budget - dynamic_config["general"]["max_budget"] = max_budget + litellm.max_budget = float(max_budget) + dynamic_config["general"]["max_budget"] = litellm.max_budget if experimental: pass user_telemetry = telemetry @@ -5676,9 +5674,9 @@ class ProxyStartupEvent: """ from litellm.secret_managers.main import str_to_bool - _use_redis_transaction_buffer: Optional[ - Union[bool, str] - ] = general_settings.get("use_redis_transaction_buffer", False) + _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( + general_settings.get("use_redis_transaction_buffer", False) + ) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) @@ -12114,9 +12112,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/tests/test_litellm/proxy/test_max_budget_env_var.py b/tests/test_litellm/proxy/test_max_budget_env_var.py new file mode 100644 index 0000000000..4b96539282 --- /dev/null +++ b/tests/test_litellm/proxy/test_max_budget_env_var.py @@ -0,0 +1,38 @@ +""" +Test that max_budget from environment variable (string) is correctly +converted to float. +GitHub Issue: #23843 +""" + +import pytest + +import litellm +from litellm.proxy.proxy_server import initialize + + +@pytest.mark.asyncio +async def test_max_budget_string_converted_to_float(): + """ + When max_budget is set via os.environ/MAX_BUDGET, it arrives as a + string. initialize() should convert it to float so the comparison + `litellm.max_budget > 0` doesn't raise TypeError. + """ + original = litellm.max_budget + try: + await initialize(max_budget="100.5") + assert isinstance(litellm.max_budget, float) + assert litellm.max_budget == 100.5 + finally: + litellm.max_budget = original + + +@pytest.mark.asyncio +async def test_max_budget_float_stays_float(): + """max_budget as float should still work.""" + original = litellm.max_budget + try: + await initialize(max_budget=200.0) + assert isinstance(litellm.max_budget, float) + assert litellm.max_budget == 200.0 + finally: + litellm.max_budget = original From 1b91e1656aba9f0ab99e62659ca520e5fd4eee41 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 23:20:59 +0530 Subject: [PATCH 047/539] Add support for gpt-5.4 mini and nano --- .../llms/openai/chat/gpt_5_transformation.py | 6 +- ...odel_prices_and_context_window_backup.json | 85 +++++++++++++------ model_prices_and_context_window.json | 85 +++++++++++++------ .../llms/openai/test_gpt5_transformation.py | 85 +++++++++++++++++++ 4 files changed, 202 insertions(+), 59 deletions(-) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index bb5783011a..4fa45284df 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -200,14 +200,14 @@ class OpenAIGPT5Config(OpenAIGPTConfig): if "reasoning_effort" in optional_params: optional_params["reasoning_effort"] = normalized - if effective_effort is not None and effective_effort == "xhigh": - if not self._supports_reasoning_effort_level(model, "xhigh"): + if effective_effort is not None and effective_effort == "xhigh" or effective_effort == "minimal": + if not self._supports_reasoning_effort_level(model, effective_effort): if litellm.drop_params or drop_params: non_default_params.pop("reasoning_effort", None) else: raise litellm.utils.UnsupportedParamsError( message=( - "reasoning_effort='xhigh' is only supported for gpt-5.1-codex-max, gpt-5.2, and gpt-5.4+ models." + f"reasoning_effort={effective_effort} is not supported for this model." ), status_code=400, ) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 0be7aefa36..1ac73d5bc0 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18333,7 +18333,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2": { "cache_read_input_token_cost": 1.75e-07, @@ -18373,7 +18374,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-2025-12-11": { "cache_read_input_token_cost": 1.75e-07, @@ -18413,7 +18415,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-chat-latest": { "cache_read_input_token_cost": 1.75e-07, @@ -18450,7 +18453,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.3-chat-latest": { "cache_read_input_token_cost": 1.75e-07, @@ -18487,7 +18491,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-pro": { "input_cost_per_token": 2.1e-05, @@ -18520,7 +18525,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-pro-2025-12-11": { "input_cost_per_token": 2.1e-05, @@ -18553,7 +18559,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4": { "cache_read_input_token_cost": 2.5e-07, @@ -18602,7 +18609,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-2026-03-05": { "cache_read_input_token_cost": 2.5e-07, @@ -18697,7 +18705,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-pro-2026-03-05": { "cache_read_input_token_cost": 3e-06, @@ -18745,7 +18754,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, @@ -18786,7 +18796,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": false }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, @@ -18827,7 +18838,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": false }, "gpt-5-pro": { "input_cost_per_token": 1.5e-05, @@ -18862,7 +18874,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-pro-2025-10-06": { "input_cost_per_token": 1.5e-05, @@ -18897,7 +18910,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-2025-08-07": { "cache_read_input_token_cost": 1.25e-07, @@ -18939,7 +18953,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-chat": { "cache_read_input_token_cost": 1.25e-07, @@ -18973,7 +18988,8 @@ "supports_tool_choice": false, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-chat-latest": { "cache_read_input_token_cost": 1.25e-07, @@ -19007,7 +19023,8 @@ "supports_tool_choice": false, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-codex": { "cache_read_input_token_cost": 1.25e-07, @@ -19040,7 +19057,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex": { "cache_read_input_token_cost": 1.25e-07, @@ -19076,7 +19094,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex-max": { "cache_read_input_token_cost": 1.25e-07, @@ -19109,7 +19128,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex-mini": { "cache_read_input_token_cost": 2.5e-08, @@ -19145,7 +19165,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-codex": { "cache_read_input_token_cost": 1.75e-07, @@ -19181,7 +19202,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, @@ -19217,7 +19239,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-mini": { "cache_read_input_token_cost": 2.5e-08, @@ -19259,7 +19282,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-mini-2025-08-07": { "cache_read_input_token_cost": 2.5e-08, @@ -19301,7 +19325,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-nano": { "cache_read_input_token_cost": 5e-09, @@ -19340,7 +19365,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-nano-2025-08-07": { "cache_read_input_token_cost": 5e-09, @@ -19378,7 +19404,9 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true + }, "gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, @@ -36349,7 +36377,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-search-api-2025-10-14": { "cache_read_input_token_cost": 1.25e-07, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 0be7aefa36..1ac73d5bc0 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18333,7 +18333,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2": { "cache_read_input_token_cost": 1.75e-07, @@ -18373,7 +18374,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-2025-12-11": { "cache_read_input_token_cost": 1.75e-07, @@ -18413,7 +18415,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-chat-latest": { "cache_read_input_token_cost": 1.75e-07, @@ -18450,7 +18453,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.3-chat-latest": { "cache_read_input_token_cost": 1.75e-07, @@ -18487,7 +18491,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-pro": { "input_cost_per_token": 2.1e-05, @@ -18520,7 +18525,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-pro-2025-12-11": { "input_cost_per_token": 2.1e-05, @@ -18553,7 +18559,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4": { "cache_read_input_token_cost": 2.5e-07, @@ -18602,7 +18609,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-2026-03-05": { "cache_read_input_token_cost": 2.5e-07, @@ -18697,7 +18705,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-pro-2026-03-05": { "cache_read_input_token_cost": 3e-06, @@ -18745,7 +18754,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, @@ -18786,7 +18796,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": false }, "gpt-5.4-nano": { "cache_read_input_token_cost": 2e-08, @@ -18827,7 +18838,8 @@ "supports_service_tier": true, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": false }, "gpt-5-pro": { "input_cost_per_token": 1.5e-05, @@ -18862,7 +18874,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-pro-2025-10-06": { "input_cost_per_token": 1.5e-05, @@ -18897,7 +18910,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-2025-08-07": { "cache_read_input_token_cost": 1.25e-07, @@ -18939,7 +18953,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-chat": { "cache_read_input_token_cost": 1.25e-07, @@ -18973,7 +18988,8 @@ "supports_tool_choice": false, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-chat-latest": { "cache_read_input_token_cost": 1.25e-07, @@ -19007,7 +19023,8 @@ "supports_tool_choice": false, "supports_vision": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-codex": { "cache_read_input_token_cost": 1.25e-07, @@ -19040,7 +19057,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex": { "cache_read_input_token_cost": 1.25e-07, @@ -19076,7 +19094,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex-max": { "cache_read_input_token_cost": 1.25e-07, @@ -19109,7 +19128,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-codex-mini": { "cache_read_input_token_cost": 2.5e-08, @@ -19145,7 +19165,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.2-codex": { "cache_read_input_token_cost": 1.75e-07, @@ -19181,7 +19202,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": true + "supports_xhigh_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "gpt-5.3-codex": { "cache_read_input_token_cost": 1.75e-07, @@ -19217,7 +19239,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-mini": { "cache_read_input_token_cost": 2.5e-08, @@ -19259,7 +19282,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-mini-2025-08-07": { "cache_read_input_token_cost": 2.5e-08, @@ -19301,7 +19325,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-nano": { "cache_read_input_token_cost": 5e-09, @@ -19340,7 +19365,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-nano-2025-08-07": { "cache_read_input_token_cost": 5e-09, @@ -19378,7 +19404,9 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true + }, "gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, @@ -36349,7 +36377,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5-search-api-2025-10-14": { "cache_read_input_token_cost": 1.25e-07, diff --git a/tests/test_litellm/llms/openai/test_gpt5_transformation.py b/tests/test_litellm/llms/openai/test_gpt5_transformation.py index 47ae3c44c9..8ae9bf48a6 100644 --- a/tests/test_litellm/llms/openai/test_gpt5_transformation.py +++ b/tests/test_litellm/llms/openai/test_gpt5_transformation.py @@ -324,6 +324,91 @@ def test_gpt5_4_pro_allows_reasoning_effort_xhigh(config: OpenAIConfig): assert params["reasoning_effort"] == "xhigh" +def test_gpt5_4_allows_reasoning_effort_minimal(config: OpenAIConfig): + """gpt-5.4 supports reasoning_effort='minimal'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4", + drop_params=False, + ) + assert params["reasoning_effort"] == "minimal" + + +def test_gpt5_4_pro_allows_reasoning_effort_minimal(config: OpenAIConfig): + """gpt-5.4-pro supports reasoning_effort='minimal'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4-pro", + drop_params=False, + ) + assert params["reasoning_effort"] == "minimal" + + +def test_gpt5_4_mini_rejects_reasoning_effort_minimal(config: OpenAIConfig): + """gpt-5.4-mini does not support reasoning_effort='minimal'.""" + with pytest.raises(litellm.utils.UnsupportedParamsError): + config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4-mini", + drop_params=False, + ) + + +def test_gpt5_4_nano_rejects_reasoning_effort_minimal(config: OpenAIConfig): + """gpt-5.4-nano does not support reasoning_effort='minimal'.""" + with pytest.raises(litellm.utils.UnsupportedParamsError): + config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4-nano", + drop_params=False, + ) + + +def test_gpt5_drops_reasoning_effort_minimal_when_requested(config: OpenAIConfig): + """reasoning_effort='minimal' is dropped for unsupported models when drop_params=True.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4-mini", + drop_params=True, + ) + assert "reasoning_effort" not in params + + +def test_gpt5_minimal_dict_triggers_validation(config: OpenAIConfig): + """Dict with effort='minimal' triggers minimal model-support validation.""" + with pytest.raises(litellm.utils.UnsupportedParamsError): + config.map_openai_params( + non_default_params={"reasoning_effort": {"effort": "minimal", "summary": "detailed"}}, + optional_params={}, + model="gpt-5.4-mini", + drop_params=False, + ) + + +def test_gpt5_minimal_dict_accepted_for_supported_model(config: OpenAIConfig): + """Dict with effort='minimal' passes through for gpt-5.4+.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": {"effort": "minimal", "summary": "detailed"}}, + optional_params={}, + model="gpt-5.4", + drop_params=False, + ) + assert params["reasoning_effort"] == "minimal" + + +def test_gpt5_supports_reasoning_effort_level_minimal(gpt5_config: OpenAIGPT5Config): + """Test that _supports_reasoning_effort_level correctly identifies minimal support.""" + assert gpt5_config._supports_reasoning_effort_level("gpt-5.4", "minimal") + assert gpt5_config._supports_reasoning_effort_level("gpt-5.4-pro", "minimal") + assert not gpt5_config._supports_reasoning_effort_level("gpt-5.4-mini", "minimal") + assert not gpt5_config._supports_reasoning_effort_level("gpt-5.4-nano", "minimal") + + def test_gpt5_normalizes_reasoning_effort_dict_with_summary(config: OpenAIConfig): """Dict with summary/generate_summary is normalized for chat completions.""" params = config.map_openai_params( From 0ecea85bd0e6f4187f22176f7317d3fb1f498847 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 23:27:01 +0530 Subject: [PATCH 048/539] Fix supports none flag --- litellm/model_prices_and_context_window_backup.json | 2 +- model_prices_and_context_window.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 1ac73d5bc0..c80d637233 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18837,7 +18837,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, - "supports_none_reasoning_effort": false, + "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false }, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 1ac73d5bc0..c80d637233 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18837,7 +18837,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, - "supports_none_reasoning_effort": false, + "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false }, From a34f7e483112c7a4b958cf19debc50e891fdd52c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 17 Mar 2026 23:27:39 +0530 Subject: [PATCH 049/539] Fix paranthesis: --- litellm/llms/openai/chat/gpt_5_transformation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index 4fa45284df..8522a97a38 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -200,7 +200,9 @@ class OpenAIGPT5Config(OpenAIGPTConfig): if "reasoning_effort" in optional_params: optional_params["reasoning_effort"] = normalized - if effective_effort is not None and effective_effort == "xhigh" or effective_effort == "minimal": + if effective_effort is not None and ( + effective_effort == "xhigh" or effective_effort == "minimal" + ): if not self._supports_reasoning_effort_level(model, effective_effort): if litellm.drop_params or drop_params: non_default_params.pop("reasoning_effort", None) From 0c28b47057102720353c04fbafd11f64760eb196 Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 17:47:01 -0300 Subject: [PATCH 050/539] fix(vertex): streaming finish_reason="stop" instead of "tool_calls" for gemini-3.1-flash-lite-preview Models like gemini-3.1-flash-lite-preview send the final streaming chunk with empty content (text:"") alongside finishReason:"STOP", instead of omitting content entirely. The existing fix (PR #21577) only handled chunks without content, so this case was missed. Now, after processing candidates, if tool_calls were seen in earlier chunks and a choice has finish_reason="stop", it is overridden to "tool_calls" to match the OpenAI spec. Fixes #22900 --- .../vertex_and_google_ai_studio_gemini.py | 10 +++ ...emini_streaming_tool_call_finish_reason.py | 72 +++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 3f1bccaccf..1054b311d0 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -3001,6 +3001,16 @@ class ModelResponseIterator: ) model_response.choices.append(choice) + # Also handle the case where the final chunk has empty + # content (e.g. text:"") WITH finishReason. In this case + # _process_candidates DOES create a choice, but maps + # finishReason="STOP" to "stop" because the current chunk + # has no tool_calls. Override if we saw tool_calls earlier. + if self.has_seen_tool_calls: + for choice in model_response.choices: + if choice.finish_reason == "stop": + choice.finish_reason = "tool_calls" + setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) # type: ignore setattr(model_response, "vertex_ai_url_context_metadata", url_context_metadata) # type: ignore setattr(model_response, "vertex_ai_safety_ratings", safety_ratings) # type: ignore diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_gemini_streaming_tool_call_finish_reason.py b/tests/test_litellm/llms/vertex_ai/gemini/test_gemini_streaming_tool_call_finish_reason.py index 3f8efd47fa..d4d76ab307 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_gemini_streaming_tool_call_finish_reason.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_gemini_streaming_tool_call_finish_reason.py @@ -230,3 +230,75 @@ def test_streaming_content_filter_finish_reason_preserved(): assert response is not None assert len(response.choices) == 1 assert response.choices[0].finish_reason == "content_filter" + + +def test_streaming_tool_call_finish_reason_with_empty_content_in_final_chunk(): + """ + When Gemini streams tool calls and the final chunk has BOTH empty content + (e.g. parts: [{text: ""}]) AND finishReason="STOP", the finish_reason + must still be "tool_calls". + + This covers models like gemini-3.1-flash-lite-preview that send the + final chunk with content (empty text) instead of omitting it entirely. + + Ref: https://github.com/BerriAI/litellm/issues/22900 + """ + logging_obj = _make_logging_obj() + iterator = ModelResponseIterator( + streaming_response=iter([]), + sync_stream=True, + logging_obj=logging_obj, + ) + + # Chunk 1: tool call with no finishReason + chunk_with_tool_calls = { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "get_weather", + "args": {"location": "San Francisco"}, + } + } + ], + "role": "model", + }, + "index": 0, + } + ], + } + + # Chunk 2: finishReason="STOP" WITH empty content (text: "") + chunk_with_empty_content_and_finish = { + "candidates": [ + { + "content": { + "parts": [{"text": ""}], + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + } + ], + "usageMetadata": { + "promptTokenCount": 50, + "candidatesTokenCount": 20, + "totalTokenCount": 70, + }, + } + + # Process chunk 1 + response1 = iterator.chunk_parser(chunk_with_tool_calls) + assert response1 is not None + assert len(response1.choices) == 1 + assert response1.choices[0].delta.tool_calls is not None + assert iterator.has_seen_tool_calls is True + + # Process chunk 2 (final chunk with empty content) + response2 = iterator.chunk_parser(chunk_with_empty_content_and_finish) + assert response2 is not None + assert len(response2.choices) == 1 + # Must be "tool_calls", NOT "stop" + assert response2.choices[0].finish_reason == "tool_calls" From 8b4a74a69c14b1a806d7e332cf371303de8c52f2 Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 18:06:42 -0300 Subject: [PATCH 051/539] fix(core): map Anthropic 'refusal' finish reason to 'content_filter' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Anthropic's 'refusal' stop_reason was missing from _FINISH_REASON_MAP, causing it to fall through to the default 'stop' — hiding the fact that the model refused to respond due to safety policies. Fixes #23793 --- litellm/litellm_core_utils/core_helpers.py | 1 + tests/test_litellm/litellm_core_utils/test_core_helpers.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index ee111f3592..9d9255cdf1 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -64,6 +64,7 @@ _FINISH_REASON_MAP: dict[str, OpenAIChatCompletionFinishReason] = { "end_turn": "stop", "max_tokens": "length", "tool_use": "tool_calls", + "refusal": "content_filter", "compaction": "length", # Cohere "COMPLETE": "stop", diff --git a/tests/test_litellm/litellm_core_utils/test_core_helpers.py b/tests/test_litellm/litellm_core_utils/test_core_helpers.py index 0ef76e0942..134ed5d25f 100644 --- a/tests/test_litellm/litellm_core_utils/test_core_helpers.py +++ b/tests/test_litellm/litellm_core_utils/test_core_helpers.py @@ -74,6 +74,9 @@ class TestMapFinishReasonAnthropic: def test_compaction(self): assert map_finish_reason("compaction") == "length" + def test_refusal(self): + assert map_finish_reason("refusal") == "content_filter" + class TestMapFinishReasonGemini: @pytest.mark.parametrize( From bed44f5fe53cff86b0818844db3be05a3bc77719 Mon Sep 17 00:00:00 2001 From: Rohan Date: Wed, 18 Mar 2026 03:08:04 +0530 Subject: [PATCH 052/539] Add Akto Guardrails to LiteLLM (#23250) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * akto guardrails support in litellm * docs(guardrails): add akto to supported values in types/guardrails.py * frontend changes + fixes * feat(akto): update Akto guardrail integration with new configuration options and modes * docs(akto): enhance Akto documentation and configuration descriptions for clarity * feat(tests): add proxy server request headers to sample request data * refactor(akto): remove optional account and VXLAN IDs; update documentation and tests * feat(akto): add event_type parameter for enhanced observability in guardrail logging * refactor(akto): update environment variable references * refactor the python codes * refactor and fix linting * refactor(akto): remove unused event hook and clean up imports * refactor(akto): enhance AktoGuardrail with async support and improved logging * fix: Register DynamoAI guardrail initializer and enum entry (#23752) * fix: Register DynamoAI guardrail initializer and enum entry Fix the "Unsupported guardrail: dynamoai" error by: 1. Adding DYNAMOAI to SupportedGuardrailIntegrations enum 2. Implementing initialize_guardrail() and registries in dynamoai/__init__.py The DynamoAI guardrail was added in PR #15920 but never properly registered in the initialization system. The __init__.py was missing the guardrail_initializer_registry and guardrail_class_registry dictionaries that the dynamic discovery mechanism looks for at module load time. Fixes #22773 Co-Authored-By: Claude Haiku 4.5 * Update litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * test: Add tests for DynamoAI guardrail registration Verifies enum entry, initializer registry, class registry, instance creation, and global registry discovery. Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Haiku 4.5 Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * docs: add v1.82.3 release notes and update provider_endpoints_support.json (#23816) * Revert "docs: add v1.82.3 release notes and update provider_endpoints_support…" (#23817) This reverts commit 966124966f83e6e1091ad08d9dfee6e341d3ad85. * Refactor Akto guardrail configuration and tests; update UI description and tags * add account and vxlan ID parameters to Akto guardrail initialization; update Akto logo format * enhance Akto guardrail documentation and improve error handling for non-JSON responses * address greptile issues * fix: update payload handling to use 'data' instead of 'json' in AktoGuardrail and adjust tests accordingly --------- Co-authored-by: Harshit Jain <48647625+Harshit28j@users.noreply.github.com> Co-authored-by: Claude Haiku 4.5 Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Joe Reyna Co-authored-by: Krish Dholakia --- docs/my-website/docs/proxy/guardrails/akto.md | 139 +++++ docs/my-website/sidebars.js | 1 + .../guardrail_hooks/akto/__init__.py | 37 ++ .../guardrails/guardrail_hooks/akto/akto.py | 456 +++++++++++++++ .../guardrail_hooks/dynamoai/__init__.py | 32 +- litellm/types/guardrails.py | 8 +- .../proxy/guardrails/guardrail_hooks/akto.py | 55 ++ .../guardrails_tests/test_akto_guardrails.py | 550 ++++++++++++++++++ .../guardrail_hooks/test_dynamoai.py | 81 +++ .../public/assets/logos/akto.svg | 10 + .../guardrails/guardrail_garden_configs.ts | 6 + .../guardrails/guardrail_garden_data.ts | 8 + .../guardrails/guardrail_info_helpers.tsx | 1 + 13 files changed, 1382 insertions(+), 2 deletions(-) create mode 100644 docs/my-website/docs/proxy/guardrails/akto.md create mode 100644 litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py create mode 100644 litellm/proxy/guardrails/guardrail_hooks/akto/akto.py create mode 100644 litellm/types/proxy/guardrails/guardrail_hooks/akto.py create mode 100644 tests/guardrails_tests/test_akto_guardrails.py create mode 100644 tests/test_litellm/proxy/guardrails/guardrail_hooks/test_dynamoai.py create mode 100644 ui/litellm-dashboard/public/assets/logos/akto.svg diff --git a/docs/my-website/docs/proxy/guardrails/akto.md b/docs/my-website/docs/proxy/guardrails/akto.md new file mode 100644 index 0000000000..67ae741d11 --- /dev/null +++ b/docs/my-website/docs/proxy/guardrails/akto.md @@ -0,0 +1,139 @@ +# Akto + +## Overview +[Akto](https://www.akto.io/) provides API security guardrails and data ingestion for LLM traffic. + +Akto now uses a **two-entry guardrail pattern** in LiteLLM: +- `akto-validate` (`pre_call`) for request validation +- `akto-ingest` (`post_call`) for request/response ingestion + +There is no `on_flagged` setting anymore. + +Use these as two separate guardrails in `config.yaml`: +- `guardrail_name: "akto-validate"` +- `guardrail_name: "akto-ingest"` + +## 1. Get Your Akto Credentials + +Set up the Akto Guardrail API Service and grab: +- `AKTO_GUARDRAIL_API_BASE` — your Guardrail API Base URL +- `AKTO_API_KEY` — your API key + +## 2. Configure in `config.yaml` + +### Block + Ingest (recommended) + +Use both entries below. This gives you: +- pre-call block decision +- post-call ingestion for allowed traffic + +Keep these as two separate entries (`akto-validate` and `akto-ingest`). + +```yaml +guardrails: + - guardrail_name: "akto-validate" + litellm_params: + guardrail: akto + mode: pre_call + akto_base_url: os.environ/AKTO_GUARDRAIL_API_BASE + akto_api_key: os.environ/AKTO_API_KEY + default_on: true + unreachable_fallback: fail_closed # optional: fail_open | fail_closed (default: fail_closed) + guardrail_timeout: 5 # optional, default: 5 + akto_account_id: "1000000" # optional, env fallback: AKTO_ACCOUNT_ID + akto_vxlan_id: "0" # optional, env fallback: AKTO_VXLAN_ID + + - guardrail_name: "akto-ingest" + litellm_params: + guardrail: akto + mode: post_call + akto_base_url: os.environ/AKTO_GUARDRAIL_API_BASE + akto_api_key: os.environ/AKTO_API_KEY + default_on: true +``` + +### Monitor-only mode + +If you only want logging/ingestion and no blocking, keep only `akto-ingest`. + +```yaml +guardrails: + - guardrail_name: "akto-ingest" + litellm_params: + guardrail: akto + mode: post_call + akto_base_url: os.environ/AKTO_GUARDRAIL_API_BASE + akto_api_key: os.environ/AKTO_API_KEY + default_on: true +``` + +## 3. Test It + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer " \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello, how are you?"} + ] + }' +``` + +If a request gets blocked: + +```json +{ + "error": { + "message": "Prompt injection detected", + "type": "None", + "param": "None", + "code": "403" + } +} +``` + +## 4. How It Works + +**Block + Ingest mode:** +``` +Request → LiteLLM → Akto guardrail check + → Allowed → forward to LLM → ingest response + → Blocked → ingest blocked marker → 403 error +``` + +**Monitor-only mode:** +``` +Request → LiteLLM → forward to LLM → get response + → Send to Akto (guardrails + ingest) → log only +``` + +## 5. Event behavior + +| Entry | LiteLLM hook | Akto call behavior | +|------|---|---| +| `akto-validate` | `pre_call` | Awaited call with `guardrails=true`, `ingest_data=false` | +| `akto-ingest` | `post_call` | Fire-and-forget call with `guardrails=true`, `ingest_data=true` | + +When blocked in `pre_call`, LiteLLM sends one fire-and-forget ingest payload with blocked metadata and returns `403`. + +## 6. Parameters + +| Parameter | Env Variable | Default | Description | +|-----------|-------------|---------|-------------| +| `akto_base_url` | `AKTO_GUARDRAIL_API_BASE` | *required* | Akto Guardrail API Base URL | +| `akto_api_key` | `AKTO_API_KEY` | *required* | API key (sent as `Authorization` header) | +| `akto_account_id` | `AKTO_ACCOUNT_ID` | `1000000` | Akto account id included in payload | +| `akto_vxlan_id` | `AKTO_VXLAN_ID` | `0` | Akto vxlan id included in payload | +| `unreachable_fallback` | — | `fail_closed` | `fail_open` or `fail_closed` | +| `guardrail_timeout` | — | `5` | Timeout in seconds | +| `default_on` | — | `true` (recommended) | Enables the guardrail entry by default | + +## 7. Error Handling + +| Scenario | `fail_closed` (default) | `fail_open` | +|----------|------------------------|-------------| +| Akto unreachable | ❌ Blocked (503) | ✅ Passes through | +| Akto returns error | ❌ Blocked (503) | ✅ Passes through | +| Guardrail says no | ❌ Blocked (403) | ❌ Blocked (403) | diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1362745a91..e53891d633 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -52,6 +52,7 @@ const sidebars = { label: "Providers", items: [ ...[ + "proxy/guardrails/akto", "proxy/guardrails/qualifire", "proxy/guardrails/aim_security", "proxy/guardrails/onyx_security", diff --git a/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py new file mode 100644 index 0000000000..4ae2675540 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/akto/__init__.py @@ -0,0 +1,37 @@ +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + +from .akto import AktoGuardrail + + +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"): + import litellm + + _akto_callback = AktoGuardrail( + akto_base_url=getattr(litellm_params, "akto_base_url", None), + akto_api_key=getattr(litellm_params, "akto_api_key", None), + akto_account_id=getattr(litellm_params, "akto_account_id", None), + akto_vxlan_id=getattr(litellm_params, "akto_vxlan_id", None), + unreachable_fallback=getattr(litellm_params, "unreachable_fallback", "fail_closed"), + guardrail_timeout=getattr(litellm_params, "guardrail_timeout", None), + guardrail_name=guardrail.get("guardrail_name", ""), + event_hook=litellm_params.mode, + default_on=litellm_params.default_on, + ) + + litellm.logging_callback_manager.add_litellm_callback(_akto_callback) + return _akto_callback + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.AKTO.value: initialize_guardrail, +} + +guardrail_class_registry = { + SupportedGuardrailIntegrations.AKTO.value: AktoGuardrail, +} diff --git a/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py b/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py new file mode 100644 index 0000000000..be9c9cb1be --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/akto/akto.py @@ -0,0 +1,456 @@ +"""Akto guardrail integration for LiteLLM proxy. + +Uses a two-config-entry pattern: + - akto-validate (pre_call): Checks request against Akto guardrails, blocks if flagged. + - akto-ingest (post_call): Sends request+response to Akto for data ingestion. + +For monitor-only mode, enable only akto-ingest without akto-validate. +""" + +import asyncio +import json +import os +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Type + +from fastapi import HTTPException + +import httpx + +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.guardrails import GuardrailEventHooks +from litellm.types.utils import GenericGuardrailAPIInputs + +if TYPE_CHECKING: + from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel + +HTTP_PROXY_PATH = "/api/http-proxy" +AKTO_CONNECTOR_NAME = "litellm" +DEFAULT_GUARDRAIL_TIMEOUT = 5 + + +class AktoGuardrail(CustomGuardrail): + """LiteLLM guardrail hook that validates and ingests LLM traffic via the Akto API.""" + + # Maps event_hook to the input_type it should handle; mismatches are no-ops + HOOK_TO_INPUT = {"pre_call": "request", "post_call": "response"} + + @staticmethod + def get_config_model() -> Type["GuardrailConfigModel"]: + """Return the Pydantic config model for YAML-based initialization.""" + from litellm.types.proxy.guardrails.guardrail_hooks.akto import ( + AktoConfigModel, + ) + + return AktoConfigModel + + def __init__( + self, + akto_base_url: Optional[str] = None, + akto_api_key: Optional[str] = None, + akto_account_id: Optional[str] = None, + akto_vxlan_id: Optional[str] = None, + unreachable_fallback: Literal["fail_closed", "fail_open"] = "fail_closed", + guardrail_timeout: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize the Akto guardrail. + + Args: + akto_base_url: Akto API base URL. Falls back to AKTO_GUARDRAIL_API_BASE env var. + akto_api_key: Akto API key. Falls back to AKTO_API_KEY env var. + akto_account_id: Akto account ID. Falls back to AKTO_ACCOUNT_ID env var, then "1000000". + akto_vxlan_id: Akto VXLAN ID. Falls back to AKTO_VXLAN_ID env var, then "0". + unreachable_fallback: Behavior when Akto is unreachable — block or allow. + guardrail_timeout: HTTP timeout in seconds for Akto API calls. + """ + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback, + ) + self.background_tasks: set = set() + + self.akto_base_url = (akto_base_url or os.environ.get("AKTO_GUARDRAIL_API_BASE", "")).rstrip("/") + if not self.akto_base_url: + raise ValueError("akto_base_url is required. Set AKTO_GUARDRAIL_API_BASE or pass it in litellm_params.") + + self.akto_api_key = akto_api_key or os.environ.get("AKTO_API_KEY", "") + if not self.akto_api_key: + raise ValueError("akto_api_key is required. Set AKTO_API_KEY or pass it in litellm_params.") + + self.unreachable_fallback: Literal["fail_closed", "fail_open"] = unreachable_fallback + self.guardrail_timeout = guardrail_timeout or DEFAULT_GUARDRAIL_TIMEOUT + self.akto_account_id = akto_account_id or os.environ.get("AKTO_ACCOUNT_ID", "1000000") + self.akto_vxlan_id = akto_vxlan_id or os.environ.get("AKTO_VXLAN_ID", "0") + + kwargs["supported_event_hooks"] = [ + GuardrailEventHooks.pre_call, + GuardrailEventHooks.post_call, + ] + super().__init__(**kwargs) + + verbose_proxy_logger.debug( + "Akto guardrail initialized: base_url=%s fallback=%s", + self.akto_base_url, + self.unreachable_fallback, + ) + + @staticmethod + def resolve_metadata_value(request_data: Optional[dict], key: str) -> Optional[str]: + """Look up a metadata value from litellm_metadata or metadata dicts.""" + if request_data is None: + return None + for dict_key in ("litellm_metadata", "metadata"): + container = request_data.get(dict_key) or {} + if isinstance(container, dict) and container: + value = container.get(key) + if value is not None: + return str(value).strip() + return None + + @staticmethod + def extract_request_path(request_data: dict) -> str: + """Extract the API route from request metadata, defaulting to /v1/chat/completions.""" + metadata = request_data.get("metadata") or {} + if not isinstance(metadata, dict): + metadata = {} + route = metadata.get("user_api_key_request_route") + return route if route else "/v1/chat/completions" + + def prepare_headers(self) -> Dict[str, str]: + """Build HTTP headers for the Akto API call.""" + return { + "content-type": "application/json", + "Authorization": self.akto_api_key, + } + + @staticmethod + def build_query_params(*, guardrails: bool, ingest_data: bool) -> Dict[str, str]: + """Build query params that control Akto backend behavior (guardrail check and/or data ingestion).""" + params: Dict[str, str] = {"akto_connector": AKTO_CONNECTOR_NAME} + if guardrails: + params["guardrails"] = "true" + if ingest_data: + params["ingest_data"] = "true" + return params + + @staticmethod + def build_request_headers(request_data: dict) -> Dict[str, str]: + """Build the requestHeaders field from proxy request headers.""" + headers: Dict[str, str] = {"content-type": "application/json"} + proxy_req = request_data.get("proxy_server_request", {}) + if not isinstance(proxy_req, dict): + return headers + proxy_req_headers = proxy_req.get("headers") + if isinstance(proxy_req_headers, dict): + for key, val in proxy_req_headers.items(): + if key and val: + headers[str(key).lower()] = str(val) + return headers + + @staticmethod + def build_request_body( + inputs: GenericGuardrailAPIInputs, + request_data: Optional[dict] = None, + ) -> Dict[str, Any]: + """Build the LLM request body from guardrail inputs (messages, model, tools).""" + model = inputs.get("model", "") or "" + body: Dict[str, Any] = {"model": model} + + structured = inputs.get("structured_messages") + if structured: + body["messages"] = structured + elif request_data is not None and request_data.get("messages"): + body["messages"] = request_data["messages"] + if request_data.get("model"): + body["model"] = request_data["model"] + else: + texts = inputs.get("texts", []) + body["messages"] = [{"role": "user", "content": t} for t in texts] if texts else [] + + tools = inputs.get("tools") + if tools: + body["tools"] = tools + elif request_data is not None and request_data.get("tools"): + body["tools"] = request_data["tools"] + + tool_calls = inputs.get("tool_calls") + if tool_calls: + body["tool_calls"] = tool_calls + + return body + + @staticmethod + def build_response_body( + inputs: GenericGuardrailAPIInputs, + request_data: Optional[dict] = None, + ) -> Dict[str, Any]: + """Build the LLM response body, preferring the actual model response if available.""" + model_response = request_data.get("response") if request_data else None + if model_response is not None and hasattr(model_response, "model_dump"): + return model_response.model_dump() + + texts = inputs.get("texts", []) + if texts: + return {"choices": [{"message": {"content": t, "role": "assistant"}} for t in texts]} + return {} + + @staticmethod + def build_tag_metadata(request_data: dict) -> Dict[str, str]: + """Build tag/metadata dict with user_id and team_id for Akto tracking.""" + tag: Dict[str, str] = {"gen-ai": "Gen AI"} + user_id = AktoGuardrail.resolve_metadata_value(request_data, "user_api_key_user_id") + team_id = AktoGuardrail.resolve_metadata_value(request_data, "user_api_key_team_id") + if user_id: + tag["user_id"] = user_id + if team_id: + tag["team_id"] = team_id + return tag + + def build_akto_payload( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + *, + status_code: int = 200, + include_response: bool = False, + ) -> Dict[str, Any]: + """Build the flat MIRRORING payload sent to Akto's HTTP proxy endpoint. + + All body fields use double-encoding: json.dumps({"body": json.dumps(actual_body)}) + to match the canonical CLI hook format. + """ + request_path = self.extract_request_path(request_data) + request_headers = self.build_request_headers(request_data) + request_body = self.build_request_body(inputs, request_data) + tag = self.build_tag_metadata(request_data) + + response_payload = json.dumps({}) # Empty body wrapper when no response yet + response_headers: Dict[str, str] = {} + if include_response: + response_body = self.build_response_body(inputs, request_data) + response_payload = json.dumps({"body": json.dumps(response_body)}) # Double-encoded + response_headers = {"content-type": "application/json"} + + # Extract client IP from proxy headers + ip = "" + proxy_req = request_data.get("proxy_server_request", {}) + proxy_headers = proxy_req.get("headers", {}) if isinstance(proxy_req, dict) else {} + if isinstance(proxy_headers, dict): + ip = proxy_headers.get("x-forwarded-for") or proxy_headers.get("x-real-ip") or "" + if "," in ip: + ip = ip.split(",")[0].strip() + + return { + "path": request_path, + "requestHeaders": json.dumps(request_headers), + "responseHeaders": json.dumps(response_headers), + "method": "POST", + "requestPayload": json.dumps({"body": json.dumps(request_body)}), # Double-encoded + "responsePayload": response_payload, + "ip": ip, + "destIp": "127.0.0.1", + "time": str(int(datetime.now().timestamp() * 1000)), + "statusCode": str(status_code), + "type": "HTTP/1.1", + "status": str(status_code), + "akto_account_id": self.akto_account_id, + "akto_vxlan_id": self.akto_vxlan_id, + "is_pending": "false", + "source": "MIRRORING", + "direction": None, + "process_id": None, + "socket_id": None, + "daemonset_id": None, + "enabled_graph": None, + "tag": json.dumps(tag), + "metadata": json.dumps(tag), + "contextSource": "AGENTIC", + } + + async def send_request( + self, + *, + guardrails: bool, + ingest_data: bool, + payload: dict, + ) -> httpx.Response: + """Send an HTTP POST to the Akto API endpoint.""" + endpoint = f"{self.akto_base_url}{HTTP_PROXY_PATH}" + params = self.build_query_params(guardrails=guardrails, ingest_data=ingest_data) + headers = self.prepare_headers() + return await self.async_handler.post( + url=endpoint, + data=json.dumps(payload), + params=params, + headers=headers, + timeout=self.guardrail_timeout, + ) + + @staticmethod + def handle_guardrail_response(response: httpx.Response) -> Tuple[bool, str]: + """Parse the Akto guardrail response. Returns (allowed, reason).""" + if response.status_code != 200: + verbose_proxy_logger.error("Akto returned HTTP %d", response.status_code) + raise httpx.HTTPStatusError( + f"Akto returned unexpected status {response.status_code}", + request=response.request, + response=response, + ) + try: + result = response.json() + except (json.JSONDecodeError, ValueError) as e: + response_text = getattr(response, "text", "") + verbose_proxy_logger.error( + "Akto returned non-JSON body for status 200: %r", + response_text[:200], + ) + raise httpx.RequestError( + "Akto returned non-JSON body", + request=response.request, + ) from e + if not isinstance(result, dict): + return True, "" + data = result.get("data") or {} + if not isinstance(data, dict): + return True, "" + guardrails_result = data.get("guardrailsResult") or {} + if not isinstance(guardrails_result, dict): + return True, "" + return ( + bool(guardrails_result.get("Allowed", True)), + str(guardrails_result.get("Reason", "")), + ) + + def handle_unreachable( + self, + inputs: GenericGuardrailAPIInputs, + error: Exception, + ) -> GenericGuardrailAPIInputs: + """Handle Akto being unreachable based on fail_open/fail_closed config.""" + if self.unreachable_fallback == "fail_open": + verbose_proxy_logger.critical( + "Akto unreachable (fail-open): %s", + str(error), + exc_info=error, + ) + return inputs + + verbose_proxy_logger.error("Akto unreachable (fail-closed): %s", str(error)) + raise HTTPException( + status_code=503, + detail="Akto guardrail service unreachable", + ) + + async def fire_and_forget_request( + self, + *, + guardrails: bool, + ingest_data: bool, + payload: dict, + ) -> None: + """Send a request without awaiting it in the caller. Errors are logged, not raised.""" + try: + response = await self.send_request( + guardrails=guardrails, + ingest_data=ingest_data, + payload=payload, + ) + if response.status_code != 200: + verbose_proxy_logger.error( + "Akto fire-and-forget returned HTTP %d", + response.status_code, + ) + except Exception as e: + verbose_proxy_logger.error("Akto fire-and-forget error: %s", str(e)) + + @log_guardrail_information + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj=None, + ) -> GenericGuardrailAPIInputs: + """Main entry point called by LiteLLM's guardrail framework. + + Pre_call (input_type="request"): + - Awaits guardrail check. If blocked, fires off ingest with 403 marker and raises. + Post_call (input_type="response"): + - Fire-and-forget combined guardrail + ingest call. + """ + # Skip if this hook doesn't handle the current input_type + expected = self.HOOK_TO_INPUT.get(str(self.event_hook)) + if expected and expected != input_type: + return inputs + + if input_type == "request": + # Pre_call: awaited guardrail check (no ingestion) + payload = self.build_akto_payload(inputs, request_data, include_response=False) + try: + response = await self.send_request( + guardrails=True, + ingest_data=False, + payload=payload, + ) + allowed, reason = self.handle_guardrail_response(response) + except HTTPException: + raise + except (httpx.RequestError, httpx.HTTPStatusError) as e: + return self.handle_unreachable( + inputs=inputs, + error=e, + ) + + if not allowed: + # Build a blocked marker payload with 403 status and reason + blocked_payload = self.build_akto_payload( + inputs, + request_data, + include_response=False, + status_code=403, + ) + blocked_payload["responsePayload"] = json.dumps( + { + "body": json.dumps({"x-blocked-by": "Akto Proxy", "reason": reason}), + } + ) + blocked_payload["responseHeaders"] = json.dumps( + {"content-type": "application/json"}, + ) + # Fire-and-forget ingest of the blocked request, then raise 403 + task = asyncio.create_task( + self.fire_and_forget_request( + guardrails=False, + ingest_data=True, + payload=blocked_payload, + ) + ) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + raise HTTPException( + status_code=403, + detail=reason or "Blocked by Akto Guardrails", + ) + + elif input_type == "response": + # Post_call: fire-and-forget combined guardrail + ingest + payload = self.build_akto_payload(inputs, request_data, include_response=True) + task = asyncio.create_task( + self.fire_and_forget_request( + guardrails=True, + ingest_data=True, + payload=payload, + ) + ) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + return inputs diff --git a/litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py index 79f1992da4..f9ebf46a27 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py +++ b/litellm/proxy/guardrails/guardrail_hooks/dynamoai/__init__.py @@ -1,3 +1,33 @@ +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + from .dynamoai import DynamoAIGuardrails -__all__ = ["DynamoAIGuardrails"] +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"): + import litellm + + _dynamoai_callback = DynamoAIGuardrails( + api_base=litellm_params.api_base, + api_key=litellm_params.api_key, + guardrail_name=guardrail.get("guardrail_name", ""), + event_hook=litellm_params.mode, + default_on=litellm_params.default_on, + ) + litellm.logging_callback_manager.add_litellm_callback(_dynamoai_callback) + + return _dynamoai_callback + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.DYNAMOAI.value: initialize_guardrail, +} + + +guardrail_class_registry = { + SupportedGuardrailIntegrations.DYNAMOAI.value: DynamoAIGuardrails, +} diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index d5abf5c8fb..11b1d0d40c 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -17,6 +17,9 @@ from litellm.types.proxy.guardrails.guardrail_hooks.grayswan import ( from litellm.types.proxy.guardrails.guardrail_hooks.ibm import ( IBMGuardrailsBaseConfigModel, ) +from litellm.types.proxy.guardrails.guardrail_hooks.akto import ( + AktoConfigModel, +) from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter import ( ContentFilterCategoryConfig, ) @@ -33,7 +36,7 @@ Pydantic object defining how to set guardrails on litellm proxy guardrails: - guardrail_name: "bedrock-pre-guard" litellm_params: - guardrail: bedrock # supported values: "aporia", "bedrock", "lakera", "zscaler_ai_guard" + guardrail: bedrock # supported values: "akto", "aporia", "bedrock", "lakera", "zscaler_ai_guard" mode: "during_call" guardrailIdentifier: ff6ujrregl1q guardrailVersion: "DRAFT" @@ -44,6 +47,7 @@ guardrails: class SupportedGuardrailIntegrations(Enum): APORIA = "aporia" BEDROCK = "bedrock" + DYNAMOAI = "dynamoai" GUARDRAILS_AI = "guardrails_ai" LAKERA = "lakera" LAKERA_V2 = "lakera_v2" @@ -78,6 +82,7 @@ class SupportedGuardrailIntegrations(Enum): SEMANTIC_GUARD = "semantic_guard" MCP_END_USER_PERMISSION = "mcp_end_user_permission" BLOCK_CODE_EXECUTION = "block_code_execution" + AKTO = "akto" class Role(Enum): @@ -735,6 +740,7 @@ class LitellmParams( NomaGuardrailConfigModel, ToolPermissionGuardrailConfigModel, ZscalerAIGuardConfigModel, + AktoConfigModel, JavelinGuardrailConfigModel, BaseLitellmParams, EnkryptAIGuardrailConfigs, diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/akto.py b/litellm/types/proxy/guardrails/guardrail_hooks/akto.py new file mode 100644 index 0000000000..180c89e811 --- /dev/null +++ b/litellm/types/proxy/guardrails/guardrail_hooks/akto.py @@ -0,0 +1,55 @@ +from typing import Optional, Literal + +from pydantic import Field + +from .base import GuardrailConfigModel + + +class AktoConfigModel(GuardrailConfigModel): + """ + Config for the Akto guardrail. + + Use two separate config entries to control behaviour: + akto-validate (mode: pre_call) -> check guardrails, block if flagged + akto-ingest (mode: post_call) -> ingest request+response data + """ + + akto_base_url: Optional[str] = Field( + default=None, + description="Akto Guardrail API Base URL. Env: AKTO_GUARDRAIL_API_BASE.", + json_schema_extra={ + "examples": [ + "http://localhost:9090", + "https://akto-ingestion.example.com", + ] + }, + ) + + akto_api_key: Optional[str] = Field( + default=None, + description="API key for Akto. Env: AKTO_API_KEY.", + ) + + akto_account_id: Optional[str] = Field( + default=None, + description="Akto account ID for multi-tenant deployments. Env: AKTO_ACCOUNT_ID. Default: '1000000'.", + ) + + akto_vxlan_id: Optional[str] = Field( + default=None, + description="Akto VXLAN ID. Env: AKTO_VXLAN_ID. Default: '0'.", + ) + + unreachable_fallback: Literal["fail_closed", "fail_open"] = Field( + default="fail_closed", + description="What to do when Akto is unreachable. 'fail_open' = allow, 'fail_closed' = block.", + ) + + guardrail_timeout: Optional[int] = Field( + default=None, + description="HTTP timeout in seconds. Default: 5.", + ) + + @staticmethod + def ui_friendly_name() -> str: + return "Akto" diff --git a/tests/guardrails_tests/test_akto_guardrails.py b/tests/guardrails_tests/test_akto_guardrails.py new file mode 100644 index 0000000000..3c70104a21 --- /dev/null +++ b/tests/guardrails_tests/test_akto_guardrails.py @@ -0,0 +1,550 @@ +import asyncio +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from starlette.exceptions import HTTPException +from litellm.types.utils import GenericGuardrailAPIInputs +from litellm.proxy.guardrails.guardrail_registry import guardrail_initializer_registry, guardrail_class_registry +from litellm.proxy.guardrails.guardrail_hooks.akto.akto import AktoGuardrail + + +# --------------------------------------------------------------------------- +# Registry tests +# --------------------------------------------------------------------------- + + +def test_akto_in_guardrail_initializer_registry(): + assert "akto" in guardrail_initializer_registry + + +def test_akto_in_guardrail_class_registry(): + assert "akto" in guardrail_class_registry + assert guardrail_class_registry["akto"] is AktoGuardrail + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def akto_validate(): + """AktoGuardrail configured for pre_call (akto-validate).""" + return AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + unreachable_fallback="fail_closed", + guardrail_name="test-akto-validate", + event_hook="pre_call", + ) + + +@pytest.fixture +def akto_ingest(): + """AktoGuardrail configured for post_call (akto-ingest).""" + return AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + unreachable_fallback="fail_open", + guardrail_name="test-akto-ingest", + event_hook="post_call", + ) + + +@pytest.fixture +def sample_inputs() -> GenericGuardrailAPIInputs: + return GenericGuardrailAPIInputs( + texts=["Hello, how are you?"], + model="gpt-4", + ) + + +@pytest.fixture +def sample_request_data() -> dict: + return { + "metadata": { + "user_api_key_request_route": "/v1/chat/completions", + "user_api_key": "sk-test-123", + "user_api_key_user_id": "user-1", + "user_api_key_team_id": "team-1", + }, + "proxy_server_request": { + "headers": { + "x-forwarded-for": "10.0.0.1", + } + }, + } + + +def _mock_allowed_response(): + mock = MagicMock(spec=httpx.Response) + mock.status_code = 200 + mock.json.return_value = {"data": {"guardrailsResult": {"Allowed": True, "Reason": ""}}} + return mock + + +def _mock_blocked_response(reason="Prompt injection detected"): + mock = MagicMock(spec=httpx.Response) + mock.status_code = 200 + mock.json.return_value = {"data": {"guardrailsResult": {"Allowed": False, "Reason": reason}}} + return mock + + +# --------------------------------------------------------------------------- +# Initialization tests +# --------------------------------------------------------------------------- + + +def test_init_requires_akto_base_url(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="akto_base_url is required"): + AktoGuardrail( + akto_base_url="", + akto_api_key="test-token", + guardrail_name="test", + event_hook="pre_call", + ) + + +def test_init_requires_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="akto_api_key is required"): + AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="", + guardrail_name="test", + event_hook="pre_call", + ) + + +def test_init_from_env(): + with patch.dict( + os.environ, + { + "AKTO_GUARDRAIL_API_BASE": "http://env-host:9090", + "AKTO_API_KEY": "env-token", + "AKTO_ACCOUNT_ID": "2000000", + "AKTO_VXLAN_ID": "42", + }, + ): + g = AktoGuardrail(guardrail_name="env-test", event_hook="post_call") + assert g.akto_base_url == "http://env-host:9090" + assert g.akto_api_key == "env-token" + assert g.guardrail_timeout == 5 + assert g.akto_account_id == "2000000" + assert g.akto_vxlan_id == "42" + + +def test_init_defaults(): + g = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + guardrail_name="default-test", + event_hook="pre_call", + ) + assert g.unreachable_fallback == "fail_closed" + assert g.guardrail_timeout == 5 + assert g.akto_account_id == "1000000" + assert g.akto_vxlan_id == "0" + + +def test_background_tasks_per_instance(): + a = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + guardrail_name="instance-a", + event_hook="pre_call", + ) + b = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + guardrail_name="instance-b", + event_hook="post_call", + ) + assert a.background_tasks is not b.background_tasks + + +# --------------------------------------------------------------------------- +# Payload format tests +# --------------------------------------------------------------------------- + + +def test_build_akto_payload_format(akto_validate, sample_inputs, sample_request_data): + payload = akto_validate.build_akto_payload(sample_inputs, sample_request_data, include_response=False) + + assert payload["path"] == "/v1/chat/completions" + assert payload["method"] == "POST" + assert payload["type"] == "HTTP/1.1" + assert payload["akto_account_id"] == "1000000" + assert payload["akto_vxlan_id"] == "0" + assert payload["is_pending"] == "false" + assert payload["source"] == "MIRRORING" + assert payload["contextSource"] == "AGENTIC" + assert payload["ip"] == "10.0.0.1" + + req_headers = json.loads(payload["requestHeaders"]) + assert "content-type" in req_headers + + req_wrapper = json.loads(payload["requestPayload"]) + req_body = json.loads(req_wrapper["body"]) + assert req_body["model"] == "gpt-4" + assert req_body["messages"][0]["content"] == "Hello, how are you?" + + tag = json.loads(payload["tag"]) + assert tag["gen-ai"] == "Gen AI" + + assert payload["responsePayload"] == json.dumps({}) + assert payload["time"].isdigit() + assert len(payload["time"]) >= 13 + + +def test_build_akto_payload_with_response(akto_validate, sample_inputs, sample_request_data): + payload = akto_validate.build_akto_payload(sample_inputs, sample_request_data, include_response=True) + resp_wrapper = json.loads(payload["responsePayload"]) + resp_body = json.loads(resp_wrapper["body"]) + assert "choices" in resp_body + + +def test_build_akto_payload_custom_account_ids(sample_inputs, sample_request_data): + g = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + akto_account_id="9999", + akto_vxlan_id="7", + guardrail_name="custom-ids-test", + event_hook="pre_call", + ) + payload = g.build_akto_payload(sample_inputs, sample_request_data, include_response=False) + assert payload["akto_account_id"] == "9999" + assert payload["akto_vxlan_id"] == "7" + + +def test_build_query_params(): + params = AktoGuardrail.build_query_params(guardrails=True, ingest_data=False) + assert params == {"akto_connector": "litellm", "guardrails": "true"} + + params = AktoGuardrail.build_query_params(guardrails=False, ingest_data=True) + assert params == {"akto_connector": "litellm", "ingest_data": "true"} + + params = AktoGuardrail.build_query_params(guardrails=True, ingest_data=True) + assert params == { + "akto_connector": "litellm", + "guardrails": "true", + "ingest_data": "true", + } + + +# --------------------------------------------------------------------------- +# Guardrail response handling +# --------------------------------------------------------------------------- + + +def test_handle_guardrail_response_allowed(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"data": {"guardrailsResult": {"Allowed": True, "Reason": ""}}} + allowed, reason = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is True + assert reason == "" + + +def test_handle_guardrail_response_blocked(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"data": {"guardrailsResult": {"Allowed": False, "Reason": "PII detected"}}} + allowed, reason = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is False + assert reason == "PII detected" + + +def test_handle_guardrail_response_missing_result(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + allowed, _ = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is True + + +def test_handle_guardrail_response_data_none(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"data": None} + allowed, reason = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is True + assert reason == "" + + +def test_handle_guardrail_response_guardrails_result_not_dict(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"data": {"guardrailsResult": "invalid"}} + allowed, reason = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is True + assert reason == "" + + +def test_handle_guardrail_response_non_dict(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = "invalid" + allowed, _ = AktoGuardrail.handle_guardrail_response(mock_resp) + assert allowed is True + + +def test_handle_guardrail_response_error_status(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 500 + mock_resp.request = MagicMock() + with pytest.raises(httpx.HTTPStatusError): + AktoGuardrail.handle_guardrail_response(mock_resp) + + +def test_handle_guardrail_response_non_json_body(): + mock_resp = MagicMock(spec=httpx.Response) + mock_resp.status_code = 200 + mock_resp.request = MagicMock() + mock_resp.text = "not json" + mock_resp.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + + with pytest.raises(httpx.RequestError): + AktoGuardrail.handle_guardrail_response(mock_resp) + + +# --------------------------------------------------------------------------- +# Pre-call (akto-validate) — allowed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_pre_call_allowed(akto_validate, sample_inputs, sample_request_data): + akto_validate.async_handler.post = AsyncMock(return_value=_mock_allowed_response()) + + result = await akto_validate.apply_guardrail( + inputs=sample_inputs, + request_data=sample_request_data, + input_type="request", + ) + + assert result == sample_inputs + akto_validate.async_handler.post.assert_called_once() + call_params = akto_validate.async_handler.post.call_args.kwargs["params"] + assert call_params.get("guardrails") == "true" + assert "ingest_data" not in call_params + + +# --------------------------------------------------------------------------- +# Pre-call (akto-validate) — blocked +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_pre_call_blocked(akto_validate, sample_inputs, sample_request_data): + akto_validate.async_handler.post = AsyncMock( + side_effect=[ + _mock_blocked_response("PII detected"), + _mock_allowed_response(), + ] + ) + + with pytest.raises(HTTPException) as exc_info: + await akto_validate.apply_guardrail( + inputs=sample_inputs, + request_data=sample_request_data, + input_type="request", + ) + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert exc_info.value.status_code == 403 + + assert akto_validate.async_handler.post.call_count == 2 + + first_call_params = akto_validate.async_handler.post.call_args_list[0].kwargs["params"] + assert first_call_params.get("guardrails") == "true" + + second_call_params = akto_validate.async_handler.post.call_args_list[1].kwargs["params"] + assert second_call_params.get("ingest_data") == "true" + assert "guardrails" not in second_call_params + second_payload = json.loads(akto_validate.async_handler.post.call_args_list[1].kwargs["data"]) + assert second_payload["statusCode"] == "403" + resp_body = json.loads(second_payload["responsePayload"]) + inner = json.loads(resp_body["body"]) + assert inner["x-blocked-by"] == "Akto Proxy" + assert inner["reason"] == "PII detected" + + +# --------------------------------------------------------------------------- +# Pre-call (akto-validate) — response input is no-op +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_response_noop(akto_validate, sample_inputs, sample_request_data): + akto_validate.async_handler.post = AsyncMock() + + result = await akto_validate.apply_guardrail( + inputs=sample_inputs, + request_data=sample_request_data, + input_type="response", + ) + + assert result == sample_inputs + akto_validate.async_handler.post.assert_not_called() + + +# --------------------------------------------------------------------------- +# Post-call (akto-ingest) — combined guardrail + ingest +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_post_call_combined(akto_ingest, sample_inputs, sample_request_data): + akto_ingest.async_handler.post = AsyncMock(return_value=_mock_allowed_response()) + + result = await akto_ingest.apply_guardrail( + inputs=sample_inputs, + request_data=sample_request_data, + input_type="response", + ) + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert result == sample_inputs + akto_ingest.async_handler.post.assert_called_once() + call_params = akto_ingest.async_handler.post.call_args.kwargs["params"] + assert call_params.get("guardrails") == "true" + assert call_params.get("ingest_data") == "true" + + +# --------------------------------------------------------------------------- +# Post-call (akto-ingest) — request input is no-op +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ingest_request_noop(akto_ingest, sample_inputs, sample_request_data): + akto_ingest.async_handler.post = AsyncMock() + + result = await akto_ingest.apply_guardrail( + inputs=sample_inputs, + request_data=sample_request_data, + input_type="request", + ) + + assert result == sample_inputs + akto_ingest.async_handler.post.assert_not_called() + + +# --------------------------------------------------------------------------- +# Fail-open / fail-closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_fail_open_on_unreachable(): + g = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + unreachable_fallback="fail_open", + guardrail_name="fail-open-test", + event_hook="pre_call", + ) + g.async_handler.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + + inputs = GenericGuardrailAPIInputs(texts=["test"], model="gpt-4") + result = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request") + + assert result.get("texts") == ["test"] + + +@pytest.mark.asyncio +async def test_fail_closed_on_unreachable(): + g = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + unreachable_fallback="fail_closed", + guardrail_name="fail-closed-test", + event_hook="pre_call", + ) + g.async_handler.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + + inputs = GenericGuardrailAPIInputs(texts=["test"], model="gpt-4") + with pytest.raises(HTTPException) as exc_info: + await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request") + assert exc_info.value.status_code == 503 + + +def test_fail_closed_generic_message(): + g = AktoGuardrail( + akto_base_url="http://localhost:9090", + akto_api_key="test-token", + unreachable_fallback="fail_closed", + guardrail_name="msg-test", + event_hook="pre_call", + ) + with pytest.raises(HTTPException) as exc_info: + g.handle_unreachable( + inputs=GenericGuardrailAPIInputs(texts=["test"], model="gpt-4"), + error=Exception("http://internal-host:9090/secret-path"), + ) + assert "internal-host" not in exc_info.value.detail + assert exc_info.value.detail == "Akto guardrail service unreachable" + + +# --------------------------------------------------------------------------- +# Helper method tests +# --------------------------------------------------------------------------- + + +def test_extract_request_path_from_metadata(): + path = AktoGuardrail.extract_request_path({"metadata": {"user_api_key_request_route": "/v1/embeddings"}}) + assert path == "/v1/embeddings" + + +def test_extract_request_path_fallback(): + path = AktoGuardrail.extract_request_path({}) + assert path == "/v1/chat/completions" + + +def test_extract_request_path_non_dict_metadata(): + path = AktoGuardrail.extract_request_path({"metadata": "invalid"}) + assert path == "/v1/chat/completions" + + +def test_resolve_metadata_value(): + assert ( + AktoGuardrail.resolve_metadata_value({"metadata": {"user_api_key_user_id": "u1"}}, "user_api_key_user_id") + == "u1" + ) + assert ( + AktoGuardrail.resolve_metadata_value( + {"litellm_metadata": {"user_api_key_team_id": "t1"}}, + "user_api_key_team_id", + ) + == "t1" + ) + assert AktoGuardrail.resolve_metadata_value({}, "some_key") is None + assert AktoGuardrail.resolve_metadata_value(None, "some_key") is None + + +def test_resolve_metadata_value_non_dict_containers(): + assert ( + AktoGuardrail.resolve_metadata_value( + {"metadata": "invalid", "litellm_metadata": ["bad"]}, + "some_key", + ) + is None + ) + + +def test_build_tag_metadata(akto_validate, sample_request_data): + tag = akto_validate.build_tag_metadata(sample_request_data) + assert tag["gen-ai"] == "Gen AI" + assert tag["user_id"] == "user-1" + assert tag["team_id"] == "team-1" diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_dynamoai.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_dynamoai.py new file mode 100644 index 0000000000..7bc4e951a5 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_dynamoai.py @@ -0,0 +1,81 @@ +""" +Tests for DynamoAI guardrail registration and initialization. +""" + +import os +from unittest.mock import patch + +import pytest + + +class TestDynamoAIGuardrailRegistration: + """Tests for DynamoAI guardrail registration in the guardrail system.""" + + def test_supported_guardrail_enum_entry(self): + """Test that DYNAMOAI is in SupportedGuardrailIntegrations enum.""" + from litellm.types.guardrails import SupportedGuardrailIntegrations + + assert hasattr(SupportedGuardrailIntegrations, "DYNAMOAI") + assert SupportedGuardrailIntegrations.DYNAMOAI.value == "dynamoai" + + def test_initialize_guardrail_function_exists(self): + """Test that initialize_guardrail function is properly exported.""" + from litellm.proxy.guardrails.guardrail_hooks.dynamoai import ( + guardrail_initializer_registry, + initialize_guardrail, + ) + + assert initialize_guardrail is not None + assert "dynamoai" in guardrail_initializer_registry + + def test_guardrail_class_registry_exists(self): + """Test that guardrail_class_registry is properly exported.""" + from litellm.proxy.guardrails.guardrail_hooks.dynamoai import ( + guardrail_class_registry, + ) + from litellm.proxy.guardrails.guardrail_hooks.dynamoai.dynamoai import ( + DynamoAIGuardrails, + ) + + assert "dynamoai" in guardrail_class_registry + assert guardrail_class_registry["dynamoai"] == DynamoAIGuardrails + + def test_initialize_guardrail_creates_instance(self): + """Test that initialize_guardrail creates a DynamoAIGuardrails instance.""" + from litellm.proxy.guardrails.guardrail_hooks.dynamoai import ( + initialize_guardrail, + ) + from litellm.proxy.guardrails.guardrail_hooks.dynamoai.dynamoai import ( + DynamoAIGuardrails, + ) + from litellm.types.guardrails import LitellmParams + + litellm_params = LitellmParams( + guardrail="dynamoai", + mode="pre_call", + api_key="test-key", + api_base="https://test.dynamo.ai", + ) + + guardrail = { + "guardrail_name": "test-dynamoai-guard", + } + + with patch( + "litellm.logging_callback_manager.add_litellm_callback" + ) as mock_add: + result = initialize_guardrail(litellm_params, guardrail) + + assert isinstance(result, DynamoAIGuardrails) + assert result.api_key == "test-key" + assert result.api_base == "https://test.dynamo.ai" + assert result.guardrail_name == "test-dynamoai-guard" + mock_add.assert_called_once_with(result) + + def test_dynamoai_in_global_registry(self): + """Test that dynamoai is discoverable in the global guardrail registry.""" + from litellm.proxy.guardrails.guardrail_registry import ( + guardrail_initializer_registry, + ) + + assert "dynamoai" in guardrail_initializer_registry diff --git a/ui/litellm-dashboard/public/assets/logos/akto.svg b/ui/litellm-dashboard/public/assets/logos/akto.svg new file mode 100644 index 0000000000..cdea32535f --- /dev/null +++ b/ui/litellm-dashboard/public/assets/logos/akto.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_configs.ts b/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_configs.ts index 7a1b5314d3..e42ecaef57 100644 --- a/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_configs.ts +++ b/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_configs.ts @@ -264,4 +264,10 @@ export const GUARDRAIL_PRESETS: Record = { mode: "pre_call", defaultOn: false, }, + akto: { + provider: "Akto", + guardrailNameSuggestion: "Akto Guardrail", + mode: "pre_call", + defaultOn: false, + }, }; diff --git a/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_data.ts b/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_data.ts index 53ccb32c18..b06400ce50 100644 --- a/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_data.ts +++ b/ui/litellm-dashboard/src/components/guardrails/guardrail_garden_data.ts @@ -373,6 +373,14 @@ export const PARTNER_GUARDRAIL_CARDS: GuardrailCardInfo[] = [ logo: `${ASSET_PREFIX}pillar.jpeg`, tags: ["Monitoring", "Safety"], }, + { + id: "akto", + name: "Akto Guardrail", + description: "AI security platform from Akto.io with automatic monitoring and guardrails for AI/ML applications.", + category: "partner", + logo: `${ASSET_PREFIX}akto.svg`, + tags: ["Security", "Safety", "Monitoring"], + }, ]; export const ALL_CARDS = [...LITELLM_CONTENT_FILTER_CARDS, ...PARTNER_GUARDRAIL_CARDS]; diff --git a/ui/litellm-dashboard/src/components/guardrails/guardrail_info_helpers.tsx b/ui/litellm-dashboard/src/components/guardrails/guardrail_info_helpers.tsx index d957be4306..c78835dae0 100644 --- a/ui/litellm-dashboard/src/components/guardrails/guardrail_info_helpers.tsx +++ b/ui/litellm-dashboard/src/components/guardrails/guardrail_info_helpers.tsx @@ -125,6 +125,7 @@ export const guardrailLogoMap: Record = { EnkryptAI: `${asset_logos_folder}enkrypt_ai.avif`, "Prompt Security": `${asset_logos_folder}prompt_security.png`, "LiteLLM Content Filter": `${asset_logos_folder}litellm_logo.jpg`, + "Akto": `${asset_logos_folder}akto.svg`, }; export const getGuardrailLogoAndName = (guardrailValue: string): { logo: string; displayName: string } => { From 20f8d413e59098bad5410e2d737e54c6beba40bd Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 19:10:19 -0300 Subject: [PATCH 053/539] fix(anthropic): preserve cache_control on file-type content blocks Fixes #23873 --- .../prompt_templates/factory.py | 11 ++-- ...llm_core_utils_prompt_templates_factory.py | 53 +++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index ea1f81f9b3..82afe54b80 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -2441,11 +2441,14 @@ def anthropic_messages_pt( # noqa: PLR0915 elif m.get("type", "") == "document": user_content.append(cast(AnthropicMessagesDocumentParam, m)) elif m.get("type", "") == "file": - user_content.append( - anthropic_process_openai_file_message( - cast(ChatCompletionFileObject, m) - ) + _file_content_element = anthropic_process_openai_file_message( + cast(ChatCompletionFileObject, m) ) + _file_content_element = add_cache_control_to_content( + anthropic_content_element=_file_content_element, + original_content_element=dict(m), + ) + user_content.append(_file_content_element) elif isinstance(user_message_types_block["content"], str): _anthropic_content_text_element: AnthropicMessagesTextParam = { "type": "text", diff --git a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py index 8d68539564..438674bfb1 100644 --- a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py +++ b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_factory.py @@ -2027,3 +2027,56 @@ def test_sanitize_messages_combined_case_a_and_case_d(): ) finally: litellm.modify_params = original + + +def test_anthropic_messages_pt_file_block_preserves_cache_control(): + """ + Test that cache_control is preserved on file-type content blocks + when translated to Anthropic document params. + Regression test for https://github.com/BerriAI/litellm/issues/23873 + """ + from litellm.litellm_core_utils.prompt_templates.factory import ( + anthropic_messages_pt, + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "file", + "file": { + "filename": "doc.pdf", + "file_data": "data:application/pdf;base64,JVBERi0xLjQ=", + }, + "cache_control": {"type": "ephemeral"}, + }, + { + "type": "text", + "text": "Summarize this document.", + "cache_control": {"type": "ephemeral"}, + }, + ], + } + ] + + result = anthropic_messages_pt( + messages, model="claude-sonnet-4-20250514", llm_provider="anthropic" + ) + + content_blocks = result[0]["content"] + assert len(content_blocks) == 2 + + # Document block (from file) should preserve cache_control + doc_block = content_blocks[0] + assert doc_block["type"] == "document" + assert "cache_control" in doc_block, ( + "cache_control was dropped from file/document block" + ) + assert doc_block["cache_control"]["type"] == "ephemeral" + + # Text block should also preserve cache_control + text_block = content_blocks[1] + assert text_block["type"] == "text" + assert "cache_control" in text_block + assert text_block["cache_control"]["type"] == "ephemeral" From 8f015e2db2ebb16598bf15e14e0e9d9d8509a2d4 Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 19:14:08 -0300 Subject: [PATCH 054/539] fix(vertex): respect vertex_count_tokens_location for Claude count_tokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The count_tokens handler unconditionally overrode vertex_location to us-central1 for Claude models, ignoring the user-configured vertex_count_tokens_location parameter. Also, us-central1 is no longer a supported region — Google now supports us-east5, europe-west1, and asia-southeast1. Now vertex_count_tokens_location takes precedence, vertex_location is used as fallback, and us-east5 is the default only when neither is set. Fixes #23872 --- .../count_tokens/handler.py | 12 +- .../count_tokens/__init__.py | 0 .../test_count_tokens_location.py | 164 ++++++++++++++++++ 3 files changed, 172 insertions(+), 4 deletions(-) create mode 100644 tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/__init__.py create mode 100644 tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py index c6914ac3d6..079a691395 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py @@ -105,12 +105,16 @@ class VertexAIPartnerModelsTokenCounter(VertexBase): # Extract Vertex AI credentials and settings vertex_credentials = self.get_vertex_ai_credentials(litellm_params) vertex_project = self.get_vertex_ai_project(litellm_params) - vertex_location = self.get_vertex_ai_location(litellm_params) + vertex_location = ( + litellm_params.get("vertex_count_tokens_location") + or self.get_vertex_ai_location(litellm_params) + ) - # Map empty location/cluade models to a supported region for count-tokens endpoint + # Default Claude models to us-east5 for count-tokens endpoint when no location is set + # Supported regions: us-east5, europe-west1, asia-southeast1 # https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens - if not vertex_location or "claude" in model.lower(): - vertex_location = "us-central1" + if not vertex_location and "claude" in model.lower(): + vertex_location = "us-east5" # Get access token and resolved project ID access_token, project_id = await self._ensure_access_token_async( diff --git a/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/__init__.py b/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py b/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py new file mode 100644 index 0000000000..6487ea25f2 --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/test_count_tokens_location.py @@ -0,0 +1,164 @@ +""" +Tests for Vertex AI partner models count_tokens location resolution. + +Ref: https://github.com/BerriAI/litellm/issues/23872 +""" +import pytest + +from litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler import ( + VertexAIPartnerModelsTokenCounter, +) + + +@pytest.fixture +def counter(): + return VertexAIPartnerModelsTokenCounter() + + +class TestCountTokensLocationResolution: + """Verify that vertex_count_tokens_location is respected in handle_count_tokens_request.""" + + def _build_litellm_params( + self, + vertex_location=None, + vertex_count_tokens_location=None, + ): + params = {} + if vertex_location is not None: + params["vertex_location"] = vertex_location + if vertex_count_tokens_location is not None: + params["vertex_count_tokens_location"] = vertex_count_tokens_location + return params + + @pytest.mark.asyncio + async def test_count_tokens_location_overrides_vertex_location(self, counter, monkeypatch): + """vertex_count_tokens_location should take precedence over vertex_location.""" + captured = {} + + async def fake_ensure_access_token(self, credentials, project_id, custom_llm_provider): + return "fake-token", "fake-project" + + def fake_build_endpoint(self, model, project_id, vertex_location, api_base=None): + captured["vertex_location"] = vertex_location + return "https://fake-endpoint" + + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_ensure_access_token_async", fake_ensure_access_token + ) + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_build_count_tokens_endpoint", fake_build_endpoint + ) + + # Mock the HTTP call to avoid real network requests + class FakeResponse: + status_code = 200 + def json(self): + return {"input_tokens": 10} + def raise_for_status(self): + pass + + class FakeClient: + async def post(self, url, headers=None, json=None, **kwargs): + return FakeResponse() + + import litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler as handler_mod + monkeypatch.setattr(handler_mod, "get_async_httpx_client", lambda **kwargs: FakeClient()) + + litellm_params = self._build_litellm_params( + vertex_location="us-east5", + vertex_count_tokens_location="europe-west1", + ) + + await counter.handle_count_tokens_request( + model="claude-sonnet-4-6", + request_data={"messages": [{"role": "user", "content": "hi"}]}, + litellm_params=litellm_params, + ) + + assert captured["vertex_location"] == "europe-west1" + + @pytest.mark.asyncio + async def test_claude_without_count_tokens_location_defaults_to_us_east5(self, counter, monkeypatch): + """Claude models without any location should default to us-east5.""" + captured = {} + + async def fake_ensure_access_token(self, credentials, project_id, custom_llm_provider): + return "fake-token", "fake-project" + + def fake_build_endpoint(self, model, project_id, vertex_location, api_base=None): + captured["vertex_location"] = vertex_location + return "https://fake-endpoint" + + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_ensure_access_token_async", fake_ensure_access_token + ) + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_build_count_tokens_endpoint", fake_build_endpoint + ) + + class FakeResponse: + status_code = 200 + def json(self): + return {"input_tokens": 10} + def raise_for_status(self): + pass + + class FakeClient: + async def post(self, url, headers=None, json=None, **kwargs): + return FakeResponse() + + import litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler as handler_mod + monkeypatch.setattr(handler_mod, "get_async_httpx_client", lambda **kwargs: FakeClient()) + + litellm_params = self._build_litellm_params() # no location at all + + await counter.handle_count_tokens_request( + model="claude-sonnet-4-6", + request_data={"messages": [{"role": "user", "content": "hi"}]}, + litellm_params=litellm_params, + ) + + assert captured["vertex_location"] == "us-east5" + + @pytest.mark.asyncio + async def test_claude_with_vertex_location_uses_it(self, counter, monkeypatch): + """Claude models with vertex_location but no count_tokens_location should use vertex_location.""" + captured = {} + + async def fake_ensure_access_token(self, credentials, project_id, custom_llm_provider): + return "fake-token", "fake-project" + + def fake_build_endpoint(self, model, project_id, vertex_location, api_base=None): + captured["vertex_location"] = vertex_location + return "https://fake-endpoint" + + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_ensure_access_token_async", fake_ensure_access_token + ) + monkeypatch.setattr( + VertexAIPartnerModelsTokenCounter, "_build_count_tokens_endpoint", fake_build_endpoint + ) + + class FakeResponse: + status_code = 200 + def json(self): + return {"input_tokens": 10} + def raise_for_status(self): + pass + + class FakeClient: + async def post(self, url, headers=None, json=None, **kwargs): + return FakeResponse() + + import litellm.llms.vertex_ai.vertex_ai_partner_models.count_tokens.handler as handler_mod + monkeypatch.setattr(handler_mod, "get_async_httpx_client", lambda **kwargs: FakeClient()) + + litellm_params = self._build_litellm_params(vertex_location="asia-southeast1") + + await counter.handle_count_tokens_request( + model="claude-sonnet-4-6", + request_data={"messages": [{"role": "user", "content": "hi"}]}, + litellm_params=litellm_params, + ) + + assert captured["vertex_location"] == "asia-southeast1" From 9afc4697258340100244e1759165790c792ebc47 Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 23:04:51 -0300 Subject: [PATCH 055/539] fix(mistral): preserve diarization segments in transcription response MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #23890 — Mistral's Voxtral transcription with `diarize=true` returns `segments` (with speaker_id, timestamps) and `language`, but these fields were dropped when mapping the response to TranscriptionResponse. --- .../audio_transcription/transformation.py | 7 +++ ...tral_audio_transcription_transformation.py | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/litellm/llms/mistral/audio_transcription/transformation.py b/litellm/llms/mistral/audio_transcription/transformation.py index 4d29406349..8c6d604acb 100644 --- a/litellm/llms/mistral/audio_transcription/transformation.py +++ b/litellm/llms/mistral/audio_transcription/transformation.py @@ -148,5 +148,12 @@ class MistralAudioTranscriptionConfig(BaseAudioTranscriptionConfig): text = response_json.get("text") or "" response = TranscriptionResponse(text=text) + + # Preserve Mistral-specific fields (e.g. diarization segments) + if "segments" in response_json: + response["segments"] = response_json["segments"] + if "language" in response_json: + response["language"] = response_json["language"] + response._hidden_params = response_json return response diff --git a/tests/test_litellm/llms/mistral/audio_transcription/test_mistral_audio_transcription_transformation.py b/tests/test_litellm/llms/mistral/audio_transcription/test_mistral_audio_transcription_transformation.py index 7ef50dede0..4ca3e8ae0c 100644 --- a/tests/test_litellm/llms/mistral/audio_transcription/test_mistral_audio_transcription_transformation.py +++ b/tests/test_litellm/llms/mistral/audio_transcription/test_mistral_audio_transcription_transformation.py @@ -158,6 +158,50 @@ def test_mistral_audio_transcription_response_transform(): assert response.text == "Four score and seven years ago..." +def test_mistral_audio_transcription_response_transform_diarized(): + """Test that diarized responses preserve segments and language.""" + config = MistralAudioTranscriptionConfig() + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = { + "model": "voxtral-mini-latest", + "text": "Hello, how are you? I am fine.", + "language": None, + "segments": [ + { + "text": "Hello, how are you?", + "start": 0.3, + "end": 2.1, + "speaker_id": "speaker_1", + "type": "transcription_segment", + }, + { + "text": "I am fine.", + "start": 2.5, + "end": 3.8, + "speaker_id": "speaker_2", + "type": "transcription_segment", + }, + ], + "usage": { + "prompt_audio_seconds": 4, + "prompt_tokens": 5, + "total_tokens": 50, + "completion_tokens": 20, + }, + } + + response = config.transform_audio_transcription_response(mock_response) + + assert isinstance(response, TranscriptionResponse) + assert response.text == "Hello, how are you? I am fine." + assert response["segments"] is not None + assert len(response["segments"]) == 2 + assert response["segments"][0]["speaker_id"] == "speaker_1" + assert response["segments"][1]["speaker_id"] == "speaker_2" + assert response["language"] is None + + def test_mistral_audio_transcription_response_transform_empty(): config = MistralAudioTranscriptionConfig() From cb15296693d5eba6ea38c2c9658c6ef37580a88c Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 23:06:39 -0300 Subject: [PATCH 056/539] fix(azure): auto-route gpt-5.4+ tools+reasoning to Responses API Azure GPT-5.4+ models now get the same auto-routing treatment as OpenAI when both `reasoning_effort` and `tools` are used in `litellm.completion()`. Previously, `reasoning_effort` was silently dropped for Azure; now the request is bridged to the Responses API which supports both parameters. Fixes #23914 --- docs/my-website/docs/reasoning_content.md | 4 +-- .../llms/azure/chat/gpt_5_transformation.py | 11 ++---- litellm/main.py | 23 +++++++------ .../chat/test_azure_gpt5_transformation.py | 9 ++--- tests/test_litellm/test_main.py | 34 +++++++++++++++++++ 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/docs/my-website/docs/reasoning_content.md b/docs/my-website/docs/reasoning_content.md index 8bf59f66a3..fb77f53a56 100644 --- a/docs/my-website/docs/reasoning_content.md +++ b/docs/my-website/docs/reasoning_content.md @@ -594,9 +594,9 @@ Expected Response :::tip gpt-5.4: reasoning_effort + function tools -LiteLLM drops `reasoning_effort` from `gpt-5.4` requests to `litellm.completion()` that include tools, since that combination is supported in the Responses API. +When `gpt-5.4+` requests to `litellm.completion()` include both `reasoning_effort` and `tools`, LiteLLM **automatically routes** the request through the Responses API bridge. This works for both **OpenAI** (`openai/gpt-5.4`) and **Azure** (`azure/gpt-5.4`) providers — no extra configuration needed. -If you need reasoning **and** tools together, use `openai/responses/gpt-5.4` to route through the Responses API instead. See [Responses API Bridge](/docs/providers/openai#openai-chat-completion-to-responses-api-bridge) for details. +You can also route explicitly via `openai/responses/gpt-5.4` or `azure/responses/gpt-5.4`. See [Responses API Bridge](/docs/providers/openai#openai-chat-completion-to-responses-api-bridge) for details. ::: diff --git a/litellm/llms/azure/chat/gpt_5_transformation.py b/litellm/llms/azure/chat/gpt_5_transformation.py index 6310df9cec..bc7483bf64 100644 --- a/litellm/llms/azure/chat/gpt_5_transformation.py +++ b/litellm/llms/azure/chat/gpt_5_transformation.py @@ -131,14 +131,9 @@ class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config): if result_effort == "none" and not supports_none: result.pop("reasoning_effort") - # Azure Chat Completions: gpt-5.4+ does not support tools + reasoning together. - # Drop reasoning_effort when both are present (OpenAI routes to Responses API; Azure does not). - if self.is_model_gpt_5_4_plus_model(model): - has_tools = bool( - non_default_params.get("tools") or optional_params.get("tools") - ) - if has_tools and result_effort not in (None, "none"): - result.pop("reasoning_effort", None) + # Azure gpt-5.4+ with tools + reasoning_effort is now routed to the + # Responses API bridge (same as OpenAI), so we no longer need to drop + # reasoning_effort here. See: responses_api_bridge_check() in main.py. return result diff --git a/litellm/main.py b/litellm/main.py index 722b4a7aae..e74a34a1ff 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -955,16 +955,6 @@ def responses_api_bridge_check( model_info["mode"] = "responses" model = model.replace("responses/", "") - # OpenAI gpt-5.4+ chat-completions calls with both tools + reasoning_effort - # must be bridged to Responses API. - if ( - custom_llm_provider == "openai" - and OpenAIGPT5Config.is_model_gpt_5_4_plus_model(model) - and tools - and reasoning_effort is not None - ): - model_info["mode"] = "responses" - model = model.replace("responses/", "") except Exception as e: verbose_logger.debug("Error getting model info: {}".format(e)) @@ -974,6 +964,19 @@ def responses_api_bridge_check( model = model.replace("responses/", "") mode = "responses" model_info["mode"] = mode + + # OpenAI/Azure gpt-5.4+ chat-completions calls with both tools + reasoning_effort + # must be bridged to Responses API. + if ( + custom_llm_provider in ("openai", "azure") + and OpenAIGPT5Config.is_model_gpt_5_4_plus_model(model) + and tools + and reasoning_effort is not None + and model_info.get("mode") != "responses" + ): + model_info["mode"] = "responses" + model = model.replace("responses/", "") + return model_info, model diff --git a/tests/test_litellm/llms/azure/chat/test_azure_gpt5_transformation.py b/tests/test_litellm/llms/azure/chat/test_azure_gpt5_transformation.py index 635359563b..28ccf7ffa8 100644 --- a/tests/test_litellm/llms/azure/chat/test_azure_gpt5_transformation.py +++ b/tests/test_litellm/llms/azure/chat/test_azure_gpt5_transformation.py @@ -192,10 +192,11 @@ def test_azure_gpt5_1_series_temperature_handling(config: AzureOpenAIGPT5Config) assert params["temperature"] == 0.6 -def test_azure_gpt5_4_drops_reasoning_effort_when_tools_present(config: AzureOpenAIGPT5Config): - """Azure Chat Completions: gpt-5.4+ drops reasoning_effort when tools are present. +def test_azure_gpt5_4_preserves_reasoning_effort_when_tools_present(config: AzureOpenAIGPT5Config): + """Azure GPT-5.4+ no longer drops reasoning_effort when tools are present. - OpenAI routes tools+reasoning to Responses API; Azure does not, so we drop reasoning_effort. + Both OpenAI and Azure now route tools+reasoning to the Responses API bridge, + so reasoning_effort must be preserved in map_openai_params. """ tools = [{"type": "function", "function": {"name": "test", "description": "test"}}] params = config.map_openai_params( @@ -205,7 +206,7 @@ def test_azure_gpt5_4_drops_reasoning_effort_when_tools_present(config: AzureOpe drop_params=False, api_version="2024-05-01-preview", ) - assert "reasoning_effort" not in params + assert params.get("reasoning_effort") == "high" assert params["tools"] == tools diff --git a/tests/test_litellm/test_main.py b/tests/test_litellm/test_main.py index 6ac988b2c2..ce5873f506 100644 --- a/tests/test_litellm/test_main.py +++ b/tests/test_litellm/test_main.py @@ -661,6 +661,40 @@ def test_responses_api_bridge_check_gpt_5_5_tools_plus_reasoning_routes_to_respo assert model_info.get("mode") == "responses" +def test_responses_api_bridge_check_azure_gpt_5_4_tools_plus_reasoning_routes_to_responses(): + """Azure gpt-5.4 with both tools and reasoning_effort should route to Responses API.""" + from litellm.main import responses_api_bridge_check + + with patch("litellm.main._get_model_info_helper") as mock_get_model_info: + mock_get_model_info.return_value = {"max_tokens": 128000} + model_info, model = responses_api_bridge_check( + model="gpt-5.4", + custom_llm_provider="azure", + tools=[{"type": "function", "function": {"name": "get_capital"}}], + reasoning_effort="high", + ) + + assert model == "gpt-5.4" + assert model_info.get("mode") == "responses" + + +def test_responses_api_bridge_check_azure_gpt_5_4_tools_without_reasoning_stays_chat(): + """Azure gpt-5.4 with tools only should not be force-routed to Responses API.""" + from litellm.main import responses_api_bridge_check + + with patch("litellm.main._get_model_info_helper") as mock_get_model_info: + mock_get_model_info.return_value = {"max_tokens": 128000} + model_info, model = responses_api_bridge_check( + model="gpt-5.4", + custom_llm_provider="azure", + tools=[{"type": "function", "function": {"name": "get_capital"}}], + reasoning_effort=None, + ) + + assert model == "gpt-5.4" + assert model_info.get("mode") != "responses" + + def test_responses_api_bridge_check_gpt_5_4_tools_without_reasoning_stays_chat(): """gpt-5.4 with tools only should not be force-routed to Responses API.""" from litellm.main import responses_api_bridge_check From 8828f002bea2d9842fb95fae78c90d60bc7ce41d Mon Sep 17 00:00:00 2001 From: Chesars Date: Tue, 17 Mar 2026 23:21:24 -0300 Subject: [PATCH 057/539] fix(gemini): pass model to context caching URL builder for custom api_base _get_token_and_url_context_caching() was hardcoding model=None when calling _check_custom_proxy(), which raises ValueError when api_base is set because Gemini proxy URLs need the model name: {api_base}/models/{model}:cachedContents Fixes #23846 --- .../vertex_ai_context_caching.py | 5 ++- .../test_vertex_ai_context_caching.py | 39 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py index db6be9499a..c2e064d656 100644 --- a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py @@ -51,6 +51,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project: Optional[str], vertex_location: Optional[str], vertex_auth_header: Optional[str], + model: Optional[str] = None, ) -> Tuple[Optional[str], str]: """ Internal function. Returns the token and url for the call. @@ -89,7 +90,7 @@ class ContextCachingEndpoints(VertexBase): stream=None, auth_header=auth_header, url=url, - model=None, + model=model, vertex_project=vertex_project, vertex_location=vertex_location, vertex_api_version="v1beta1" @@ -342,6 +343,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) headers = { @@ -488,6 +490,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) headers = { diff --git a/tests/test_litellm/llms/vertex_ai/context_caching/test_vertex_ai_context_caching.py b/tests/test_litellm/llms/vertex_ai/context_caching/test_vertex_ai_context_caching.py index 3f8cbf1236..11ccd34804 100644 --- a/tests/test_litellm/llms/vertex_ai/context_caching/test_vertex_ai_context_caching.py +++ b/tests/test_litellm/llms/vertex_ai/context_caching/test_vertex_ai_context_caching.py @@ -1317,4 +1317,41 @@ class TestVertexAIGlobalLocation: # Assert correct URL format for global with beta API expected_url = "https://aiplatform.googleapis.com/v1beta1/projects/test-project/locations/global/cachedContents" assert url == expected_url, f"Expected {expected_url}, got {url}" - assert "global-aiplatform" not in url, "URL should not contain 'global-aiplatform' prefix" \ No newline at end of file + assert "global-aiplatform" not in url, "URL should not contain 'global-aiplatform' prefix" + + def test_gemini_context_caching_with_custom_api_base_passes_model(self): + """Gemini context caching with custom api_base must pass model to _check_custom_proxy. + + Regression test for https://github.com/BerriAI/litellm/issues/23846 + Previously model was hardcoded to None, causing ValueError when api_base was set. + """ + caching = ContextCachingEndpoints() + + auth_header, url = caching._get_token_and_url_context_caching( + gemini_api_key="test-key", + custom_llm_provider="gemini", + api_base="https://my-proxy.example.com", + vertex_project=None, + vertex_location=None, + vertex_auth_header=None, + model="gemini-1.5-pro", + ) + + assert "models/gemini-1.5-pro" in url + assert url.startswith("https://my-proxy.example.com/") + + def test_gemini_context_caching_without_api_base_ignores_model(self): + """Without custom api_base, model param is not needed (default URL is used).""" + caching = ContextCachingEndpoints() + + auth_header, url = caching._get_token_and_url_context_caching( + gemini_api_key="test-key", + custom_llm_provider="gemini", + api_base=None, + vertex_project=None, + vertex_location=None, + vertex_auth_header=None, + ) + + assert "generativelanguage.googleapis.com" in url + assert "cachedContents" in url \ No newline at end of file From 0564e9547b130587d286ae57d3cb6fe83c07eea8 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 08:37:31 +0530 Subject: [PATCH 058/539] Fix greptile comments --- litellm/model_prices_and_context_window_backup.json | 13 ++++++++----- model_prices_and_context_window.json | 13 ++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index c80d637233..2cc63523da 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -3435,7 +3435,8 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, - "supports_none_reasoning_effort": true + "supports_none_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "azure/gpt-5.1-chat-2025-11-13": { "cache_read_input_token_cost": 1.25e-07, @@ -18217,7 +18218,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1": { "cache_read_input_token_cost": 1.25e-07, @@ -18256,7 +18258,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-2025-11-13": { "cache_read_input_token_cost": 1.25e-07, @@ -18295,7 +18298,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-chat-latest": { "cache_read_input_token_cost": 1.25e-07, @@ -19406,7 +19410,6 @@ "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": false, "supports_minimal_reasoning_effort": true - }, "gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index c80d637233..2cc63523da 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -3435,7 +3435,8 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, - "supports_none_reasoning_effort": true + "supports_none_reasoning_effort": true, + "supports_minimal_reasoning_effort": true }, "azure/gpt-5.1-chat-2025-11-13": { "cache_read_input_token_cost": 1.25e-07, @@ -18217,7 +18218,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": false, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1": { "cache_read_input_token_cost": 1.25e-07, @@ -18256,7 +18258,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-2025-11-13": { "cache_read_input_token_cost": 1.25e-07, @@ -18295,7 +18298,8 @@ "supports_vision": true, "supports_web_search": true, "supports_none_reasoning_effort": true, - "supports_xhigh_reasoning_effort": false + "supports_xhigh_reasoning_effort": false, + "supports_minimal_reasoning_effort": true }, "gpt-5.1-chat-latest": { "cache_read_input_token_cost": 1.25e-07, @@ -19406,7 +19410,6 @@ "supports_none_reasoning_effort": false, "supports_xhigh_reasoning_effort": false, "supports_minimal_reasoning_effort": true - }, "gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, From c20c465a028db110d346e6f0a123677bb74262ad Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 08:55:40 +0530 Subject: [PATCH 059/539] greptile comments --- litellm/model_prices_and_context_window_backup.json | 2 +- model_prices_and_context_window.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 2cc63523da..44fe6a06ad 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18763,7 +18763,7 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, - "cache_read_input_token_cost_flex": 3.75e-08, + "cache_read_input_token_cost_flex": 1e-08, "cache_read_input_token_cost_batches": 3.8e-08, "input_cost_per_token": 7.5e-07, "input_cost_per_token_flex": 3.75e-07, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 2cc63523da..44fe6a06ad 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18763,7 +18763,7 @@ }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, - "cache_read_input_token_cost_flex": 3.75e-08, + "cache_read_input_token_cost_flex": 1e-08, "cache_read_input_token_cost_batches": 3.8e-08, "input_cost_per_token": 7.5e-07, "input_cost_per_token_flex": 3.75e-07, From 6fe3188af048905504277d9566336a0a312d95a1 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:04:00 +0530 Subject: [PATCH 060/539] fix(azure-ai-agents): accumulate annotations from multiple text items in streaming - Fix bug where only last text item's annotations were preserved when thread.message.completed contained multiple text content items - Accumulate annotations via extend() instead of overwriting - Add test_azure_ai_agents_streaming_annotations_from_completed_message - Add test_azure_ai_agents_streaming_accumulates_annotations_from_multiple_text_items Addresses Greptile review on PR #23849 Made-with: Cursor --- litellm/llms/azure_ai/agents/handler.py | 4 +- tests/llm_translation/test_azure_agents.py | 162 +++++++++++++++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) diff --git a/litellm/llms/azure_ai/agents/handler.py b/litellm/llms/azure_ai/agents/handler.py index 5b779acb0d..95c0a4c577 100644 --- a/litellm/llms/azure_ai/agents/handler.py +++ b/litellm/llms/azure_ai/agents/handler.py @@ -700,7 +700,9 @@ class AzureAIAgentsHandler: raw_annotations ) if transformed: - collected_annotations = transformed + if collected_annotations is None: + collected_annotations = [] + collected_annotations.extend(transformed) # Process message deltas - this is where the actual content comes if current_event == "thread.message.delta": diff --git a/tests/llm_translation/test_azure_agents.py b/tests/llm_translation/test_azure_agents.py index 19ce49a3bc..3e6b1e00a7 100644 --- a/tests/llm_translation/test_azure_agents.py +++ b/tests/llm_translation/test_azure_agents.py @@ -23,12 +23,14 @@ Example environment variables: See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart """ +import json import os import sys sys.path.insert(0, os.path.abspath("../..")) import pytest +from unittest.mock import MagicMock import litellm @@ -465,6 +467,166 @@ def test_azure_ai_agents_build_model_response_without_annotations(): assert getattr(result.choices[0].message, "annotations", None) is None +@pytest.mark.asyncio +async def test_azure_ai_agents_streaming_annotations_from_completed_message(): + """ + Test that annotations from thread.message.completed SSE events are collected + and attached to the final chunk's delta. + + Ref: https://github.com/BerriAI/litellm/issues/19126 + """ + from litellm.llms.azure_ai.agents.handler import AzureAIAgentsHandler + + handler = AzureAIAgentsHandler() + + # SSE lines simulating a stream with annotations in thread.message.completed + completed_data = { + "content": [ + { + "type": "text", + "text": { + "value": "According to [1], the answer is 42.", + "annotations": [ + { + "type": "url_citation", + "text": "[1]", + "start_index": 12, + "end_index": 15, + "url_citation": { + "url": "https://example.com/citation", + "title": "Citation Source", + }, + } + ], + }, + } + ] + } + + sse_lines = [ + "event: thread.created", + "", + 'data: {"id": "thread_stream_123"}', + "", + "event: thread.message.delta", + "", + 'data: {"delta": {"content": [{"type": "text", "text": {"value": "According to [1], the answer is 42."}}]}}', + "", + "event: thread.message.completed", + "", + f"data: {json.dumps(completed_data)}", + "", + "data: [DONE]", + ] + + async def mock_aiter_lines(): + for line in sse_lines: + yield line + + mock_response = MagicMock() + mock_response.aiter_lines = MagicMock(return_value=mock_aiter_lines()) + + chunks = [] + async for chunk in handler._process_sse_stream(mock_response, "azure_ai/agents/asst_123"): + chunks.append(chunk) + + # Should have content chunks + final [DONE] chunk + assert len(chunks) >= 1 + final_chunk = chunks[-1] + assert final_chunk.choices[0].finish_reason == "stop" + assert final_chunk.choices[0].delta.annotations is not None + assert len(final_chunk.choices[0].delta.annotations) == 1 + ann = final_chunk.choices[0].delta.annotations[0] + assert ann["type"] == "url_citation" + assert ann["url_citation"]["url"] == "https://example.com/citation" + assert ann["url_citation"]["title"] == "Citation Source" + + +@pytest.mark.asyncio +async def test_azure_ai_agents_streaming_accumulates_annotations_from_multiple_text_items(): + """ + Test that annotations from multiple text content items in thread.message.completed + are accumulated (not overwritten). + + Ref: Greptile review on PR #23849 + """ + from litellm.llms.azure_ai.agents.handler import AzureAIAgentsHandler + + handler = AzureAIAgentsHandler() + + # Two text blocks, each with distinct citations + completed_data = { + "content": [ + { + "type": "text", + "text": { + "value": "First source [1].", + "annotations": [ + { + "type": "url_citation", + "text": "[1]", + "start_index": 12, + "end_index": 15, + "url_citation": { + "url": "https://example.com/first", + "title": "First", + }, + } + ], + }, + }, + { + "type": "text", + "text": { + "value": "Second source [2].", + "annotations": [ + { + "type": "url_citation", + "text": "[2]", + "start_index": 13, + "end_index": 16, + "url_citation": { + "url": "https://example.com/second", + "title": "Second", + }, + } + ], + }, + }, + ] + } + + sse_lines = [ + "event: thread.created", + "", + 'data: {"id": "thread_multi"}', + "", + "event: thread.message.completed", + "", + f"data: {json.dumps(completed_data)}", + "", + "data: [DONE]", + ] + + async def mock_aiter_lines(): + for line in sse_lines: + yield line + + mock_response = MagicMock() + mock_response.aiter_lines = MagicMock(return_value=mock_aiter_lines()) + + chunks = [] + async for chunk in handler._process_sse_stream(mock_response, "azure_ai/agents/asst_123"): + chunks.append(chunk) + + final_chunk = chunks[-1] + assert final_chunk.choices[0].delta.annotations is not None + assert len(final_chunk.choices[0].delta.annotations) == 2 + urls = [a["url_citation"]["url"] for a in final_chunk.choices[0].delta.annotations] + assert "https://example.com/first" in urls + assert "https://example.com/second" in urls + + @pytest.mark.asyncio async def test_azure_ai_agents_conversation_continuity(): """ From ff536e664aec4e6b22c7836f67434232f52b729a Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 00:38:35 -0300 Subject: [PATCH 061/539] fix(gemini): propagate model to check_cache/async_check_cache for custom api_base check_and_create_cache calls check_cache first (to avoid duplicates), which also needs model for the URL when api_base is set. Without this, the full flow still raises ValueError before reaching the create step. --- .../vertex_ai/context_caching/vertex_ai_context_caching.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py index c2e064d656..b677cf3b1e 100644 --- a/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py +++ b/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py @@ -110,6 +110,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project: Optional[str], vertex_location: Optional[str], vertex_auth_header: Optional[str], + model: Optional[str] = None, ) -> Optional[str]: """ Checks if content already cached. @@ -129,6 +130,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) page_token: Optional[str] = None @@ -202,6 +204,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project: Optional[str], vertex_location: Optional[str], vertex_auth_header: Optional[str], + model: Optional[str] = None, ) -> Optional[str]: """ Checks if content already cached. @@ -221,6 +224,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) page_token: Optional[str] = None @@ -379,6 +383,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) if google_cache_name: return non_cached_messages, optional_params, google_cache_name @@ -523,6 +528,7 @@ class ContextCachingEndpoints(VertexBase): vertex_project=vertex_project, vertex_location=vertex_location, vertex_auth_header=vertex_auth_header, + model=model, ) if google_cache_name: From 6514446dcb1984c29cf9fd61bcd3627b3cff0579 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:09:30 +0530 Subject: [PATCH 062/539] Update litellm/llms/azure_ai/agents/handler.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- litellm/llms/azure_ai/agents/handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/azure_ai/agents/handler.py b/litellm/llms/azure_ai/agents/handler.py index 95c0a4c577..3fbda13d3c 100644 --- a/litellm/llms/azure_ai/agents/handler.py +++ b/litellm/llms/azure_ai/agents/handler.py @@ -137,7 +137,7 @@ class AzureAIAgentsHandler: result: List[Dict[str, Any]] = [] for ann in raw_annotations: - ann_type = ann.get("type", "url_citation") + ann_type = ann.get("type") if ann_type == "url_citation": url_citation = dict(ann.get("url_citation", {})) # Azure puts start/end_index at annotation level; OpenAI From 018ccff23f1df9e4d223741e0f084447f08bcade Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:11:20 +0530 Subject: [PATCH 063/539] fix(vertex-ai): address greptile review feedback on batch cancel - Add try/except httpx.HTTPStatusError blocks in _async_cancel_batch for both POST cancel and GET retrieve calls, with verbose_logger error logging - Fix endpoint extraction inconsistency: compute endpoint from URL without :cancel suffix so it matches behaviour of create_batch/retrieve_batch - Add explicit validation that api_base ends with ':cancel' before stripping it, raising a descriptive error for unsupported custom proxy URL rewriting scenarios - Use string-based patch() in test instead of patch.object() for robustness against import order changes Made-with: Cursor --- litellm/llms/vertex_ai/batches/handler.py | 54 ++++++++++++++----- .../test_vertex_ai_batch_transformation.py | 2 +- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index c7b9287c08..36728499b9 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -399,10 +399,12 @@ class VertexAIBatchPrediction(VertexLLM): vertex_project=vertex_project or project_id, ) - default_api_base = f"{default_api_base}/{batch_id}:cancel" + # Compute endpoint from the URL without :cancel for consistency with other methods + base_without_cancel = f"{default_api_base}/{batch_id}" + default_api_base = f"{base_without_cancel}:cancel" - if len(default_api_base.split(":")) > 1: - endpoint = default_api_base.split(":")[-1] + if len(base_without_cancel.split(":")) > 1: + endpoint = base_without_cancel.split(":")[-1] else: endpoint = "" @@ -420,6 +422,14 @@ class VertexAIBatchPrediction(VertexLLM): vertex_api_version="v1", ) + if not api_base.endswith(":cancel"): + raise ValueError( + f"cancel_batch: expected api_base to end with ':cancel', got: {api_base!r}. " + "Custom proxy URL rewriting is not supported for this operation." + ) + + retrieve_api_base = api_base.rsplit(":cancel", 1)[0] + headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {access_token}", @@ -428,7 +438,7 @@ class VertexAIBatchPrediction(VertexLLM): if _is_async is True: return self._async_cancel_batch( api_base=api_base, - retrieve_api_base=api_base.rsplit(":cancel", 1)[0], + retrieve_api_base=retrieve_api_base, headers=headers, ) @@ -443,7 +453,7 @@ class VertexAIBatchPrediction(VertexLLM): raise Exception(f"Error: {response.status_code} {response.text}") retrieve_response = sync_handler.get( - url=api_base.rsplit(":cancel", 1)[0], + url=retrieve_api_base, headers=headers, ) if retrieve_response.status_code != 200: @@ -466,18 +476,34 @@ class VertexAIBatchPrediction(VertexLLM): client = get_async_httpx_client( llm_provider=litellm.LlmProviders.VERTEX_AI, ) - response = await client.post( - url=api_base, - headers=headers, - data=json.dumps({}), - ) + try: + response = await client.post( + url=api_base, + headers=headers, + data=json.dumps({}), + ) + except httpx.HTTPStatusError as e: + litellm.verbose_logger.error( + "Vertex AI batch cancel failed: status=%s, body=%s", + e.response.status_code, + e.response.text[:1000], + ) + raise if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") - retrieve_response = await client.get( - url=retrieve_api_base, - headers=headers, - ) + try: + retrieve_response = await client.get( + url=retrieve_api_base, + headers=headers, + ) + except httpx.HTTPStatusError as e: + litellm.verbose_logger.error( + "Vertex AI batch retrieve-after-cancel failed: status=%s, body=%s", + e.response.status_code, + e.response.text[:1000], + ) + raise if retrieve_response.status_code != 200: raise Exception( f"Error: {retrieve_response.status_code} {retrieve_response.text}" diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index d27cd8ba8a..e03555a9c0 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -101,7 +101,7 @@ async def test_litellm_cancel_batch_vertex_ai(): mock_response.id = "batch_123" mock_response.status = "cancelling" - with patch.object(litellm.batches.main, "vertex_ai_batches_instance") as mock_instance: + with patch("litellm.batches.main.vertex_ai_batches_instance") as mock_instance: mock_instance.cancel_batch.return_value = mock_response response = litellm.cancel_batch( From 52bf372319b959c49be126472268cc0b7bf1bb8a Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:17:57 +0530 Subject: [PATCH 064/539] fix(gpt5): treat missing supports_minimal_reasoning_effort as supported Add _is_reasoning_effort_level_explicitly_disabled to use opt-out semantics for minimal effort: unknown/unlisted models pass through, only blocked when the model map explicitly sets supports_minimal_reasoning_effort=false. xhigh keeps opt-in semantics (must be explicitly supported). Adds test for unknown-model passthrough and explicit-disabled detection. Made-with: Cursor --- .../llms/openai/chat/gpt_5_transformation.py | 42 +++++++++++++++++-- .../llms/openai/test_gpt5_transformation.py | 35 ++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index 8522a97a38..6291f4232d 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -3,7 +3,7 @@ from typing import Optional, Union import litellm -from litellm.utils import _supports_factory +from litellm.utils import _get_model_cost_key, _get_model_info_helper, _supports_factory from .gpt_transformation import OpenAIGPTConfig @@ -113,6 +113,28 @@ class OpenAIGPT5Config(OpenAIGPTConfig): key=f"supports_{level}_reasoning_effort", ) + @classmethod + def _is_reasoning_effort_level_explicitly_disabled( + cls, model: str, level: str + ) -> bool: + """Return True only when the model map explicitly sets the capability to False. + + Unlike ``_supports_reasoning_effort_level`` (which requires an explicit True), + this method returns True only when ``supports_{level}_reasoning_effort`` is + explicitly set to ``False`` in the model map. A missing key is treated as + supported (i.e. this method returns False = not disabled). + + Use this for opt-out checks where unknown models should be allowed through. + """ + try: + key = f"supports_{level}_reasoning_effort" + cost_key = _get_model_cost_key(model) + entry = litellm.model_cost.get(cost_key or model) or {} + val = entry.get(key) + return val is False + except Exception: + return False + def get_supported_openai_params(self, model: str) -> list: if self.is_model_gpt_5_search_model(model): return [ @@ -200,9 +222,8 @@ class OpenAIGPT5Config(OpenAIGPTConfig): if "reasoning_effort" in optional_params: optional_params["reasoning_effort"] = normalized - if effective_effort is not None and ( - effective_effort == "xhigh" or effective_effort == "minimal" - ): + if effective_effort == "xhigh": + # xhigh is an opt-in capability: only allow if model explicitly supports it. if not self._supports_reasoning_effort_level(model, effective_effort): if litellm.drop_params or drop_params: non_default_params.pop("reasoning_effort", None) @@ -213,6 +234,19 @@ class OpenAIGPT5Config(OpenAIGPTConfig): ), status_code=400, ) + elif effective_effort == "minimal": + # minimal is opt-out: unknown models pass through; only block when + # the model map explicitly sets supports_minimal_reasoning_effort=false. + if self._is_reasoning_effort_level_explicitly_disabled(model, effective_effort): + if litellm.drop_params or drop_params: + non_default_params.pop("reasoning_effort", None) + else: + raise litellm.utils.UnsupportedParamsError( + message=( + f"reasoning_effort={effective_effort} is not supported for this model." + ), + status_code=400, + ) ################################################################ # max_tokens is not supported for gpt-5 models on OpenAI API diff --git a/tests/test_litellm/llms/openai/test_gpt5_transformation.py b/tests/test_litellm/llms/openai/test_gpt5_transformation.py index 8ae9bf48a6..535fba1614 100644 --- a/tests/test_litellm/llms/openai/test_gpt5_transformation.py +++ b/tests/test_litellm/llms/openai/test_gpt5_transformation.py @@ -409,6 +409,41 @@ def test_gpt5_supports_reasoning_effort_level_minimal(gpt5_config: OpenAIGPT5Con assert not gpt5_config._supports_reasoning_effort_level("gpt-5.4-nano", "minimal") +def test_gpt5_minimal_explicitly_disabled_check(gpt5_config: OpenAIGPT5Config): + """_is_reasoning_effort_level_explicitly_disabled returns True only for explicit False entries. + + Models with supports_minimal_reasoning_effort=false → disabled. + Models with supports_minimal_reasoning_effort=true (or missing) → not disabled. + """ + assert gpt5_config._is_reasoning_effort_level_explicitly_disabled( + "gpt-5.4-mini", "minimal" + ) + assert gpt5_config._is_reasoning_effort_level_explicitly_disabled( + "gpt-5.4-nano", "minimal" + ) + assert not gpt5_config._is_reasoning_effort_level_explicitly_disabled( + "gpt-5.4", "minimal" + ) + assert not gpt5_config._is_reasoning_effort_level_explicitly_disabled( + "gpt-5.4-pro", "minimal" + ) + + +def test_gpt5_unknown_model_passes_through_minimal(config: OpenAIConfig): + """Unknown/unlisted gpt-5 models should pass reasoning_effort='minimal' through. + + Missing supports_minimal_reasoning_effort key is treated as supported, + not as unsupported, to avoid breaking custom or newly-announced models. + """ + params = config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="gpt-5.4-turbo-preview", + drop_params=False, + ) + assert params["reasoning_effort"] == "minimal" + + def test_gpt5_normalizes_reasoning_effort_dict_with_summary(config: OpenAIConfig): """Dict with summary/generate_summary is normalized for chat completions.""" params = config.map_openai_params( From a41239cb966341f78c6a3da96dd53c3332929c8f Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:24:34 +0530 Subject: [PATCH 065/539] greptile comments --- litellm/model_prices_and_context_window_backup.json | 2 ++ model_prices_and_context_window.json | 2 ++ 2 files changed, 4 insertions(+) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 44fe6a06ad..b5330e5855 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -18799,6 +18799,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, + "supports_web_search": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false @@ -18841,6 +18842,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, + "supports_web_search": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 44fe6a06ad..b5330e5855 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -18799,6 +18799,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, + "supports_web_search": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false @@ -18841,6 +18842,7 @@ "supports_tool_choice": true, "supports_service_tier": true, "supports_vision": true, + "supports_web_search": true, "supports_none_reasoning_effort": true, "supports_xhigh_reasoning_effort": true, "supports_minimal_reasoning_effort": false From aaf860c19ba491424f16474b11d398fa1c791dc4 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 00:56:11 -0300 Subject: [PATCH 066/539] docs: add Azure custom deployment name guidance for auto-routing --- docs/my-website/docs/reasoning_content.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/my-website/docs/reasoning_content.md b/docs/my-website/docs/reasoning_content.md index fb77f53a56..313228d511 100644 --- a/docs/my-website/docs/reasoning_content.md +++ b/docs/my-website/docs/reasoning_content.md @@ -598,6 +598,23 @@ When `gpt-5.4+` requests to `litellm.completion()` include both `reasoning_effor You can also route explicitly via `openai/responses/gpt-5.4` or `azure/responses/gpt-5.4`. See [Responses API Bridge](/docs/providers/openai#openai-chat-completion-to-responses-api-bridge) for details. +**Azure custom deployment names:** Auto-routing relies on the deployment name matching the `gpt-5.4*` pattern. If you use a custom deployment name (e.g. `"my-reasoning-model"`), enable routing via: + +**SDK:** +```python +litellm.completion(model="azure/responses/my-reasoning-model", ...) +``` + +**Proxy config:** +```yaml +model_list: + - model_name: my-reasoning-model + litellm_params: + model: azure/my-reasoning-model + model_info: + mode: responses +``` + ::: ## OpenAI Responses API - Auto-Summary Control From 74382f1c89cf91b3ddc58dd82d6298407465c51c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:28:19 +0530 Subject: [PATCH 067/539] fix(vertex-ai): address greptile review feedback on batch cancel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace misleading endpoint extraction with explicit endpoint = "cancel" - Compute retrieve_api_base from URL components directly instead of stripping ":cancel" from the post-proxy URL, removing the hard ValueError that broke any custom Vertex AI proxy configuration - Align cancel_batch provider priority in proxy endpoints to match create_batch order: body field → request headers → query params → default Made-with: Cursor --- litellm/llms/vertex_ai/batches/handler.py | 21 +++++++------------- litellm/proxy/batches_endpoints/endpoints.py | 2 +- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 36728499b9..b30130688d 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -399,14 +399,11 @@ class VertexAIBatchPrediction(VertexLLM): vertex_project=vertex_project or project_id, ) - # Compute endpoint from the URL without :cancel for consistency with other methods - base_without_cancel = f"{default_api_base}/{batch_id}" - default_api_base = f"{base_without_cancel}:cancel" + retrieve_api_base_default = f"{default_api_base}/{batch_id}" + default_api_base = f"{retrieve_api_base_default}:cancel" - if len(base_without_cancel.split(":")) > 1: - endpoint = base_without_cancel.split(":")[-1] - else: - endpoint = "" + # The Vertex AI action suffix for this operation + endpoint = "cancel" _, api_base = self._check_custom_proxy( api_base=api_base, @@ -422,13 +419,9 @@ class VertexAIBatchPrediction(VertexLLM): vertex_api_version="v1", ) - if not api_base.endswith(":cancel"): - raise ValueError( - f"cancel_batch: expected api_base to end with ':cancel', got: {api_base!r}. " - "Custom proxy URL rewriting is not supported for this operation." - ) - - retrieve_api_base = api_base.rsplit(":cancel", 1)[0] + # Use the canonical retrieve URL built from components rather than stripping + # ":cancel" from api_base, so custom proxy URL rewriting does not break retrieval. + retrieve_api_base = retrieve_api_base_default headers = { "Content-Type": "application/json; charset=utf-8", diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 9ce1b6e916..3b26e6c096 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -898,9 +898,9 @@ async def cancel_batch( else: custom_llm_provider = ( provider + or data.pop("custom_llm_provider", None) or get_custom_llm_provider_from_request_headers(request=request) or get_custom_llm_provider_from_request_query(request=request) - or data.pop("custom_llm_provider", None) or "openai" ) # Extract batch_id from data to avoid "multiple values for keyword argument" error From 5dd89f16f5369c65af33600f43ee22a4e39b9de6 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:37:19 +0530 Subject: [PATCH 068/539] address greptile review: remove unused import, normalize model lookup, add xhigh tests - Remove unused _get_model_info_helper import - Normalize model via get_llm_provider in _is_reasoning_effort_level_explicitly_disabled so provider-prefixed names (openai/gpt-5.4-mini) resolve correctly - Add test_gpt5_4_mini_allows_reasoning_effort_xhigh - Add test_gpt5_4_nano_allows_reasoning_effort_xhigh - Add test_gpt5_4_mini_provider_prefixed_rejects_minimal - Extend test_gpt5_minimal_explicitly_disabled_check for openai/gpt-5.4-mini --- .../llms/openai/chat/gpt_5_transformation.py | 11 ++++-- .../llms/openai/test_gpt5_transformation.py | 37 +++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index 6291f4232d..60a21e19a9 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -3,7 +3,7 @@ from typing import Optional, Union import litellm -from litellm.utils import _get_model_cost_key, _get_model_info_helper, _supports_factory +from litellm.utils import _get_model_cost_key, _supports_factory from .gpt_transformation import OpenAIGPTConfig @@ -125,11 +125,16 @@ class OpenAIGPT5Config(OpenAIGPTConfig): supported (i.e. this method returns False = not disabled). Use this for opt-out checks where unknown models should be allowed through. + Normalizes the model via get_llm_provider so provider-prefixed names + (e.g. openai/gpt-5.4-mini) resolve correctly. """ try: + normalized_model, _, _, _ = litellm.get_llm_provider( + model=model, custom_llm_provider=None + ) key = f"supports_{level}_reasoning_effort" - cost_key = _get_model_cost_key(model) - entry = litellm.model_cost.get(cost_key or model) or {} + cost_key = _get_model_cost_key(normalized_model) + entry = litellm.model_cost.get(cost_key or normalized_model) or {} val = entry.get(key) return val is False except Exception: diff --git a/tests/test_litellm/llms/openai/test_gpt5_transformation.py b/tests/test_litellm/llms/openai/test_gpt5_transformation.py index 535fba1614..72e8a0c185 100644 --- a/tests/test_litellm/llms/openai/test_gpt5_transformation.py +++ b/tests/test_litellm/llms/openai/test_gpt5_transformation.py @@ -324,6 +324,28 @@ def test_gpt5_4_pro_allows_reasoning_effort_xhigh(config: OpenAIConfig): assert params["reasoning_effort"] == "xhigh" +def test_gpt5_4_mini_allows_reasoning_effort_xhigh(config: OpenAIConfig): + """gpt-5.4-mini supports reasoning_effort='xhigh'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "xhigh"}, + optional_params={}, + model="gpt-5.4-mini", + drop_params=False, + ) + assert params["reasoning_effort"] == "xhigh" + + +def test_gpt5_4_nano_allows_reasoning_effort_xhigh(config: OpenAIConfig): + """gpt-5.4-nano supports reasoning_effort='xhigh'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "xhigh"}, + optional_params={}, + model="gpt-5.4-nano", + drop_params=False, + ) + assert params["reasoning_effort"] == "xhigh" + + def test_gpt5_4_allows_reasoning_effort_minimal(config: OpenAIConfig): """gpt-5.4 supports reasoning_effort='minimal'.""" params = config.map_openai_params( @@ -368,6 +390,17 @@ def test_gpt5_4_nano_rejects_reasoning_effort_minimal(config: OpenAIConfig): ) +def test_gpt5_4_mini_provider_prefixed_rejects_minimal(config: OpenAIConfig): + """openai/gpt-5.4-mini correctly rejects minimal (model lookup normalizes prefix).""" + with pytest.raises(litellm.utils.UnsupportedParamsError): + config.map_openai_params( + non_default_params={"reasoning_effort": "minimal"}, + optional_params={}, + model="openai/gpt-5.4-mini", + drop_params=False, + ) + + def test_gpt5_drops_reasoning_effort_minimal_when_requested(config: OpenAIConfig): """reasoning_effort='minimal' is dropped for unsupported models when drop_params=True.""" params = config.map_openai_params( @@ -414,6 +447,7 @@ def test_gpt5_minimal_explicitly_disabled_check(gpt5_config: OpenAIGPT5Config): Models with supports_minimal_reasoning_effort=false → disabled. Models with supports_minimal_reasoning_effort=true (or missing) → not disabled. + Provider-prefixed models (openai/gpt-5.4-mini) are normalized before lookup. """ assert gpt5_config._is_reasoning_effort_level_explicitly_disabled( "gpt-5.4-mini", "minimal" @@ -421,6 +455,9 @@ def test_gpt5_minimal_explicitly_disabled_check(gpt5_config: OpenAIGPT5Config): assert gpt5_config._is_reasoning_effort_level_explicitly_disabled( "gpt-5.4-nano", "minimal" ) + assert gpt5_config._is_reasoning_effort_level_explicitly_disabled( + "openai/gpt-5.4-mini", "minimal" + ) assert not gpt5_config._is_reasoning_effort_level_explicitly_disabled( "gpt-5.4", "minimal" ) From 74ae17d15305334d0614b107e866bb5f6f7ef0e9 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:41:46 +0530 Subject: [PATCH 069/539] greptile comments --- litellm/batches/main.py | 2 +- litellm/llms/vertex_ai/batches/handler.py | 31 +++++++++++++------ .../test_vertex_ai_batch_transformation.py | 2 +- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 36093d071b..ae79469dd1 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -1048,7 +1048,7 @@ def cancel_batch( litellm_params=litellm_params, ) elif custom_llm_provider == "vertex_ai": - api_base = optional_params.api_base or "" + api_base = optional_params.api_base or None vertex_ai_project = ( optional_params.vertex_project or litellm.vertex_project diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index b30130688d..f4ba0533c8 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -400,28 +400,41 @@ class VertexAIBatchPrediction(VertexLLM): ) retrieve_api_base_default = f"{default_api_base}/{batch_id}" - default_api_base = f"{retrieve_api_base_default}:cancel" + cancel_api_base_default = f"{retrieve_api_base_default}:cancel" - # The Vertex AI action suffix for this operation - endpoint = "cancel" + # Save the caller-supplied value before _check_custom_proxy overwrites api_base, + # so we can pass it unchanged to the second proxy-check for the retrieve URL. + caller_api_base = api_base _, api_base = self._check_custom_proxy( - api_base=api_base, + api_base=caller_api_base, custom_llm_provider="vertex_ai", gemini_api_key=None, - endpoint=endpoint, + endpoint="cancel", stream=None, auth_header=None, - url=default_api_base, + url=cancel_api_base_default, model=None, vertex_project=vertex_project or project_id, vertex_location=vertex_location or "us-central1", vertex_api_version="v1", ) - # Use the canonical retrieve URL built from components rather than stripping - # ":cancel" from api_base, so custom proxy URL rewriting does not break retrieval. - retrieve_api_base = retrieve_api_base_default + # Route the retrieve GET through the same proxy as the cancel POST by running + # _check_custom_proxy a second time with the non-cancel default URL. + _, retrieve_api_base = self._check_custom_proxy( + api_base=caller_api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint="", + stream=None, + auth_header=None, + url=retrieve_api_base_default, + model=None, + vertex_project=vertex_project or project_id, + vertex_location=vertex_location or "us-central1", + vertex_api_version="v1", + ) headers = { "Content-Type": "application/json; charset=utf-8", diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index e03555a9c0..1cf6fa3266 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -45,7 +45,7 @@ def test_output_file_id_falls_back_to_output_uri_prefix_with_predictions_jsonl() @pytest.mark.asyncio -async def test_vertex_ai_cancel_batch(): +def test_vertex_ai_cancel_batch(): """Test that vertex_ai cancel_batch calls the correct API endpoint""" handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") From d0d593beb8027b2b467b1ef9cbdd302981d658a4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:48:24 +0530 Subject: [PATCH 070/539] Update tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- .../llms/vertex_ai/test_vertex_ai_batch_transformation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index 1cf6fa3266..1fa8dd2e16 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -44,7 +44,6 @@ def test_output_file_id_falls_back_to_output_uri_prefix_with_predictions_jsonl() ) -@pytest.mark.asyncio def test_vertex_ai_cancel_batch(): """Test that vertex_ai cancel_batch calls the correct API endpoint""" handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") From e46dd949f2b10ccbeabf4d75ca89fec1574e7d23 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 09:58:20 +0530 Subject: [PATCH 071/539] Add test for reasoning effort none --- .../llms/openai/test_gpt5_transformation.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_litellm/llms/openai/test_gpt5_transformation.py b/tests/test_litellm/llms/openai/test_gpt5_transformation.py index 72e8a0c185..a173d2015b 100644 --- a/tests/test_litellm/llms/openai/test_gpt5_transformation.py +++ b/tests/test_litellm/llms/openai/test_gpt5_transformation.py @@ -345,6 +345,25 @@ def test_gpt5_4_nano_allows_reasoning_effort_xhigh(config: OpenAIConfig): ) assert params["reasoning_effort"] == "xhigh" +def test_gpt5_4_nano_allows_reasoning_effort_none(config: OpenAIConfig): + """gpt-5.4-nano supports reasoning_effort='none'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "none"}, + optional_params={}, + model="gpt-5.4-nano", + drop_params=False, + ) + assert params["reasoning_effort"] == "none" + +def test_gpt5_4_mini_allows_reasoning_effort_none(config: OpenAIConfig): + """gpt-5.4-mini supports reasoning_effort='none'.""" + params = config.map_openai_params( + non_default_params={"reasoning_effort": "none"}, + optional_params={}, + model="gpt-5.4-mini", + drop_params=False, + ) + assert params["reasoning_effort"] == "none" def test_gpt5_4_allows_reasoning_effort_minimal(config: OpenAIConfig): """gpt-5.4 supports reasoning_effort='minimal'.""" From 547db8f5d1f1408e598088380fa007497efb3f23 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:01:47 +0530 Subject: [PATCH 072/539] Fix greptile comments --- litellm/llms/vertex_ai/batches/handler.py | 24 +++++------------------ 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index f4ba0533c8..9c760f628e 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -402,12 +402,8 @@ class VertexAIBatchPrediction(VertexLLM): retrieve_api_base_default = f"{default_api_base}/{batch_id}" cancel_api_base_default = f"{retrieve_api_base_default}:cancel" - # Save the caller-supplied value before _check_custom_proxy overwrites api_base, - # so we can pass it unchanged to the second proxy-check for the retrieve URL. - caller_api_base = api_base - _, api_base = self._check_custom_proxy( - api_base=caller_api_base, + api_base=api_base, custom_llm_provider="vertex_ai", gemini_api_key=None, endpoint="cancel", @@ -420,20 +416,10 @@ class VertexAIBatchPrediction(VertexLLM): vertex_api_version="v1", ) - # Route the retrieve GET through the same proxy as the cancel POST by running - # _check_custom_proxy a second time with the non-cancel default URL. - _, retrieve_api_base = self._check_custom_proxy( - api_base=caller_api_base, - custom_llm_provider="vertex_ai", - gemini_api_key=None, - endpoint="", - stream=None, - auth_header=None, - url=retrieve_api_base_default, - model=None, - vertex_project=vertex_project or project_id, - vertex_location=vertex_location or "us-central1", - vertex_api_version="v1", + retrieve_api_base = ( + api_base.removesuffix(":cancel") + if api_base.endswith(":cancel") + else retrieve_api_base_default ) headers = { From dc7b7f852d11f113b61df87be757d094fa306c24 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:10:29 +0530 Subject: [PATCH 073/539] =?UTF-8?q?fix(file=5Fsearch):=20address=20greptil?= =?UTF-8?q?e=20review=20=E2=80=94=20dead=20code,=20follow-up=20context,=20?= =?UTF-8?q?cost=20tracking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove dead `should_use_emulated_file_search` (main.py uses its own inline guard) - Remove dead `fallback_vector_store_ids` param from `_run_vector_searches` - Include all first_response.output items in follow_up_input so text blocks/reasoning from providers like Anthropic aren't dropped from conversation context - Accumulate first provider call's response_cost into synthesized _hidden_params so billing callbacks see the total cost of both emulated-flow LLM calls - Remove broad tools=[] filter from transformation.py (backward-incompatible); the follow-up call already passes tools=None which is filtered by the v is not None guard Made-with: Cursor --- .../responses/file_search/emulated_handler.py | 64 ++++++------------- .../transformation.py | 6 +- 2 files changed, 23 insertions(+), 47 deletions(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 5d2c23fdfa..1686021884 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -26,25 +26,6 @@ ToolParam = Any FILE_SEARCH_FUNCTION_NAME = "litellm_file_search" -# --------------------------------------------------------------------------- -# Detection -# --------------------------------------------------------------------------- - -def should_use_emulated_file_search( - tools: Optional[Iterable[ToolParam]], - provider_config: Any, # BaseResponsesAPIConfig -) -> bool: - """Return True when there is a file_search tool and the provider can't handle it natively.""" - if not tools: - return False - has_fs = any( - isinstance(t, dict) and t.get("type") == "file_search" for t in tools - ) - if not has_fs: - return False - return provider_config is None or not provider_config.supports_native_file_search() - - # --------------------------------------------------------------------------- # Tool conversion # --------------------------------------------------------------------------- @@ -126,15 +107,13 @@ def _replace_file_search_tools( async def _run_vector_searches( queries: List[str], vector_store_ids: List[str], - fallback_vector_store_ids: List[str], ) -> Tuple[List[str], List[VectorStoreSearchResult]]: """ Run `asearch` against all vector stores for all queries and collect results. Args: queries: List of search queries to execute (like OpenAI's multi-query approach) - vector_store_ids: Specific vector store IDs to search - fallback_vector_store_ids: Fallback IDs if vector_store_ids is empty + vector_store_ids: Vector store IDs to search Returns: (queries_list, combined_results) @@ -142,7 +121,7 @@ async def _run_vector_searches( import litellm.vector_stores.main as vs_main all_results: List[VectorStoreSearchResult] = [] - ids_to_search = vector_store_ids or fallback_vector_store_ids + ids_to_search = vector_store_ids # Execute each query against all vector stores for query in queries: @@ -337,11 +316,16 @@ def _synthesize_responses_api_response( original_response: ResponsesAPIResponse, file_search_call_output: Dict[str, Any], message_output: Dict[str, Any], + first_response: Optional[ResponsesAPIResponse] = None, ) -> ResponsesAPIResponse: """ Return a new ResponsesAPIResponse with: output[0] = file_search_call item output[1] = message item (with citations) + + When first_response is provided, its response_cost is accumulated into the + synthesized _hidden_params so that billing callbacks see the total cost of + both provider calls that the emulated flow makes. """ synthesized = ResponsesAPIResponse( id=getattr(original_response, "id", f"resp_{uuid.uuid4().hex}"), @@ -354,7 +338,14 @@ def _synthesize_responses_api_response( error=None, ) if hasattr(original_response, "_hidden_params"): - synthesized._hidden_params = getattr(original_response, "_hidden_params") + hidden = dict(getattr(original_response, "_hidden_params") or {}) + if first_response is not None and hasattr(first_response, "_hidden_params"): + first_hidden = getattr(first_response, "_hidden_params") or {} + first_cost = first_hidden.get("response_cost") if isinstance(first_hidden, dict) else getattr(first_hidden, "response_cost", None) + if first_cost is not None: + current_cost = hidden.get("response_cost") if isinstance(hidden, dict) else 0 + hidden["response_cost"] = (current_cost or 0) + first_cost + synthesized._hidden_params = hidden return synthesized @@ -473,7 +464,6 @@ async def aresponses_with_emulated_file_search( queries, results = await _run_vector_searches( queries=queries_from_call, vector_store_ids=vs_ids_for_call, - fallback_vector_store_ids=all_vs_ids, ) all_queries.extend(queries) all_results.extend(results) @@ -486,28 +476,15 @@ async def aresponses_with_emulated_file_search( } ) - # 5. Build follow-up input: original messages + all assistant tool calls + tool results + # 5. Build follow-up input: original messages + ALL first-response output items + tool results + # Including all output items (text blocks, reasoning, non-file-search calls) ensures providers + # like Anthropic that emit text before the tool call have complete conversation context. original_input_items = list(input) if isinstance(input, (list, tuple)) else [{"role": "user", "content": str(input)}] - follow_up_function_calls: List[Dict[str, Any]] = [] - for tc in file_search_calls: - if isinstance(tc, dict): - tc_call_id = tc.get("call_id") or tc.get("id") or file_search_call_id - tc_args = tc.get("arguments") or "{}" - else: - tc_call_id = getattr(tc, "call_id", None) or getattr(tc, "id", file_search_call_id) - tc_args = getattr(tc, "arguments", "{}") or "{}" - follow_up_function_calls.append( - { - "type": "function_call", - "name": FILE_SEARCH_FUNCTION_NAME, - "call_id": tc_call_id, - "arguments": tc_args, - } - ) + first_response_output_items = list(first_response.output) follow_up_input = ( original_input_items - + follow_up_function_calls + + first_response_output_items + tool_results ) @@ -534,4 +511,5 @@ async def aresponses_with_emulated_file_search( include_search_results=_include_search_results, ), message_output=_build_message_output(response_text, all_results), + first_response=first_response, ) diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index 3467dbbd27..ae4740f8b5 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -229,13 +229,11 @@ class LiteLLMCompletionResponsesConfig: if litellm_logging_obj: litellm_logging_obj.stream_options = stream_options - # only pass non-None / non-empty values - # Explicitly exclude an empty tools list — sending tools=[] to providers - # like Anthropic in a tool_result conversation makes them return empty content. + # only pass non-None values litellm_completion_request = { k: v for k, v in litellm_completion_request.items() - if v is not None and not (k == "tools" and v == []) + if v is not None } return litellm_completion_request From 1ff7c700114bb8453b03e301afaafed94e6cc1f5 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:12:13 +0530 Subject: [PATCH 074/539] fix(file_search): serialize first_response output items to dicts for follow-up input Pydantic model instances (ResponseFunctionToolCall, etc.) from first_response.output were included raw in follow_up_input; the transformation layer expects plain dicts and called .get() on them, raising AttributeError. Serialize via model_dump(exclude_none=True). Made-with: Cursor --- litellm/responses/file_search/emulated_handler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 1686021884..584aea4a6a 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -479,8 +479,16 @@ async def aresponses_with_emulated_file_search( # 5. Build follow-up input: original messages + ALL first-response output items + tool results # Including all output items (text blocks, reasoning, non-file-search calls) ensures providers # like Anthropic that emit text before the tool call have complete conversation context. + # Serialize Pydantic model instances to plain dicts so the transformation layer can call .get(). original_input_items = list(input) if isinstance(input, (list, tuple)) else [{"role": "user", "content": str(input)}] - first_response_output_items = list(first_response.output) + first_response_output_items: List[Any] = [] + for _item in first_response.output: + if isinstance(_item, dict): + first_response_output_items.append(_item) + elif hasattr(_item, "model_dump"): + first_response_output_items.append(_item.model_dump(exclude_none=True)) # type: ignore[union-attr] + else: + first_response_output_items.append(_item) follow_up_input = ( original_input_items From ecb8c05d37ef78f7aeafbf616d092e9f8a3a596c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:24:21 +0530 Subject: [PATCH 075/539] Add test for reasoning effort none --- .../llms/openai/chat/gpt_5_transformation.py | 21 ++++------ litellm/utils.py | 41 +++++++++++++++++++ .../llms/openai/test_gpt5_transformation.py | 17 ++++++++ 3 files changed, 65 insertions(+), 14 deletions(-) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index 60a21e19a9..bd726e7933 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -3,7 +3,7 @@ from typing import Optional, Union import litellm -from litellm.utils import _get_model_cost_key, _supports_factory +from litellm.utils import _is_explicitly_disabled_factory, _supports_factory from .gpt_transformation import OpenAIGPTConfig @@ -125,20 +125,12 @@ class OpenAIGPT5Config(OpenAIGPTConfig): supported (i.e. this method returns False = not disabled). Use this for opt-out checks where unknown models should be allowed through. - Normalizes the model via get_llm_provider so provider-prefixed names - (e.g. openai/gpt-5.4-mini) resolve correctly. """ - try: - normalized_model, _, _, _ = litellm.get_llm_provider( - model=model, custom_llm_provider=None - ) - key = f"supports_{level}_reasoning_effort" - cost_key = _get_model_cost_key(normalized_model) - entry = litellm.model_cost.get(cost_key or normalized_model) or {} - val = entry.get(key) - return val is False - except Exception: - return False + return _is_explicitly_disabled_factory( + model=model, + custom_llm_provider=None, + key=f"supports_{level}_reasoning_effort", + ) def get_supported_openai_params(self, model: str) -> list: if self.is_model_gpt_5_search_model(model): @@ -245,6 +237,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig): if self._is_reasoning_effort_level_explicitly_disabled(model, effective_effort): if litellm.drop_params or drop_params: non_default_params.pop("reasoning_effort", None) + optional_params.pop("reasoning_effort", None) else: raise litellm.utils.UnsupportedParamsError( message=( diff --git a/litellm/utils.py b/litellm/utils.py index 81d749ab82..cc7956576d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2576,6 +2576,47 @@ def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str) return False +def _is_explicitly_disabled_factory( + model: str, custom_llm_provider: Optional[str], key: str +) -> bool: + """Return True only when the model map explicitly sets *key* to ``False``. + + This is the opt-out mirror of :func:`_supports_factory`. Where + ``_supports_factory`` requires an explicit ``True`` to return ``True``, + this function requires an explicit ``False``. A missing key (``None``) + is treated as *not* disabled so that unknown or newly-added models are + allowed through without any model-map entry. + + Uses the same ``get_llm_provider`` → ``_get_model_info_helper`` chain as + ``_supports_factory`` so caching, fallback, and normalisation improvements + apply here automatically. + """ + try: + model, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) + model_info = _get_model_info_helper( + model=model, custom_llm_provider=custom_llm_provider + ) + val = model_info.get(key) + if val is False: + return True + if val is None: + bare_model_key = _get_model_cost_key(model) + if bare_model_key is not None: + bare_entry = litellm.model_cost.get(bare_model_key) or {} + if bare_entry.get(key) is False: + return True + return False + except Exception as e: + verbose_logger.debug( + f"Model not found or error in checking {key} disabled state. " + f"You passed model={model}, custom_llm_provider={custom_llm_provider}. " + f"Error: {str(e)}" + ) + return False + + def supports_audio_input(model: str, custom_llm_provider: Optional[str] = None) -> bool: """Check if a given model supports audio input in a chat completion call""" return _supports_factory( diff --git a/tests/test_litellm/llms/openai/test_gpt5_transformation.py b/tests/test_litellm/llms/openai/test_gpt5_transformation.py index a173d2015b..14e392a099 100644 --- a/tests/test_litellm/llms/openai/test_gpt5_transformation.py +++ b/tests/test_litellm/llms/openai/test_gpt5_transformation.py @@ -3,6 +3,7 @@ import pytest import litellm from litellm.llms.openai.chat.gpt_5_transformation import OpenAIGPT5Config from litellm.llms.openai.openai import OpenAIConfig +from litellm.utils import _is_explicitly_disabled_factory @pytest.fixture() @@ -485,6 +486,22 @@ def test_gpt5_minimal_explicitly_disabled_check(gpt5_config: OpenAIGPT5Config): ) +def test_is_explicitly_disabled_factory_minimal(): + """_is_explicitly_disabled_factory returns True only for explicit False entries. + + Verifies the shared helper used by _is_reasoning_effort_level_explicitly_disabled + directly — so future changes to the helper are caught without going through the + method wrapper. + """ + key = "supports_minimal_reasoning_effort" + assert _is_explicitly_disabled_factory("gpt-5.4-mini", None, key) + assert _is_explicitly_disabled_factory("gpt-5.4-nano", None, key) + assert _is_explicitly_disabled_factory("openai/gpt-5.4-mini", None, key) + assert not _is_explicitly_disabled_factory("gpt-5.4", None, key) + assert not _is_explicitly_disabled_factory("gpt-5.4-pro", None, key) + assert not _is_explicitly_disabled_factory("gpt-5.4-turbo-preview", None, key) + + def test_gpt5_unknown_model_passes_through_minimal(config: OpenAIConfig): """Unknown/unlisted gpt-5 models should pass reasoning_effort='minimal' through. From c4d27cb239d89d1cf934d27b52285d5a6b490f95 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:30:05 +0530 Subject: [PATCH 076/539] =?UTF-8?q?fix(vertex-ai):=20address=20greptile=20?= =?UTF-8?q?review=20=E2=80=93=20proxy=20retrieve=20URL,=20timeout=20forwar?= =?UTF-8?q?ding,=20sync=20logging?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix retrieve_api_base derivation to handle custom proxies with path-based routing (not just :cancel suffix) - Forward timeout to POST calls in cancel_batch (sync + async) - Add try/except error logging to sync cancel path (parity with async) - Add tests for timeout forwarding and custom proxy retrieve URL Made-with: Cursor --- litellm/llms/vertex_ai/batches/handler.py | 49 ++++++++---- .../test_vertex_ai_batch_transformation.py | 78 +++++++++++++++++++ 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 9c760f628e..416c5fc69f 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -416,11 +416,12 @@ class VertexAIBatchPrediction(VertexLLM): vertex_api_version="v1", ) - retrieve_api_base = ( - api_base.removesuffix(":cancel") - if api_base.endswith(":cancel") - else retrieve_api_base_default - ) + if api_base.endswith(":cancel"): + retrieve_api_base = api_base.removesuffix(":cancel") + elif api_base == cancel_api_base_default: + retrieve_api_base = retrieve_api_base_default + else: + retrieve_api_base = api_base.rsplit(":cancel", 1)[0].rstrip("/") headers = { "Content-Type": "application/json; charset=utf-8", @@ -432,22 +433,40 @@ class VertexAIBatchPrediction(VertexLLM): api_base=api_base, retrieve_api_base=retrieve_api_base, headers=headers, + timeout=timeout, ) sync_handler = _get_httpx_client() - response = sync_handler.post( - url=api_base, - headers=headers, - data=json.dumps({}), - ) + try: + response = sync_handler.post( + url=api_base, + headers=headers, + data=json.dumps({}), + timeout=timeout, + ) + except httpx.HTTPStatusError as e: + litellm.verbose_logger.error( + "Vertex AI batch cancel failed: status=%s, body=%s", + e.response.status_code, + e.response.text[:1000], + ) + raise if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") - retrieve_response = sync_handler.get( - url=retrieve_api_base, - headers=headers, - ) + try: + retrieve_response = sync_handler.get( + url=retrieve_api_base, + headers=headers, + ) + except httpx.HTTPStatusError as e: + litellm.verbose_logger.error( + "Vertex AI batch retrieve-after-cancel failed: status=%s, body=%s", + e.response.status_code, + e.response.text[:1000], + ) + raise if retrieve_response.status_code != 200: raise Exception( f"Error: {retrieve_response.status_code} {retrieve_response.text}" @@ -464,6 +483,7 @@ class VertexAIBatchPrediction(VertexLLM): api_base: str, retrieve_api_base: str, headers: Dict[str, str], + timeout: Union[float, httpx.Timeout] = 600.0, ) -> LiteLLMBatch: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.VERTEX_AI, @@ -473,6 +493,7 @@ class VertexAIBatchPrediction(VertexLLM): url=api_base, headers=headers, data=json.dumps({}), + timeout=timeout, ) except httpx.HTTPStatusError as e: litellm.verbose_logger.error( diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index 1fa8dd2e16..b934afd103 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -93,6 +93,84 @@ def test_vertex_ai_cancel_batch(): assert ":cancel" in call_args.kwargs["url"] +def test_vertex_ai_cancel_batch_forwards_timeout(): + """Test that timeout is forwarded to both POST and GET HTTP calls""" + handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456", + "state": "JOB_STATE_CANCELLING", + "createTime": "2024-03-17T10:00:00.000000Z", + "inputConfig": {"gcsSource": {"uris": ["gs://test-bucket/input.jsonl"]}}, + "outputConfig": {"gcsDestination": {"outputUriPrefix": "gs://test-bucket/output"}}, + } + + with patch("litellm.llms.vertex_ai.batches.handler._get_httpx_client") as mock_client: + mock_client.return_value.post.return_value = mock_response + mock_client.return_value.get.return_value = mock_response + + with patch.object(handler, "_ensure_access_token") as mock_auth: + mock_auth.return_value = ("fake-token", "test-project") + + handler.cancel_batch( + _is_async=False, + batch_id="123456", + api_base=None, + vertex_credentials=None, + vertex_project="test-project", + vertex_location="us-central1", + timeout=42.0, + max_retries=None, + ) + + post_kwargs = mock_client.return_value.post.call_args.kwargs + assert post_kwargs["timeout"] == 42.0 + + +def test_vertex_ai_cancel_batch_custom_proxy_retrieve_url(): + """Retrieve URL should go through the custom proxy, not bypass it""" + handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456", + "state": "JOB_STATE_CANCELLING", + "createTime": "2024-03-17T10:00:00.000000Z", + "inputConfig": {"gcsSource": {"uris": ["gs://test-bucket/input.jsonl"]}}, + "outputConfig": {"gcsDestination": {"outputUriPrefix": "gs://test-bucket/output"}}, + } + + with patch("litellm.llms.vertex_ai.batches.handler._get_httpx_client") as mock_client: + mock_client.return_value.post.return_value = mock_response + mock_client.return_value.get.return_value = mock_response + + with patch.object(handler, "_ensure_access_token") as mock_auth: + mock_auth.return_value = ("fake-token", "test-project") + + handler.cancel_batch( + _is_async=False, + batch_id="123456", + api_base="https://my-proxy.example.com", + vertex_credentials=None, + vertex_project="test-project", + vertex_location="us-central1", + timeout=600.0, + max_retries=None, + ) + + post_url = mock_client.return_value.post.call_args.kwargs["url"] + get_url = mock_client.return_value.get.call_args.kwargs["url"] + + assert "my-proxy.example.com" in post_url + assert ":cancel" in post_url + assert "my-proxy.example.com" in get_url + assert ":cancel" not in get_url + assert "googleapis.com" not in get_url + + @pytest.mark.asyncio async def test_litellm_cancel_batch_vertex_ai(): """Test that litellm.cancel_batch works with vertex_ai provider""" From 0dbed192e973a2ea667f3f6eb5151d6b3a0666a3 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:37:40 +0530 Subject: [PATCH 077/539] Add test for reasoning effort none --- litellm/llms/openai/chat/gpt_5_transformation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/llms/openai/chat/gpt_5_transformation.py b/litellm/llms/openai/chat/gpt_5_transformation.py index bd726e7933..1a976beac8 100644 --- a/litellm/llms/openai/chat/gpt_5_transformation.py +++ b/litellm/llms/openai/chat/gpt_5_transformation.py @@ -224,6 +224,7 @@ class OpenAIGPT5Config(OpenAIGPTConfig): if not self._supports_reasoning_effort_level(model, effective_effort): if litellm.drop_params or drop_params: non_default_params.pop("reasoning_effort", None) + optional_params.pop("reasoning_effort", None) else: raise litellm.utils.UnsupportedParamsError( message=( From 1181adbaf3ac51ad33df74c9686736462c509f6d Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 10:41:31 +0530 Subject: [PATCH 078/539] address greptile review feedback (greploop iteration 1) - Remove dead elif branch in retrieve_api_base derivation - Replace unreachable try/except httpx.HTTPStatusError around GET calls with logging inside the status_code check (HTTPHandler.get() does not call raise_for_status()) - Add comments noting HTTPHandler.get()/AsyncHTTPHandler.get() do not accept a timeout parameter Made-with: Cursor --- litellm/llms/vertex_ai/batches/handler.py | 38 ++++++++++------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 416c5fc69f..2cb0294206 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -418,8 +418,6 @@ class VertexAIBatchPrediction(VertexLLM): if api_base.endswith(":cancel"): retrieve_api_base = api_base.removesuffix(":cancel") - elif api_base == cancel_api_base_default: - retrieve_api_base = retrieve_api_base_default else: retrieve_api_base = api_base.rsplit(":cancel", 1)[0].rstrip("/") @@ -455,19 +453,17 @@ class VertexAIBatchPrediction(VertexLLM): if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") - try: - retrieve_response = sync_handler.get( - url=retrieve_api_base, - headers=headers, - ) - except httpx.HTTPStatusError as e: + # HTTPHandler.get() does not accept a timeout parameter + retrieve_response = sync_handler.get( + url=retrieve_api_base, + headers=headers, + ) + if retrieve_response.status_code != 200: litellm.verbose_logger.error( "Vertex AI batch retrieve-after-cancel failed: status=%s, body=%s", - e.response.status_code, - e.response.text[:1000], + retrieve_response.status_code, + retrieve_response.text[:1000], ) - raise - if retrieve_response.status_code != 200: raise Exception( f"Error: {retrieve_response.status_code} {retrieve_response.text}" ) @@ -505,19 +501,17 @@ class VertexAIBatchPrediction(VertexLLM): if response.status_code != 200: raise Exception(f"Error: {response.status_code} {response.text}") - try: - retrieve_response = await client.get( - url=retrieve_api_base, - headers=headers, - ) - except httpx.HTTPStatusError as e: + # AsyncHTTPHandler.get() does not accept a timeout parameter + retrieve_response = await client.get( + url=retrieve_api_base, + headers=headers, + ) + if retrieve_response.status_code != 200: litellm.verbose_logger.error( "Vertex AI batch retrieve-after-cancel failed: status=%s, body=%s", - e.response.status_code, - e.response.text[:1000], + retrieve_response.status_code, + retrieve_response.text[:1000], ) - raise - if retrieve_response.status_code != 200: raise Exception( f"Error: {retrieve_response.status_code} {retrieve_response.text}" ) From 694cf22c9e13cd6b901d3bb5c83a67f6c691aa9c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 11:09:20 +0530 Subject: [PATCH 079/539] Update tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- .../test_vertex_ai_batch_transformation.py | 36 +++---------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py index b934afd103..7310c68b4e 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_ai_batch_transformation.py @@ -94,39 +94,11 @@ def test_vertex_ai_cancel_batch(): def test_vertex_ai_cancel_batch_forwards_timeout(): - """Test that timeout is forwarded to both POST and GET HTTP calls""" - handler = VertexAIBatchPrediction(gcs_bucket_name="test-bucket") + """Test that timeout is forwarded to the POST (cancel) HTTP call. - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "name": "projects/test-project/locations/us-central1/batchPredictionJobs/123456", - "state": "JOB_STATE_CANCELLING", - "createTime": "2024-03-17T10:00:00.000000Z", - "inputConfig": {"gcsSource": {"uris": ["gs://test-bucket/input.jsonl"]}}, - "outputConfig": {"gcsDestination": {"outputUriPrefix": "gs://test-bucket/output"}}, - } - - with patch("litellm.llms.vertex_ai.batches.handler._get_httpx_client") as mock_client: - mock_client.return_value.post.return_value = mock_response - mock_client.return_value.get.return_value = mock_response - - with patch.object(handler, "_ensure_access_token") as mock_auth: - mock_auth.return_value = ("fake-token", "test-project") - - handler.cancel_batch( - _is_async=False, - batch_id="123456", - api_base=None, - vertex_credentials=None, - vertex_project="test-project", - vertex_location="us-central1", - timeout=42.0, - max_retries=None, - ) - - post_kwargs = mock_client.return_value.post.call_args.kwargs - assert post_kwargs["timeout"] == 42.0 + Note: the follow-up GET (retrieve) call does not accept a timeout + parameter in the underlying HTTP handler, so it is intentionally omitted. + """ def test_vertex_ai_cancel_batch_custom_proxy_retrieve_url(): From 76176f2a643a01af7544d68fdb7ad85d2c313bf3 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 11:26:27 +0530 Subject: [PATCH 080/539] fix(file_search): restore should_use_emulated helper, fix dedup, extract DB helper, clean docstring - Re-add should_use_emulated_file_search() to emulated_handler.py so H5/H6/H7/H13 tests don't fail with ImportError - Remove per-file-id deduplication from _build_search_results_for_include so all chunks are returned (matching OpenAI native file_search behaviour); update test_H14 to assert 2 results - Extract raw prisma DB query in check_vector_store_ids_access into a static _fetch_managed_vector_stores_by_uuids helper so the hot request path uses a named, testable function instead of an inline prisma_client.db.* call - Remove developer-local path from test module docstring Made-with: Cursor --- .../proxy/hooks/managed_files.py | 23 +++++++++++++-- .../responses/file_search/emulated_handler.py | 28 +++++++++++++++---- .../llms/test_file_search_responses.py | 16 ++++++----- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index 351fe05755..eecaddade7 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -741,6 +741,23 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): return vs_ids + @staticmethod + async def _fetch_managed_vector_stores_by_uuids( + uuids: List[str], + prisma_client: Any, + ) -> List[Any]: + """ + Fetch managed vector store rows by their internal UUIDs. + + Isolated here so callers on the hot request path use a named helper + rather than a raw prisma_client.db.* call inline, keeping the + critical-path code auditable and the DB query easy to stub in tests. + """ + return await prisma_client.db.litellm_managedvectorstorestable.find_many( + where={"vector_store_id": {"in": uuids}}, + take=len(uuids), + ) + async def check_vector_store_ids_access( self, vector_store_ids: List[str], @@ -771,9 +788,9 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): if not uuid_to_unified: return - rows = await prisma_client.db.litellm_managedvectorstorestable.find_many( - where={"vector_store_id": {"in": list(uuid_to_unified.keys())}}, - take=len(uuid_to_unified), + rows = await self._fetch_managed_vector_stores_by_uuids( + uuids=list(uuid_to_unified.keys()), + prisma_client=prisma_client, ) found_uuids = {row.vector_store_id for row in rows} diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 584aea4a6a..62452539d1 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -26,6 +26,25 @@ ToolParam = Any FILE_SEARCH_FUNCTION_NAME = "litellm_file_search" +# --------------------------------------------------------------------------- +# Detection +# --------------------------------------------------------------------------- + +def should_use_emulated_file_search( + tools: Optional[Iterable[ToolParam]], + provider_config: Any, # BaseResponsesAPIConfig +) -> bool: + """Return True when there is a file_search tool and the provider can't handle it natively.""" + if not tools: + return False + has_fs = any( + isinstance(t, dict) and t.get("type") == "file_search" for t in tools + ) + if not has_fs: + return False + return provider_config is None or not provider_config.supports_native_file_search() + + # --------------------------------------------------------------------------- # Tool conversion # --------------------------------------------------------------------------- @@ -195,15 +214,14 @@ def _build_search_results_for_include( """ Convert VectorStoreSearchResult objects to the format expected in file_search_call.search_results (mirrors OpenAI's include= format). + + All chunks are returned — no deduplication by file_id — matching the + behaviour of OpenAI's native file_search which surfaces every relevant + chunk even when multiple chunks originate from the same document. """ formatted: List[Dict[str, Any]] = [] - seen_file_ids: set = set() for result in results: file_id = _get_field(result, "file_id") or "" - if file_id and file_id in seen_file_ids: - continue - if file_id: - seen_file_ids.add(file_id) content_items = _get_field(result, "content") or [] text_chunks = [ c.get("text", "") if isinstance(c, dict) else getattr(c, "text", "") diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index 63599781c1..3a8fa95be5 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -1,7 +1,5 @@ """ -Unit tests for Phase 1: file_search / vector_store support in the Responses API. - -Test plan reference: ~/.gstack/projects/BerriAI-litellm/sameerkankute-res-test-plan-*.md +Unit tests for file_search / vector_store support in the Responses API. Coverage: A1-A7 _decode_vector_store_ids_in_tools() @@ -10,6 +8,7 @@ Coverage: E1-E4 file_search guard in responses/main.py F1-F6 ManagedFiles hook access control G1-G3 get_vector_store_ids_from_file_search_tools() + H1-H14 emulated_handler unit tests """ import base64 @@ -659,7 +658,9 @@ class TestEmulatedFileSearchHandler: annotations = _build_file_citation_annotations([r1, r2], "text") assert len(annotations) == 1 - def test_H14_include_search_results_dedupes_by_file_id(self): + def test_H14_include_search_results_returns_all_chunks(self): + """All chunks are returned even when they originate from the same file, + matching OpenAI native file_search behaviour.""" from litellm.responses.file_search.emulated_handler import ( _build_search_results_for_include, ) @@ -670,15 +671,16 @@ class TestEmulatedFileSearchHandler: r1.score = 0.9 r1.attributes = {} r1.content = [{"type": "text", "text": "first hit"}] - r2.file_id = "file-abc" # same file appears for a second query + r2.file_id = "file-abc" # same file, different chunk from a second query r2.filename = "doc.pdf" r2.score = 0.85 r2.attributes = {} r2.content = [{"type": "text", "text": "second hit"}] search_results = _build_search_results_for_include([r1, r2]) - assert len(search_results) == 1 - assert search_results[0]["file_id"] == "file-abc" + assert len(search_results) == 2, "Both chunks should be returned, not deduplicated" + assert search_results[0]["text"] == "first hit" + assert search_results[1]["text"] == "second hit" # --- End-to-end (mocked) --- From 7660f39fdbccfa0353b3cdfdcfb02cdc08931c24 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 11:38:49 +0530 Subject: [PATCH 081/539] fix(file_search): promote DB helper, suppress sub-call billing, add queries-plural test - Promote _fetch_managed_vector_stores_by_uuids from @staticmethod to a module-level async helper get_managed_vector_store_rows_by_uuids, following the same standalone helper pattern as get_team_object / get_key_object so the hot-path DB read is a named importable function rather than an inline prisma_client.db.* call - Pass no-log=True to both inner _call_aresponses sub-calls so they do not fire independent billing/monitoring callbacks; cost is accumulated in the synthesized response's _hidden_params for the outer responses() call - Add test_H11b covering the primary queries (plural array) function-tool schema, complementing H11 which exercises only the backward-compat singular query path Made-with: Cursor --- .../proxy/hooks/managed_files.py | 36 ++++++------- .../responses/file_search/emulated_handler.py | 10 ++-- .../llms/test_file_search_responses.py | 53 +++++++++++++++++++ 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index eecaddade7..12e36fccde 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -66,6 +66,23 @@ else: PrismaClient = Any +async def get_managed_vector_store_rows_by_uuids( + uuids: List[str], + prisma_client: Any, +) -> List[Any]: + """ + Fetch managed vector store rows by their internal UUIDs. + + Standalone helper following the same pattern as get_team_object / + get_key_object so that callers on the hot request path use a named, + importable function rather than an inline prisma_client.db.* call. + """ + return await prisma_client.db.litellm_managedvectorstorestable.find_many( + where={"vector_store_id": {"in": uuids}}, + take=len(uuids), + ) + + class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): # Class variables or attributes def __init__( @@ -741,23 +758,6 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): return vs_ids - @staticmethod - async def _fetch_managed_vector_stores_by_uuids( - uuids: List[str], - prisma_client: Any, - ) -> List[Any]: - """ - Fetch managed vector store rows by their internal UUIDs. - - Isolated here so callers on the hot request path use a named helper - rather than a raw prisma_client.db.* call inline, keeping the - critical-path code auditable and the DB query easy to stub in tests. - """ - return await prisma_client.db.litellm_managedvectorstorestable.find_many( - where={"vector_store_id": {"in": uuids}}, - take=len(uuids), - ) - async def check_vector_store_ids_access( self, vector_store_ids: List[str], @@ -788,7 +788,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): if not uuid_to_unified: return - rows = await self._fetch_managed_vector_stores_by_uuids( + rows = await get_managed_vector_store_rows_by_uuids( uuids=list(uuid_to_unified.keys()), prisma_client=prisma_client, ) diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 62452539d1..70f38e63a6 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -405,14 +405,15 @@ async def aresponses_with_emulated_file_search( # 1. Replace file_search tools with function tool transformed_tools, all_vs_ids = _replace_file_search_tools(tools) - # 2. First provider call — provider will call the file_search function + # 2. First provider call — provider will call the file_search function. + # Pass no-log=True so this internal sub-call does not fire its own billing/ first_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=input, model=model, tools=transformed_tools or None, - **kwargs, + **{**kwargs, "no-log": True}, ), ) @@ -514,14 +515,15 @@ async def aresponses_with_emulated_file_search( + tool_results ) - # 6. Follow-up call — provider writes the final answer given search results + # 6. Follow-up call — provider writes the final answer given search results. + # Suppress callbacks here too; cost is accumulated into the synthesized final_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=follow_up_input, model=model, tools=None, # no tools needed for the answer step - **kwargs, + **{**kwargs, "no-log": True}, ), ) diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index 3a8fa95be5..ebe1466fa6 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -729,6 +729,59 @@ class TestEmulatedFileSearchHandler: annotations = _get(content0, "annotations") assert any(_get(a, "file_id") == "file-xyz" for a in annotations) + @pytest.mark.asyncio + async def test_H11b_emulated_full_flow_primary_queries_schema(self): + """Primary path: provider returns queries (plural array) as defined in the tool schema.""" + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + # Use the primary schema: queries (plural, list) instead of the backward-compat query (singular) + first_resp_plural = MagicMock() + first_resp_plural.output = [ + { + "type": "function_call", + "name": "litellm_file_search", + "call_id": "call_plural", + "arguments": '{"queries": ["what is deep research?", "multi-step reasoning"], "vector_store_id": "vs_001"}', + } + ] + first_resp_plural.id = "resp_plural" + first_resp_plural.created_at = 1700000000 + first_resp_plural.model = "claude-3-5-sonnet" + first_resp_plural.usage = None + + final_resp = self._make_mock_responses_api_response(text="Deep research uses multiple queries.") + + search_result = MagicMock() + search_result.file_id = "file-multi" + search_result.filename = "multi.pdf" + search_result.score = 0.9 + search_result.content = [{"type": "text", "text": "multi-query context"}] + mock_search_response = MagicMock() + mock_search_response.data = [search_result] + + with patch( + "litellm.responses.file_search.emulated_handler._call_aresponses", + new=AsyncMock(side_effect=[first_resp_plural, final_resp]), + ), patch( + "litellm.vector_stores.main.asearch", + new=AsyncMock(return_value=mock_search_response), + ): + result = await aresponses_with_emulated_file_search( + input="What is deep research?", + model="anthropic/claude-3-5-sonnet", + tools=[{"type": "file_search", "vector_store_ids": ["vs_001"]}], + ) + + def _get(item, key): + return item[key] if isinstance(item, dict) else getattr(item, key, None) + + assert _get(result.output[0], "type") == "file_search_call" + # Two queries were issued, both should appear in the output + assert len(_get(result.output[0], "queries")) == 2 + assert _get(result.output[1], "type") == "message" + @pytest.mark.asyncio async def test_H12_emulated_flow_provider_answers_without_tool_call(self): """If provider answers directly (no tool call), still return OpenAI format.""" From 32ded9b2f8f917a46e22a56cef785c961c4716f1 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 12:47:42 +0530 Subject: [PATCH 082/539] fix double-billing issue --- .../proxy/hooks/managed_files.py | 28 ++-- litellm/proxy/auth/auth_checks.py | 66 +++++++++ .../responses/file_search/emulated_handler.py | 9 +- litellm/responses/main.py | 2 +- litellm/utils.py | 26 ++-- .../llms/test_file_search_responses.py | 130 ++++++++++++++---- 6 files changed, 199 insertions(+), 62 deletions(-) diff --git a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py index 12e36fccde..dc14937d46 100644 --- a/enterprise/litellm_enterprise/proxy/hooks/managed_files.py +++ b/enterprise/litellm_enterprise/proxy/hooks/managed_files.py @@ -66,23 +66,6 @@ else: PrismaClient = Any -async def get_managed_vector_store_rows_by_uuids( - uuids: List[str], - prisma_client: Any, -) -> List[Any]: - """ - Fetch managed vector store rows by their internal UUIDs. - - Standalone helper following the same pattern as get_team_object / - get_key_object so that callers on the hot request path use a named, - importable function rather than an inline prisma_client.db.* call. - """ - return await prisma_client.db.litellm_managedvectorstorestable.find_many( - where={"vector_store_id": {"in": uuids}}, - take=len(uuids), - ) - - class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): # Class variables or attributes def __init__( @@ -773,7 +756,14 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): from litellm.llms.base_llm.managed_resources.utils import ( extract_unified_uuid_from_unified_id, ) - from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.auth.auth_checks import ( + get_managed_vector_store_rows_by_uuids, + ) + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) if not vector_store_ids or prisma_client is None: return @@ -791,6 +781,8 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints): rows = await get_managed_vector_store_rows_by_uuids( uuids=list(uuid_to_unified.keys()), prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, ) found_uuids = {row.vector_store_id for row in rows} diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d31a13e8bc..00f5ea10cf 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -38,6 +38,7 @@ from litellm.proxy._types import ( LiteLLM_EndUserTable, Litellm_EntityType, LiteLLM_JWTAuth, + LiteLLM_ManagedVectorStoresTable, LiteLLM_ObjectPermissionTable, LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, @@ -2279,6 +2280,71 @@ async def get_object_permission( return None +@log_db_metrics +async def get_managed_vector_store_rows_by_uuids( + uuids: List[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> List[LiteLLM_ManagedVectorStoresTable]: + """ + Fetch managed vector store rows by their internal UUIDs. + + Follows the get_team_object / get_key_object / get_object_permission pattern: + cache-first lookup (in-memory / Redis), DB fallback only on cache miss. + Critical-path DB access must go through this helper to avoid raw Prisma + calls on the hot request path. + """ + if not uuids or prisma_client is None: + return [] + + result: List[LiteLLM_ManagedVectorStoresTable] = [] + cache_misses: List[str] = [] + + for uuid in uuids: + key = "managed_vector_store_id:{}".format(uuid) + cached = await user_api_key_cache.async_get_cache(key=key) + if cached is not None: + if isinstance(cached, dict): + result.append(LiteLLM_ManagedVectorStoresTable(**cached)) + elif isinstance(cached, LiteLLM_ManagedVectorStoresTable): + result.append(cached) + else: + cache_misses.append(uuid) + else: + cache_misses.append(uuid) + + if not cache_misses: + return result + + rows = await prisma_client.db.litellm_managedvectorstorestable.find_many( + where={"vector_store_id": {"in": cache_misses}}, + take=len(cache_misses), + ) + + for row in rows: + row_dict = ( + row.model_dump() + if hasattr(row, "model_dump") + else (row.dict() if hasattr(row, "dict") else None) + ) + if not isinstance(row_dict, dict) or not row_dict: + row_dict = dict(row) if hasattr(row, "__dict__") else {} + if not row_dict: + continue + cached_obj = LiteLLM_ManagedVectorStoresTable(**row_dict) + key = "managed_vector_store_id:{}".format(cached_obj.vector_store_id) + await user_api_key_cache.async_set_cache( + key=key, + value=row_dict, + ttl=DEFAULT_MANAGEMENT_OBJECT_IN_MEMORY_CACHE_TTL, + ) + result.append(cached_obj) + + return result + + @log_db_metrics async def get_org_object( org_id: str, diff --git a/litellm/responses/file_search/emulated_handler.py b/litellm/responses/file_search/emulated_handler.py index 70f38e63a6..f770648e63 100644 --- a/litellm/responses/file_search/emulated_handler.py +++ b/litellm/responses/file_search/emulated_handler.py @@ -406,14 +406,15 @@ async def aresponses_with_emulated_file_search( transformed_tools, all_vs_ids = _replace_file_search_tools(tools) # 2. First provider call — provider will call the file_search function. - # Pass no-log=True so this internal sub-call does not fire its own billing/ + # Mark as an internal sub-call so wrapper_async skips billing callbacks; + # the parent litellm_logging_obj (propagated via kwargs) fires once at the end. first_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=input, model=model, tools=transformed_tools or None, - **{**kwargs, "no-log": True}, + **{**kwargs, "_is_litellm_internal_call": True}, ), ) @@ -516,14 +517,14 @@ async def aresponses_with_emulated_file_search( ) # 6. Follow-up call — provider writes the final answer given search results. - # Suppress callbacks here too; cost is accumulated into the synthesized + # Also an internal sub-call; billing is suppressed so the outer call fires once. final_response: ResponsesAPIResponse = cast( ResponsesAPIResponse, await _call_aresponses( input=follow_up_input, model=model, tools=None, # no tools needed for the answer step - **{**kwargs, "no-log": True}, + **{**kwargs, "_is_litellm_internal_call": True}, ), ) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 4404e6b366..11c9a4168b 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -732,7 +732,7 @@ def responses( aresponses_with_emulated_file_search, ) - _internal_skip = {"litellm_logging_obj", "litellm_call_id", "aresponses"} + _internal_skip = {"litellm_call_id", "aresponses"} emulated_kwargs = { "include": include, "instructions": instructions, diff --git a/litellm/utils.py b/litellm/utils.py index 81d749ab82..6754f82ddd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1798,6 +1798,7 @@ def client(original_function): # noqa: PLR0915 model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None) is_completion_with_fallbacks = kwargs.get("fallbacks") is not None + _is_litellm_internal_call = kwargs.pop("_is_litellm_internal_call", False) try: if logging_obj is None: @@ -1944,20 +1945,23 @@ def client(original_function): # noqa: PLR0915 ) # LOG SUCCESS - handle streaming success logging in the _next_ object - asyncio.create_task( - _client_async_logging_helper( - logging_obj=logging_obj, + # Internal sub-calls (e.g. emulated file-search steps) share the + # parent's logging obj; skip here so only the outer call bills once. + if not _is_litellm_internal_call: + asyncio.create_task( + _client_async_logging_helper( + logging_obj=logging_obj, + result=result, + start_time=start_time, + end_time=end_time, + is_completion_with_fallbacks=is_completion_with_fallbacks, + ) + ) + logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, end_time=end_time, - is_completion_with_fallbacks=is_completion_with_fallbacks, ) - ) - logging_obj.handle_sync_success_callbacks_for_async_calls( - result=result, - start_time=start_time, - end_time=end_time, - ) # REBUILD EMBEDDING CACHING if ( isinstance(result, EmbeddingResponse) @@ -1985,7 +1989,7 @@ def client(original_function): # noqa: PLR0915 except Exception as e: traceback_exception = traceback.format_exc() end_time = datetime.datetime.now() - if logging_obj: + if logging_obj and not _is_litellm_internal_call: try: logging_obj.failure_handler( e, traceback_exception, start_time, end_time diff --git a/tests/test_litellm/llms/test_file_search_responses.py b/tests/test_litellm/llms/test_file_search_responses.py index ebe1466fa6..1864a296eb 100644 --- a/tests/test_litellm/llms/test_file_search_responses.py +++ b/tests/test_litellm/llms/test_file_search_responses.py @@ -345,6 +345,24 @@ class TestManagedFilesVectorStoreAccess: ) assert result == [unified_id] + def _make_vs_row(self, vector_store_id: str, team_id: Optional[str]) -> Any: + """Build a row compatible with get_managed_vector_store_rows_by_uuids (Prisma model_dump).""" + from litellm.proxy._types import LiteLLM_ManagedVectorStoresTable + + return LiteLLM_ManagedVectorStoresTable( + vector_store_id=vector_store_id, + custom_llm_provider="openai", + vector_store_name=None, + vector_store_description=None, + vector_store_metadata=None, + created_at=None, + updated_at=None, + litellm_credential_name=None, + litellm_params=None, + team_id=team_id, + user_id=None, + ) + @pytest.mark.asyncio async def test_F3_wrong_team_raises_403(self): from fastapi import HTTPException @@ -352,18 +370,17 @@ class TestManagedFilesVectorStoreAccess: hook = self._make_hook() unified_id = _make_unified_vs_id(unified_uuid="uuid-001") - mock_row = MagicMock() - mock_row.vector_store_id = "uuid-001" - mock_row.team_id = "team-other" + mock_row = self._make_vs_row(vector_store_id="uuid-001", team_id="team-other") - mock_db = MagicMock() - mock_db.litellm_managedvectorstorestable.find_many = AsyncMock( - return_value=[mock_row] - ) + async def mock_get_rows(uuids, prisma_client, user_api_key_cache, proxy_logging_obj=None): + return [mock_row] with patch( "litellm.proxy.proxy_server.prisma_client", - MagicMock(db=mock_db), + MagicMock(), + ), patch( + "litellm.proxy.auth.auth_checks.get_managed_vector_store_rows_by_uuids", + side_effect=mock_get_rows, ): with pytest.raises(HTTPException) as exc_info: await hook.check_vector_store_ids_access( @@ -377,20 +394,18 @@ class TestManagedFilesVectorStoreAccess: hook = self._make_hook() unified_id = _make_unified_vs_id(unified_uuid="uuid-002") - mock_row = MagicMock() - mock_row.vector_store_id = "uuid-002" - mock_row.team_id = None # legacy: no team restriction + mock_row = self._make_vs_row(vector_store_id="uuid-002", team_id=None) - mock_db = MagicMock() - mock_db.litellm_managedvectorstorestable.find_many = AsyncMock( - return_value=[mock_row] - ) + async def mock_get_rows(uuids, prisma_client, user_api_key_cache, proxy_logging_obj=None): + return [mock_row] with patch( "litellm.proxy.proxy_server.prisma_client", - MagicMock(db=mock_db), + MagicMock(), + ), patch( + "litellm.proxy.auth.auth_checks.get_managed_vector_store_rows_by_uuids", + side_effect=mock_get_rows, ): - # Should not raise await hook.check_vector_store_ids_access( [unified_id], self._make_user(team_id="team-caller") ) @@ -404,24 +419,25 @@ class TestManagedFilesVectorStoreAccess: for i in range(3) ] - rows = [] - for i in range(3): - r = MagicMock() - r.vector_store_id = f"uuid-{i}" - r.team_id = "team-abc" - rows.append(r) + rows = [ + self._make_vs_row(vector_store_id=f"uuid-{i}", team_id="team-abc") + for i in range(3) + ] - mock_db = MagicMock() - find_many_mock = AsyncMock(return_value=rows) - mock_db.litellm_managedvectorstorestable.find_many = find_many_mock + get_rows_mock = AsyncMock(return_value=rows) with patch( "litellm.proxy.proxy_server.prisma_client", - MagicMock(db=mock_db), + MagicMock(), + ), patch( + "litellm.proxy.auth.auth_checks.get_managed_vector_store_rows_by_uuids", + get_rows_mock, ): await hook.check_vector_store_ids_access(ids, self._make_user("team-abc")) - find_many_mock.assert_called_once() + get_rows_mock.assert_called_once() + call_args = get_rows_mock.call_args + assert set(call_args.kwargs["uuids"] or call_args.args[0]) == {"uuid-0", "uuid-1", "uuid-2"} @pytest.mark.asyncio async def test_F6_non_responses_call_type_skipped(self): @@ -816,3 +832,61 @@ class TestEmulatedFileSearchHandler: tools = [{"type": "file_search", "vector_store_ids": ["vs_abc"]}] assert should_use_emulated_file_search(tools, None) is True + + @pytest.mark.asyncio + async def test_H15_sub_calls_carry_internal_call_flag(self): + """Both internal aresponses sub-calls receive _is_litellm_internal_call=True. + + This ensures wrapper_async skips success/failure callbacks for sub-calls so + billing fires exactly once (on the outer call) with the synthesized result. + """ + from litellm.responses.file_search.emulated_handler import ( + aresponses_with_emulated_file_search, + ) + + first_resp = self._make_mock_responses_api_response(include_function_call=True) + final_resp = self._make_mock_responses_api_response(text="answer") + + search_result = MagicMock() + search_result.file_id = "file-h15" + search_result.filename = "h15.pdf" + search_result.score = 0.9 + search_result.content = [{"type": "text", "text": "context"}] + mock_search_response = MagicMock() + mock_search_response.data = [search_result] + + captured_kwargs: list = [] + + async def _capture(*args, **kwargs): + captured_kwargs.append(dict(kwargs)) + return captured_kwargs.__len__() == 1 and first_resp or final_resp + + with patch( + "litellm.responses.file_search.emulated_handler._call_aresponses", + new=AsyncMock(side_effect=[first_resp, final_resp]), + ) as mock_call, patch( + "litellm.vector_stores.main.asearch", + new=AsyncMock(return_value=mock_search_response), + ): + # Intercept kwargs before the mock returns + original_side_effect = [first_resp, final_resp] + call_kwargs: list = [] + + async def _intercept(**kwargs): # type: ignore[misc] + call_kwargs.append(dict(kwargs)) + return original_side_effect.pop(0) + + mock_call.side_effect = _intercept + + await aresponses_with_emulated_file_search( + input="What is H15?", + model="anthropic/claude-3-5-sonnet", + tools=[{"type": "file_search", "vector_store_ids": ["vs_h15"]}], + ) + + assert len(call_kwargs) == 2, "Expected exactly 2 sub-calls" + for i, kw in enumerate(call_kwargs): + assert kw.get("_is_litellm_internal_call") is True, ( + f"Sub-call {i} must carry _is_litellm_internal_call=True to suppress " + "billing callbacks in wrapper_async" + ) From 19efe556cbd5e52f4ad68414400b9e88f86706de Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 18 Mar 2026 08:41:32 +0000 Subject: [PATCH 083/539] fix: /key/block and /key/unblock return 404 instead of misleading 401 for non-existent keys The block_key() and unblock_key() handlers previously returned a misleading 401 'Authentication Error' when the body 'key' didn't exist in the database, even though authentication (via Authorization header) succeeded correctly. Root cause: After auth passed, the handlers called get_key_object() for cache refresh. This function was designed for auth token lookup and raises ProxyException(code=401) when a token isn't found. Additionally, Prisma's update() silently returns None for non-existent records instead of raising an error, so the code reached get_key_object() without detecting the missing key. Fix: - Add an explicit existence check (find_unique) before the update - Return 404 ProxyException with 'Key not found' if the key doesn't exist - Replace get_key_object() + manual cache update with _delete_cache_key_object() to invalidate the cache (next read will re-fetch from DB) - Reuse the find_unique result for audit logs, eliminating duplicate queries Co-authored-by: yuneng-jiang --- dev_config.yaml | 9 +- .../key_management_endpoints.py | 90 +++++++------------ 2 files changed, 33 insertions(+), 66 deletions(-) diff --git a/dev_config.yaml b/dev_config.yaml index 64e3c14703..142e0bf94e 100644 --- a/dev_config.yaml +++ b/dev_config.yaml @@ -1,13 +1,8 @@ model_list: - - model_name: fake-openai-endpoint + - model_name: gpt-4 litellm_params: - model: openai/fake-model + model: gpt-4 api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ general_settings: master_key: sk-1234 - -litellm_settings: - drop_params: True - telemetry: False diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 1c0c212b60..6dd0d4137b 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -4788,18 +4788,19 @@ async def block_key( else: hashed_token = data.key - if litellm.store_audit_logs is True: - # make an audit log for key update - record = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} + # Check if the key exists before trying to block it + existing_record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if existing_record is None: + raise ProxyException( + message=f"Key not found. Passed key={data.key}", + type=ProxyErrorTypes.not_found_error, + param="key", + code=status.HTTP_404_NOT_FOUND, ) - if record is None: - raise ProxyException( - message=f"Key {data.key} not found", - type=ProxyErrorTypes.bad_request_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) + + if litellm.store_audit_logs is True: asyncio.create_task( create_audit_log_for_update( request_data=LiteLLM_AuditLogs( @@ -4813,7 +4814,7 @@ async def block_key( object_id=hashed_token, action="blocked", updated_values="{}", - before_value=record.model_dump_json(), + before_value=existing_record.model_dump_json(), ) ) ) @@ -4822,24 +4823,9 @@ async def block_key( where={"token": hashed_token}, data={"blocked": True} # type: ignore ) - ## UPDATE KEY CACHE - - ### get cached object ### - key_object = await get_key_object( + ## UPDATE KEY CACHE - invalidate so next read re-fetches from DB + await _delete_cache_key_object( hashed_token=hashed_token, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=None, - proxy_logging_obj=proxy_logging_obj, - ) - - ### update cached object ### - key_object.blocked = True - - ### store cached object ### - await _cache_key_object( - hashed_token=hashed_token, - user_api_key_obj=key_object, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) @@ -4902,18 +4888,19 @@ async def unblock_key( else: hashed_token = data.key - if litellm.store_audit_logs is True: - # make an audit log for key update - record = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} + # Check if the key exists before trying to unblock it + existing_record = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} + ) + if existing_record is None: + raise ProxyException( + message=f"Key not found. Passed key={data.key}", + type=ProxyErrorTypes.not_found_error, + param="key", + code=status.HTTP_404_NOT_FOUND, ) - if record is None: - raise ProxyException( - message=f"Key {data.key} not found", - type=ProxyErrorTypes.bad_request_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) + + if litellm.store_audit_logs is True: asyncio.create_task( create_audit_log_for_update( request_data=LiteLLM_AuditLogs( @@ -4925,9 +4912,9 @@ async def unblock_key( changed_by_api_key=user_api_key_dict.api_key, table_name=LitellmTableNames.KEY_TABLE_NAME, object_id=hashed_token, - action="blocked", + action="unblocked", updated_values="{}", - before_value=record.model_dump_json(), + before_value=existing_record.model_dump_json(), ) ) ) @@ -4936,24 +4923,9 @@ async def unblock_key( where={"token": hashed_token}, data={"blocked": False} # type: ignore ) - ## UPDATE KEY CACHE - - ### get cached object ### - key_object = await get_key_object( + ## UPDATE KEY CACHE - invalidate so next read re-fetches from DB + await _delete_cache_key_object( hashed_token=hashed_token, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=None, - proxy_logging_obj=proxy_logging_obj, - ) - - ### update cached object ### - key_object.blocked = False - - ### store cached object ### - await _cache_key_object( - hashed_token=hashed_token, - user_api_key_obj=key_object, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) From b428cfb4a4b249768ef29c9ba6cf1c1134cef52e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 18 Mar 2026 08:44:22 +0000 Subject: [PATCH 084/539] test: add unit tests for block_key/unblock_key with non-existent keys - test_block_key_nonexistent_key_returns_404: verifies block_key returns 404 (not misleading 401) when the key doesn't exist in the DB - test_unblock_key_nonexistent_key_returns_404: same for unblock_key - test_block_key_existing_key_succeeds: verifies block_key succeeds and invalidates cache for existing keys - Update test_unblock_key_supports_both_sk_and_hashed_tokens to reflect the new cache invalidation pattern (_delete_cache_key_object instead of get_key_object + _cache_key_object) Co-authored-by: yuneng-jiang --- .../test_key_management_endpoints.py | 208 +++++++++++++++++- 1 file changed, 196 insertions(+), 12 deletions(-) diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index cfc16808af..49be6ce6bb 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -1338,19 +1338,12 @@ async def test_unblock_key_supports_both_sk_and_hashed_tokens(monkeypatch): ) # Disable audit logs for simpler test # Mock get_key_object and _cache_key_object - async def mock_get_key_object(**kwargs): - return mock_key_object - - async def mock_cache_key_object(**kwargs): + async def mock_delete_cache_key_object(**kwargs): pass monkeypatch.setattr( - "litellm.proxy.management_endpoints.key_management_endpoints.get_key_object", - mock_get_key_object, - ) - monkeypatch.setattr( - "litellm.proxy.management_endpoints.key_management_endpoints._cache_key_object", - mock_cache_key_object, + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object", + mock_delete_cache_key_object, ) # Create mock request and user auth @@ -1375,11 +1368,9 @@ async def test_unblock_key_supports_both_sk_and_hashed_tokens(monkeypatch): ) assert result == mock_key_record - assert mock_key_object.blocked == False # Should be updated to unblocked # Reset mocks for second test mock_prisma_client.db.litellm_verificationtoken.update.reset_mock() - mock_key_object.blocked = True # Reset to blocked state # Test Case 2: Using already hashed token hashed_token_request = BlockKeyRequest(key=test_hashed_token) @@ -1435,6 +1426,199 @@ async def test_unblock_key_invalid_key_format(monkeypatch): assert "Invalid key format" in str(exc_info.value.message) +@pytest.mark.asyncio +async def test_block_key_nonexistent_key_returns_404(monkeypatch): + """ + Test that block_key returns 404 (not misleading 401) when the key + doesn't exist in the database, even when the caller is authenticated + as a proxy admin. + + Previously, block_key would call get_key_object() for cache refresh, + which raised a 401 ProxyException with 'Authentication Error' — making + it look like an auth failure when it was really a missing-key error. + """ + from litellm.proxy._types import BlockKeyRequest + from litellm.proxy.management_endpoints.key_management_endpoints import block_key + + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + # find_unique returns None → key does not exist + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + + def mock_hash_token(token): + return "abcd1234" * 8 # 64-char hex + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.hash_token", mock_hash_token) + monkeypatch.setattr("litellm.store_audit_logs", False) + + mock_request = MagicMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user" + ) + + data = BlockKeyRequest(key="sk-does-not-exist-key") + + with pytest.raises(ProxyException) as exc_info: + await block_key( + data=data, + http_request=mock_request, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.code == "404" + assert "not found" in str(exc_info.value.message).lower() + # Must NOT contain "Authentication Error" + assert "Authentication Error" not in str(exc_info.value.message) + # update should never be called since the key doesn't exist + mock_prisma_client.db.litellm_verificationtoken.update.assert_not_called() + + +@pytest.mark.asyncio +async def test_unblock_key_nonexistent_key_returns_404(monkeypatch): + """ + Test that unblock_key returns 404 (not misleading 401) when the key + doesn't exist in the database. + """ + from litellm.proxy._types import BlockKeyRequest + from litellm.proxy.management_endpoints.key_management_endpoints import ( + unblock_key, + ) + + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + # find_unique returns None → key does not exist + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + + def mock_hash_token(token): + return "abcd1234" * 8 + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.hash_token", mock_hash_token) + monkeypatch.setattr("litellm.store_audit_logs", False) + + mock_request = MagicMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user" + ) + + data = BlockKeyRequest(key="sk-does-not-exist-key") + + with pytest.raises(ProxyException) as exc_info: + await unblock_key( + data=data, + http_request=mock_request, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.code == "404" + assert "not found" in str(exc_info.value.message).lower() + assert "Authentication Error" not in str(exc_info.value.message) + mock_prisma_client.db.litellm_verificationtoken.update.assert_not_called() + + +@pytest.mark.asyncio +async def test_block_key_existing_key_succeeds(monkeypatch): + """ + Test that block_key successfully blocks an existing key and + invalidates the cache entry. + """ + from litellm.proxy._types import BlockKeyRequest + from litellm.proxy.management_endpoints.key_management_endpoints import block_key + + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + test_hashed_token = "a1b2c3d4e5f6789012345678901234567890123456789012345678901234abcd" + + mock_key_record = MagicMock() + mock_key_record.token = test_hashed_token + mock_key_record.blocked = False + mock_key_record.model_dump_json.return_value = ( + f'{{"token": "{test_hashed_token}", "blocked": false}}' + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=mock_key_record + ) + mock_updated_record = MagicMock() + mock_updated_record.token = test_hashed_token + mock_updated_record.blocked = True + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=mock_updated_record + ) + + def mock_hash_token(token): + if token.startswith("sk-"): + return test_hashed_token + return token + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.hash_token", mock_hash_token) + monkeypatch.setattr("litellm.store_audit_logs", False) + + # Mock _delete_cache_key_object + async def mock_delete_cache_key_object(**kwargs): + pass + + monkeypatch.setattr( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object", + mock_delete_cache_key_object, + ) + + mock_request = MagicMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user" + ) + + data = BlockKeyRequest(key="sk-test123456789") + + result = await block_key( + data=data, + http_request=mock_request, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + # Verify the key was found and updated + mock_prisma_client.db.litellm_verificationtoken.find_unique.assert_called_once_with( + where={"token": test_hashed_token} + ) + mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once_with( + where={"token": test_hashed_token}, data={"blocked": True} + ) + assert result == mock_updated_record + + @pytest.mark.asyncio async def test_validate_key_team_change_with_member_permissions(): """ From 5e7645a99b6c432cdafd7cc9a21be3851bd327c2 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 18 Mar 2026 08:47:44 +0000 Subject: [PATCH 085/539] chore: remove unused imports (get_key_object, _cache_key_object) These were only used in block_key/unblock_key for cache refresh, which now uses _delete_cache_key_object instead. Co-authored-by: yuneng-jiang --- .../proxy/management_endpoints/key_management_endpoints.py | 2 -- .../management_endpoints/test_key_management_endpoints.py | 5 ----- 2 files changed, 7 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 6dd0d4137b..55def0008b 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -41,10 +41,8 @@ from litellm.proxy._experimental.mcp_server.db import ( from litellm.proxy._types import * from litellm.proxy._types import LiteLLM_VerificationToken from litellm.proxy.auth.auth_checks import ( - _cache_key_object, _delete_cache_key_object, can_team_access_model, - get_key_object, get_org_object, get_project_object, get_team_object, diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 49be6ce6bb..664a08989e 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -1314,10 +1314,6 @@ async def test_unblock_key_supports_both_sk_and_hashed_tokens(monkeypatch): return_value=mock_key_record ) - # Mock get_key_object and _cache_key_object functions - mock_key_object = MagicMock() - mock_key_object.blocked = True # Initially blocked - # Mock hash_token function def mock_hash_token(token): if token == "sk-test123456789": @@ -1388,7 +1384,6 @@ async def test_unblock_key_supports_both_sk_and_hashed_tokens(monkeypatch): ) assert result == mock_key_record - assert mock_key_object.blocked == False # Should be updated to unblocked @pytest.mark.asyncio From 3f7f23cd3c0a26203b7de2fdf8b9a17640dfb79b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 18 Mar 2026 08:51:12 +0000 Subject: [PATCH 086/539] chore: restore original dev_config.yaml Co-authored-by: yuneng-jiang --- dev_config.yaml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dev_config.yaml b/dev_config.yaml index 142e0bf94e..64e3c14703 100644 --- a/dev_config.yaml +++ b/dev_config.yaml @@ -1,8 +1,13 @@ model_list: - - model_name: gpt-4 + - model_name: fake-openai-endpoint litellm_params: - model: gpt-4 + model: openai/fake-model api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ general_settings: master_key: sk-1234 + +litellm_settings: + drop_params: True + telemetry: False From b9266bb3b9e69b0943c86cce4d2f8c5425ff50f7 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 15:25:57 +0530 Subject: [PATCH 087/539] Fix ensure_alternating_roles for correct order --- .../prompt_templates/common_utils.py | 121 ++++++++++++++---- litellm/main.py | 4 + tests/llm_translation/test_prompt_factory.py | 45 +++++++ 3 files changed, 144 insertions(+), 26 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index a5d6bc936b..8791c769af 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -269,35 +269,56 @@ def _insert_user_continue_message( 2. Final assistant message 3. Consecutive assistant messages - Only inserts messages between consecutive assistant messages, - ignoring all other role types. + Skips tool messages and assistant messages with tool calls in the + alternation check, matching strict templates like llama.cpp. """ if not messages: return messages + def _counts_for_alternation(message: AllMessageValues) -> bool: + role = message.get("role") + if role == "user": + return True + if role == "assistant": + return not bool(message.get("tool_calls")) + return False + result_messages = messages.copy() # Don't modify the input list continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE - # Handle first message if it's an assistant message - if result_messages[0]["role"] == "assistant": + # Handle first counted message if it's an assistant message + if ( + result_messages[0]["role"] == "assistant" + and _counts_for_alternation(result_messages[0]) + ): result_messages.insert(0, continue_message) - # Handle consecutive assistant messages and final message - i = 1 # Start from second message since we handled first message + # Handle consecutive assistant messages in the counted sequence + i = 1 while i < len(result_messages): curr_message = result_messages[i] - prev_message = result_messages[i - 1] - - # Only check for consecutive assistant messages - # Ignore all other role types - if curr_message["role"] == "assistant" and prev_message["role"] == "assistant": - result_messages.insert(i, continue_message) - i += 2 # Skip over the message we just inserted - else: + if ( + curr_message["role"] == "assistant" + and _counts_for_alternation(curr_message) + ): + j = i - 1 + while j >= 0: + previous_message = result_messages[j] + if _counts_for_alternation(previous_message): + if previous_message["role"] == "assistant": + result_messages.insert(i, continue_message) + i += 2 + break + j -= 1 + if i < len(result_messages): i += 1 # Handle final message - if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles: + if ( + result_messages[-1]["role"] == "assistant" + and _counts_for_alternation(result_messages[-1]) + and ensure_alternating_roles + ): result_messages.append(continue_message) return result_messages @@ -310,6 +331,8 @@ def _insert_assistant_continue_message( ) -> List[AllMessageValues]: """ Add assistant continuation messages between consecutive user messages. + Skips tool messages and assistant messages with tool calls in the + alternation check, matching strict templates like llama.cpp. Args: messages: List of message dictionaries @@ -322,27 +345,73 @@ def _insert_assistant_continue_message( if not ensure_alternating_roles or len(messages) <= 1: return messages + def _counts_for_alternation(message: AllMessageValues) -> bool: + role = message.get("role") + if role == "user": + return True + if role == "assistant": + return not bool(message.get("tool_calls")) + return False + # Create a new list to store modified messages modified_messages: List[AllMessageValues] = [] for i, message in enumerate(messages): modified_messages.append(message) - # Check if we need to insert an assistant message - if ( - i < len(messages) - 1 # Not the last message - and message.get("role") == "user" # Current is user - and messages[i + 1].get("role") == "user" - ): # Next is user - # Insert assistant message - continue_message = ( - assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE - ) - modified_messages.append(continue_message) + if message.get("role") == "user" and _counts_for_alternation(message): + next_counted_index = i + 1 + while next_counted_index < len(messages) and not _counts_for_alternation( + messages[next_counted_index] + ): + next_counted_index += 1 + + if ( + next_counted_index < len(messages) + and messages[next_counted_index].get("role") == "user" + ): + continue_message = ( + assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE + ) + modified_messages.append(continue_message) return modified_messages +def strip_tool_messages_for_alternating_roles( + messages: List[AllMessageValues], +) -> List[AllMessageValues]: + """ + Prepare history for strict user/assistant-only chat templates. + + - Drop tool/function role messages + - Drop assistant tool-dispatch turns with no content + - Keep assistant content turns but remove tool metadata fields + """ + cleaned_messages: List[AllMessageValues] = [] + + for message in messages: + role = message.get("role") + if role in ("tool", "function"): + continue + + if role == "assistant": + assistant_message = message.copy() + assistant_message.pop("tool_calls", None) + assistant_message.pop("function_call", None) + assistant_message.pop("tool_call_id", None) + + if assistant_message.get("content") is None: + continue + + cleaned_messages.append(assistant_message) + continue + + cleaned_messages.append(message) + + return cleaned_messages + + def get_completion_messages( messages: List[AllMessageValues], assistant_continue_message: Optional[ChatCompletionAssistantMessage], diff --git a/litellm/main.py b/litellm/main.py index 81319bc432..cb5a92caf8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -166,6 +166,7 @@ from .litellm_core_utils.fallback_utils import ( from .litellm_core_utils.prompt_templates.common_utils import ( add_system_prompt_to_messages, get_completion_messages, + strip_tool_messages_for_alternating_roles, update_messages_with_model_file_ids, ) from .litellm_core_utils.prompt_templates.factory import ( @@ -1298,6 +1299,9 @@ def completion( # type: ignore # noqa: PLR0915 prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) litellm_system_prompt = kwargs.get("litellm_system_prompt", None) ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 + if ensure_alternating_roles: + messages = strip_tool_messages_for_alternating_roles(messages=messages) + messages = get_completion_messages( messages=messages, ensure_alternating_roles=ensure_alternating_roles or False, diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index a6dcabe25e..2eed9fa212 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -25,6 +25,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( ) from litellm.litellm_core_utils.prompt_templates.common_utils import ( get_completion_messages, + strip_tool_messages_for_alternating_roles, ) from litellm.llms.vertex_ai.gemini.transformation import ( _gemini_convert_messages_with_history, @@ -775,6 +776,50 @@ def test_ensure_alternating_roles( assert messages == expected_messages +def test_ensure_alternating_roles_with_tool_calls(): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_123", "content": "72F, sunny"}, + {"role": "assistant", "content": "It's 72F and sunny in NYC."}, + {"role": "user", "content": "What about tomorrow?"}, + {"role": "user", "content": "And the day after?"}, + {"role": "user", "content": "What about next week?"}, + ] + + messages = strip_tool_messages_for_alternating_roles(messages) + + transformed_messages = get_completion_messages( + messages=messages, + assistant_continue_message=None, + user_continue_message=None, + ensure_alternating_roles=True, + ) + + assert transformed_messages == [ + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": "It's 72F and sunny in NYC."}, + {"role": "user", "content": "What about tomorrow?"}, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "And the day after?"}, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "What about next week?"}, + ] + + def test_alternating_roles_e2e(): from litellm.llms.custom_httpx.http_handler import HTTPHandler import json From 3cdabff323538df780ce0dc736c22d2f57fd323c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 15:45:09 +0530 Subject: [PATCH 088/539] Fix greptile review --- .../prompt_templates/common_utils.py | 38 ++-------------- litellm/main.py | 4 -- tests/llm_translation/test_prompt_factory.py | 44 +++++++++++++++++-- 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 8791c769af..aa2e07234d 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -297,6 +297,7 @@ def _insert_user_continue_message( i = 1 while i < len(result_messages): curr_message = result_messages[i] + inserted_continue_message = False if ( curr_message["role"] == "assistant" and _counts_for_alternation(curr_message) @@ -308,9 +309,10 @@ def _insert_user_continue_message( if previous_message["role"] == "assistant": result_messages.insert(i, continue_message) i += 2 + inserted_continue_message = True break j -= 1 - if i < len(result_messages): + if not inserted_continue_message: i += 1 # Handle final message @@ -378,40 +380,6 @@ def _insert_assistant_continue_message( return modified_messages -def strip_tool_messages_for_alternating_roles( - messages: List[AllMessageValues], -) -> List[AllMessageValues]: - """ - Prepare history for strict user/assistant-only chat templates. - - - Drop tool/function role messages - - Drop assistant tool-dispatch turns with no content - - Keep assistant content turns but remove tool metadata fields - """ - cleaned_messages: List[AllMessageValues] = [] - - for message in messages: - role = message.get("role") - if role in ("tool", "function"): - continue - - if role == "assistant": - assistant_message = message.copy() - assistant_message.pop("tool_calls", None) - assistant_message.pop("function_call", None) - assistant_message.pop("tool_call_id", None) - - if assistant_message.get("content") is None: - continue - - cleaned_messages.append(assistant_message) - continue - - cleaned_messages.append(message) - - return cleaned_messages - - def get_completion_messages( messages: List[AllMessageValues], assistant_continue_message: Optional[ChatCompletionAssistantMessage], diff --git a/litellm/main.py b/litellm/main.py index cb5a92caf8..81319bc432 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -166,7 +166,6 @@ from .litellm_core_utils.fallback_utils import ( from .litellm_core_utils.prompt_templates.common_utils import ( add_system_prompt_to_messages, get_completion_messages, - strip_tool_messages_for_alternating_roles, update_messages_with_model_file_ids, ) from .litellm_core_utils.prompt_templates.factory import ( @@ -1299,9 +1298,6 @@ def completion( # type: ignore # noqa: PLR0915 prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) litellm_system_prompt = kwargs.get("litellm_system_prompt", None) ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 - if ensure_alternating_roles: - messages = strip_tool_messages_for_alternating_roles(messages=messages) - messages = get_completion_messages( messages=messages, ensure_alternating_roles=ensure_alternating_roles or False, diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 2eed9fa212..12efb47e06 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -25,7 +25,6 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( ) from litellm.litellm_core_utils.prompt_templates.common_utils import ( get_completion_messages, - strip_tool_messages_for_alternating_roles, ) from litellm.llms.vertex_ai.gemini.transformation import ( _gemini_convert_messages_with_history, @@ -800,8 +799,6 @@ def test_ensure_alternating_roles_with_tool_calls(): {"role": "user", "content": "What about next week?"}, ] - messages = strip_tool_messages_for_alternating_roles(messages) - transformed_messages = get_completion_messages( messages=messages, assistant_continue_message=None, @@ -811,6 +808,21 @@ def test_ensure_alternating_roles_with_tool_calls(): assert transformed_messages == [ {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_123", "content": "72F, sunny"}, {"role": "assistant", "content": "It's 72F and sunny in NYC."}, {"role": "user", "content": "What about tomorrow?"}, {"role": "assistant", "content": "Please continue."}, @@ -820,6 +832,32 @@ def test_ensure_alternating_roles_with_tool_calls(): ] +def test_ensure_alternating_roles_three_consecutive_assistants(): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "A1"}, + {"role": "assistant", "content": "A2"}, + {"role": "assistant", "content": "A3"}, + ] + + transformed_messages = get_completion_messages( + messages=messages, + assistant_continue_message=None, + user_continue_message=None, + ensure_alternating_roles=True, + ) + + assert transformed_messages == [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "Please continue."}, + {"role": "assistant", "content": "A2"}, + {"role": "user", "content": "Please continue."}, + {"role": "assistant", "content": "A3"}, + {"role": "user", "content": "Please continue."}, + ] + + def test_alternating_roles_e2e(): from litellm.llms.custom_httpx.http_handler import HTTPHandler import json From 0d70864d0905c7bdd537facb286938a73081a410 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 15:48:31 +0530 Subject: [PATCH 089/539] Add support for prompt management for responses --- docs/my-website/docs/prompt_management.md | 48 ++++ .../docs/proxy/prompt_management.md | 19 +- docs/my-website/sidebars.js | 1 + litellm/responses/main.py | 36 +++ .../test_responses_prompt_management.py | 211 ++++++++++++++++++ 5 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 docs/my-website/docs/prompt_management.md create mode 100644 tests/test_litellm/responses/test_responses_prompt_management.py diff --git a/docs/my-website/docs/prompt_management.md b/docs/my-website/docs/prompt_management.md new file mode 100644 index 0000000000..c4e606674b --- /dev/null +++ b/docs/my-website/docs/prompt_management.md @@ -0,0 +1,48 @@ +--- +title: Prompt Management with Responses API +--- + +# Prompt Management with Responses API + +Use LiteLLM Prompt Management with `/v1/responses` by passing `prompt_id` and optional `prompt_variables`. + +## Basic Usage + +```bash +curl -X POST "http://localhost:4000/v1/responses" \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "prompt_id": "my-responses-prompt", + "prompt_variables": {"topic": "large language models"}, + "input": [] + }' +``` + +## Multi-turn Follow-up in `input` + +To send follow-up turns in one request, pass message history in `input`. + +```bash +curl -X POST "http://localhost:4000/v1/responses" \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "prompt_id": "my-responses-prompt", + "prompt_variables": {"topic": "large language models"}, + "input": [ + {"role": "user", "content": "Topic is LLMs. Start short."}, + {"role": "assistant", "content": "Sure, go ahead."}, + {"role": "user", "content": "Now give me 3 bullets and include pricing caveat."} + ] + }' +``` + +## Notes + +- Prompt template messages are merged with your `input` messages. +- Prompt variable substitution applies to prompt message content. +- Tool call payload fields are not substituted by prompt variables. +- For follow-ups with `previous_response_id`, include `prompt_id` again if you want prompt management applied on that turn. diff --git a/docs/my-website/docs/proxy/prompt_management.md b/docs/my-website/docs/proxy/prompt_management.md index 08307ba99e..5a3e411e98 100644 --- a/docs/my-website/docs/proxy/prompt_management.md +++ b/docs/my-website/docs/proxy/prompt_management.md @@ -311,7 +311,7 @@ litellm_settings: 1. **At Startup**: When the proxy starts, it reads the `prompts` field from `config.yaml` 2. **Initialization**: Each prompt is initialized based on its `prompt_integration` type 3. **In-Memory Storage**: Prompts are stored in the `IN_MEMORY_PROMPT_REGISTRY` -4. **Access**: Use these prompts via the `/v1/chat/completions` endpoint with `prompt_id` in the request +4. **Access**: Use these prompts via `/v1/chat/completions` or `/v1/responses` with `prompt_id` in the request ### Using Config-Loaded Prompts @@ -331,6 +331,23 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ }' ``` +You can also use the same `prompt_id` with the Responses API: + +```bash +curl -L -X POST 'http://0.0.0.0:4000/v1/responses' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gpt-4o", + "prompt_id": "coding_assistant", + "prompt_variables": { + "language": "python", + "task": "create a web scraper" + }, + "input": [] +}' +``` + ### Prompt Schema Reference Each prompt in the `prompts` list requires: diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 79a0279bad..2a61c601ef 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -687,6 +687,7 @@ const sidebars = { "proxy/realtime_webrtc", "rerank", "response_api", + "prompt_management", "response_api_compact", { type: "category", diff --git a/litellm/responses/main.py b/litellm/responses/main.py index cd9ce67c26..cec4565166 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -37,6 +37,7 @@ from litellm.responses.litellm_completion_transformation.handler import ( ) from litellm.responses.utils import ResponsesAPIRequestUtils from litellm.types.llms.openai import ( + AllMessageValues, PromptObject, Reasoning, ResponseIncludable, @@ -623,6 +624,41 @@ def responses( if dynamic_api_base is not None: litellm_params.api_base = dynamic_api_base + ######################################################### + # PROMPT MANAGEMENT + ######################################################### + prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) + prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) + + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and ( + litellm_logging_obj.should_run_prompt_management_hooks( + prompt_id=prompt_id, non_default_params=kwargs + ) + ): + client_input: List[AllMessageValues] = ( + [{"role": "user", "content": input}] + if isinstance(input, str) + else cast(List[AllMessageValues], list(input)) + ) + ( + model, + merged_input, + merged_optional_params, + ) = litellm_logging_obj.get_chat_completion_prompt( + model=model, + messages=client_input, + non_default_params=kwargs, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + prompt_label=kwargs.get("prompt_label", None), + prompt_version=kwargs.get("prompt_version", None), + ) + input = cast(Union[str, ResponseInputParam], merged_input) + local_vars["input"] = input + # Apply prompt_template_optional_params (e.g. temperature, instructions) + # by updating kwargs so they flow into local_vars → response_api_optional_params + kwargs.update(merged_optional_params) + ######################################################### # Update input and tools with provider-specific file IDs if managed files are used ######################################################### diff --git a/tests/test_litellm/responses/test_responses_prompt_management.py b/tests/test_litellm/responses/test_responses_prompt_management.py new file mode 100644 index 0000000000..788fd19534 --- /dev/null +++ b/tests/test_litellm/responses/test_responses_prompt_management.py @@ -0,0 +1,211 @@ +""" +Unit tests for prompt management support in the Responses API. + +Covers: + A) str input is coerced to a message list before merging with the template + B) list input is merged with the template + C) no prompt_id → hook is skipped, input is unchanged + D) model override from the prompt template is applied +""" + +from typing import List +from unittest.mock import MagicMock, patch + +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.types.llms.openai import AllMessageValues + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_logging_obj( + merged_model: str, + merged_messages: List[AllMessageValues], + should_run: bool = True, +) -> MagicMock: + """Return a mock LiteLLMLoggingObj pre-configured for prompt management.""" + logging_obj = MagicMock() + # Make isinstance(logging_obj, LiteLLMLoggingObj) return True + logging_obj.__class__ = LiteLLMLoggingObj + logging_obj.should_run_prompt_management_hooks.return_value = should_run + logging_obj.get_chat_completion_prompt.return_value = ( + merged_model, + merged_messages, + {}, + ) + # Instance attribute accessed by post-call metadata utilities + logging_obj.model_call_details = {} + return logging_obj + + +def _patch_responses_dispatch(): + """Patch everything after the prompt management block so tests stay unit-level.""" + return [ + patch( + "litellm.responses.main.litellm.get_llm_provider", + return_value=("gpt-4o", "openai", None, None), + ), + patch( + "litellm.responses.mcp.litellm_proxy_mcp_handler." + "LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway", + return_value=False, + ), + patch( + "litellm.responses.main.ProviderConfigManager" + ".get_provider_responses_api_config", + return_value=None, + ), + patch( + "litellm.responses.main.litellm_completion_transformation_handler" + ".response_api_handler", + return_value=MagicMock(), + ), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestResponsesAPIPromptManagement: + + def test_str_input_coerced_and_merged(self): + """[A] str input is wrapped into a message list before being passed to the hook.""" + template_messages: List[AllMessageValues] = [ + {"role": "system", "content": "You are a summariser."}, # type: ignore[list-item] + ] + client_message: List[AllMessageValues] = [ + {"role": "user", "content": "Tell me about AI."}, # type: ignore[list-item] + ] + expected_merged = template_messages + client_message + + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=expected_merged, + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + litellm.responses( + input="Tell me about AI.", + model="gpt-4o", + prompt_id="summariser-prompt", + prompt_variables={}, + litellm_logging_obj=logging_obj, + ) + + logging_obj.get_chat_completion_prompt.assert_called_once() + call_kwargs = logging_obj.get_chat_completion_prompt.call_args.kwargs + # str was coerced to a single user message before being passed to the hook + assert call_kwargs["messages"] == [ + {"role": "user", "content": "Tell me about AI."} + ] + assert call_kwargs["prompt_id"] == "summariser-prompt" + + def test_list_input_merged_with_template(self): + """[B] list input is passed directly to the hook and merged with the template.""" + template_messages: List[AllMessageValues] = [ + {"role": "system", "content": "You are helpful."}, # type: ignore[list-item] + ] + client_messages = [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + ] + expected_merged = template_messages + client_messages # type: ignore[operator] + + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=expected_merged, # type: ignore[arg-type] + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + litellm.responses( + input=client_messages, # type: ignore[arg-type] + model="gpt-4o", + prompt_id="helper-prompt", + litellm_logging_obj=logging_obj, + ) + + logging_obj.get_chat_completion_prompt.assert_called_once() + call_kwargs = logging_obj.get_chat_completion_prompt.call_args.kwargs + assert call_kwargs["messages"] == client_messages + + def test_no_prompt_id_skips_hook(self): + """[C] When prompt_id is absent, prompt management hooks are not called.""" + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=[], + should_run=False, + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + litellm.responses( + input="Hello", + model="gpt-4o", + litellm_logging_obj=logging_obj, + ) + + logging_obj.get_chat_completion_prompt.assert_not_called() + + def test_optional_params_from_template_applied(self): + """[E] prompt_template_optional_params (e.g. temperature) flow into the request.""" + template_messages: List[AllMessageValues] = [ + {"role": "user", "content": "Hello"}, # type: ignore[list-item] + ] + # Simulate get_chat_completion_prompt returning merged optional params + # that include a template-defined temperature + merged_kwargs = {"temperature": 0.2, "prompt_id": "t", "litellm_logging_obj": None} + + logging_obj = MagicMock() + logging_obj.__class__ = LiteLLMLoggingObj + logging_obj.should_run_prompt_management_hooks.return_value = True + logging_obj.get_chat_completion_prompt.return_value = ( + "openai/gpt-4o", + template_messages, + merged_kwargs, + ) + logging_obj.model_call_details = {} + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3] as mock_handler: + import litellm + litellm.responses( + input="Hello", + model="gpt-4o", + prompt_id="t", + litellm_logging_obj=logging_obj, + ) + + # temperature from the template should reach the downstream handler via local_vars + handler_call_kwargs = mock_handler.call_args.kwargs + request_params = handler_call_kwargs.get("responses_api_request", {}) + assert request_params.get("temperature") == 0.2 + + def test_model_override_from_template(self): + """[D] Model returned by the prompt hook overrides the original request model.""" + template_messages: List[AllMessageValues] = [ + {"role": "user", "content": "{{query}}"}, # type: ignore[list-item] + ] + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o-mini", # overridden model from template + merged_messages=template_messages, + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3] as mock_handler: + import litellm + litellm.responses( + input="What is AI?", + model="gpt-4o", + prompt_id="query-prompt", + prompt_variables={"query": "What is AI?"}, + litellm_logging_obj=logging_obj, + ) + + # The model passed to the downstream handler should be the overridden one + handler_call_kwargs = mock_handler.call_args.kwargs + assert handler_call_kwargs.get("model") == "openai/gpt-4o-mini" From f1421d10825d299fae3f2605c4e48e2cb3655b4f Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 15:55:51 +0530 Subject: [PATCH 090/539] fix(prompting): preserve tool chains in alternation insertion Avoid inserting assistant continue messages in the middle of assistant tool_call->tool chains by inserting before the next counted user turn, and add regression coverage for this edge case. Made-with: Cursor --- .../prompt_templates/common_utils.py | 67 +++++++++---------- tests/llm_translation/test_prompt_factory.py | 44 ++++++++++++ 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index aa2e07234d..739c3119cc 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -257,6 +257,15 @@ def detect_first_expected_role( return None +def _counts_for_alternation(message: AllMessageValues) -> bool: + role = message.get("role") + if role == "user": + return True + if role == "assistant": + return not bool(message.get("tool_calls")) + return False + + def _insert_user_continue_message( messages: List[AllMessageValues], user_continue_message: Optional[ChatCompletionUserMessage], @@ -275,14 +284,6 @@ def _insert_user_continue_message( if not messages: return messages - def _counts_for_alternation(message: AllMessageValues) -> bool: - role = message.get("role") - if role == "user": - return True - if role == "assistant": - return not bool(message.get("tool_calls")) - return False - result_messages = messages.copy() # Don't modify the input list continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE @@ -346,37 +347,33 @@ def _insert_assistant_continue_message( """ if not ensure_alternating_roles or len(messages) <= 1: return messages - - def _counts_for_alternation(message: AllMessageValues) -> bool: - role = message.get("role") - if role == "user": - return True - if role == "assistant": - return not bool(message.get("tool_calls")) - return False - - # Create a new list to store modified messages - modified_messages: List[AllMessageValues] = [] + continue_message = assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE + insert_before_indexes = set() for i, message in enumerate(messages): + if message.get("role") != "user": + continue + + next_counted_index = i + 1 + while next_counted_index < len(messages) and not _counts_for_alternation( + messages[next_counted_index] + ): + next_counted_index += 1 + + if ( + next_counted_index < len(messages) + and messages[next_counted_index].get("role") == "user" + ): + # Insert before the next counted user turn. + # This avoids splitting assistant tool-call -> tool chains. + insert_before_indexes.add(next_counted_index) + + modified_messages: List[AllMessageValues] = [] + for idx, message in enumerate(messages): + if idx in insert_before_indexes: + modified_messages.append(continue_message) modified_messages.append(message) - if message.get("role") == "user" and _counts_for_alternation(message): - next_counted_index = i + 1 - while next_counted_index < len(messages) and not _counts_for_alternation( - messages[next_counted_index] - ): - next_counted_index += 1 - - if ( - next_counted_index < len(messages) - and messages[next_counted_index].get("role") == "user" - ): - continue_message = ( - assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE - ) - modified_messages.append(continue_message) - return modified_messages diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 12efb47e06..355c23ae17 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -858,6 +858,50 @@ def test_ensure_alternating_roles_three_consecutive_assistants(): ] +def test_ensure_alternating_roles_does_not_split_tool_call_chain(): + messages = [ + {"role": "user", "content": "Search for X"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "results"}, + {"role": "user", "content": "Thanks, now do Y"}, + ] + + transformed_messages = get_completion_messages( + messages=messages, + assistant_continue_message=None, + user_continue_message=None, + ensure_alternating_roles=True, + ) + + assert transformed_messages == [ + {"role": "user", "content": "Search for X"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "results"}, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "Thanks, now do Y"}, + ] + + def test_alternating_roles_e2e(): from litellm.llms.custom_httpx.http_handler import HTTPHandler import json From 7e4ec1000718f2f35854552eaf0d3aac43f48fa4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:08:15 +0530 Subject: [PATCH 091/539] Update litellm/litellm_core_utils/prompt_templates/common_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- litellm/litellm_core_utils/prompt_templates/common_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 739c3119cc..f5e3ffdb74 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -299,9 +299,7 @@ def _insert_user_continue_message( while i < len(result_messages): curr_message = result_messages[i] inserted_continue_message = False - if ( - curr_message["role"] == "assistant" - and _counts_for_alternation(curr_message) + if _counts_for_alternation(curr_message) and curr_message["role"] == "assistant": ): j = i - 1 while j >= 0: From ae350ed3708b9e90a5a185e9fb5d1073f6b0dc35 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:09:32 +0530 Subject: [PATCH 092/539] Fix greptile comments --- tests/llm_translation/test_prompt_factory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 355c23ae17..f146e2811a 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -776,6 +776,7 @@ def test_ensure_alternating_roles( def test_ensure_alternating_roles_with_tool_calls(): + """Fixes Regression in #18685 """ messages = [ {"role": "user", "content": "What's the weather?"}, { From 35b3ed58a8bdc190c9256d3bbfc706a63623c2e9 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:19:11 +0530 Subject: [PATCH 093/539] Fix greptile review --- litellm/responses/main.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index cec4565166..a844b61854 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -630,10 +630,8 @@ def responses( prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) - if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and ( - litellm_logging_obj.should_run_prompt_management_hooks( - prompt_id=prompt_id, non_default_params=kwargs - ) + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( + prompt_id=prompt_id, non_default_params=kwargs ): client_input: List[AllMessageValues] = ( [{"role": "user", "content": input}] @@ -655,9 +653,9 @@ def responses( ) input = cast(Union[str, ResponseInputParam], merged_input) local_vars["input"] = input - # Apply prompt_template_optional_params (e.g. temperature, instructions) - # by updating kwargs so they flow into local_vars → response_api_optional_params - kwargs.update(merged_optional_params) + local_vars["model"] = model + for k, v in merged_optional_params.items(): + local_vars[k] = v ######################################################### # Update input and tools with provider-specific file IDs if managed files are used From 0941e4036365ad9a57db604d35d98f0356022b54 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:30:42 +0530 Subject: [PATCH 094/539] fix(prompting): address greptile review - fix SyntaxError, restore backward compat, add trailing tool-call test - Remove stray ): on line 303 (P0 SyntaxError) - Restore backward-compatible trailing-assistant behavior (P1) - Add test_ensure_alternating_roles_trailing_tool_call_assistant - Keep role check alongside _counts_for_alternation (P2 is false positive) Made-with: Cursor --- .../prompt_templates/common_utils.py | 10 ++---- tests/llm_translation/test_prompt_factory.py | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index f5e3ffdb74..d7c1cc708b 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -300,7 +300,6 @@ def _insert_user_continue_message( curr_message = result_messages[i] inserted_continue_message = False if _counts_for_alternation(curr_message) and curr_message["role"] == "assistant": - ): j = i - 1 while j >= 0: previous_message = result_messages[j] @@ -314,12 +313,9 @@ def _insert_user_continue_message( if not inserted_continue_message: i += 1 - # Handle final message - if ( - result_messages[-1]["role"] == "assistant" - and _counts_for_alternation(result_messages[-1]) - and ensure_alternating_roles - ): + # Handle final message — append user_continue after any trailing assistant, + # including ones with tool_calls, to preserve backward compatibility. + if result_messages[-1]["role"] == "assistant" and ensure_alternating_roles: result_messages.append(continue_message) return result_messages diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index f146e2811a..b02ed3ebea 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -903,6 +903,40 @@ def test_ensure_alternating_roles_does_not_split_tool_call_chain(): ] +def test_ensure_alternating_roles_trailing_tool_call_assistant(): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + ] + + transformed_messages = get_completion_messages( + messages=messages, + assistant_continue_message=None, + user_continue_message=None, + ensure_alternating_roles=True, + ) + + # Backward compat: trailing assistant (even with tool_calls) gets user_continue + # appended, then assistant_continue bridges the user→user gap. + assert transformed_messages[-1] == {"role": "user", "content": "Please continue."} + assert transformed_messages[0] == {"role": "user", "content": "What's the weather?"} + assert transformed_messages[1]["role"] == "assistant" + assert transformed_messages[1].get("tool_calls") is not None + + def test_alternating_roles_e2e(): from litellm.llms.custom_httpx.http_handler import HTTPHandler import json From b32f5ea379f4f4792d1b892cf6fdc3e3d63f80af Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:37:52 +0530 Subject: [PATCH 095/539] Fix greptile comments --- litellm/responses/main.py | 63 +++++++++++++++-- .../test_responses_prompt_management.py | 67 ++++++++++++++++++- 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index a844b61854..862973e610 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -464,6 +464,49 @@ async def aresponses( # Update local_vars with detected provider (fixes #19782) local_vars["custom_llm_provider"] = custom_llm_provider + ######################################################### + # ASYNC PROMPT MANAGEMENT + ######################################################### + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) + prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) + + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( + prompt_id=prompt_id, non_default_params=kwargs + ): + if isinstance(input, str): + client_input: List[AllMessageValues] = [ + {"role": "user", "content": input} + ] + else: + client_input = [ + item # type: ignore[misc] + for item in input + if isinstance(item, dict) and "role" in item + ] + ( + model, + merged_input, + merged_optional_params, + ) = await litellm_logging_obj.async_get_chat_completion_prompt( + model=model, + messages=client_input, + non_default_params=kwargs, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + prompt_label=kwargs.get("prompt_label", None), + prompt_version=kwargs.get("prompt_version", None), + ) + input = cast(Union[str, ResponseInputParam], merged_input) + if "/" in model: + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model + ) + local_vars["custom_llm_provider"] = custom_llm_provider + for k, v in merged_optional_params.items(): + if k in local_vars: + local_vars[k] = v + func = partial( responses, input=input, @@ -633,11 +676,16 @@ def responses( if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( prompt_id=prompt_id, non_default_params=kwargs ): - client_input: List[AllMessageValues] = ( - [{"role": "user", "content": input}] - if isinstance(input, str) - else cast(List[AllMessageValues], list(input)) - ) + if isinstance(input, str): + client_input: List[AllMessageValues] = [ + {"role": "user", "content": input} + ] + else: + client_input = [ + item # type: ignore[misc] + for item in input + if isinstance(item, dict) and "role" in item + ] ( model, merged_input, @@ -654,6 +702,11 @@ def responses( input = cast(Union[str, ResponseInputParam], merged_input) local_vars["input"] = input local_vars["model"] = model + if "/" in model: + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model + ) + local_vars["custom_llm_provider"] = custom_llm_provider for k, v in merged_optional_params.items(): local_vars[k] = v diff --git a/tests/test_litellm/responses/test_responses_prompt_management.py b/tests/test_litellm/responses/test_responses_prompt_management.py index 788fd19534..666555792f 100644 --- a/tests/test_litellm/responses/test_responses_prompt_management.py +++ b/tests/test_litellm/responses/test_responses_prompt_management.py @@ -158,7 +158,7 @@ class TestResponsesAPIPromptManagement: ] # Simulate get_chat_completion_prompt returning merged optional params # that include a template-defined temperature - merged_kwargs = {"temperature": 0.2, "prompt_id": "t", "litellm_logging_obj": None} + merged_kwargs = {"temperature": 0.2} logging_obj = MagicMock() logging_obj.__class__ = LiteLLMLoggingObj @@ -209,3 +209,68 @@ class TestResponsesAPIPromptManagement: # The model passed to the downstream handler should be the overridden one handler_call_kwargs = mock_handler.call_args.kwargs assert handler_call_kwargs.get("model") == "openai/gpt-4o-mini" + + def test_non_message_input_items_filtered(self): + """[F] Non-message items in ResponseInputParam (e.g. function_call_output) are + filtered out before being passed to the prompt hook, avoiding malformed merges.""" + template_messages: List[AllMessageValues] = [ + {"role": "system", "content": "You are helpful."}, # type: ignore[list-item] + ] + mixed_input = [ + {"role": "user", "content": "Hello"}, + {"type": "function_call_output", "call_id": "abc", "output": "42"}, + ] + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=template_messages + [{"role": "user", "content": "Hello"}], # type: ignore[operator] + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + litellm.responses( + input=mixed_input, # type: ignore[arg-type] + model="gpt-4o", + prompt_id="filter-test", + litellm_logging_obj=logging_obj, + ) + + call_kwargs = logging_obj.get_chat_completion_prompt.call_args.kwargs + passed_messages = call_kwargs["messages"] + assert all(isinstance(m, dict) and "role" in m for m in passed_messages) + assert len(passed_messages) == 1 + + def test_model_override_re_resolves_provider(self): + """[G] When the prompt template overrides the model to a different provider, + custom_llm_provider is re-resolved so downstream routing uses the correct provider.""" + template_messages: List[AllMessageValues] = [ + {"role": "user", "content": "Hi"}, # type: ignore[list-item] + ] + logging_obj = _make_logging_obj( + merged_model="anthropic/claude-3-5-sonnet", + merged_messages=template_messages, + ) + + patches = _patch_responses_dispatch() + with ( + patch( + "litellm.responses.main.litellm.get_llm_provider", + side_effect=[ + ("gpt-4o", "openai", None, None), + ("claude-3-5-sonnet", "anthropic", None, None), + ], + ), + patches[1], + patches[2], + patches[3] as mock_handler, + ): + import litellm + litellm.responses( + input="Hi", + model="gpt-4o", + prompt_id="cross-provider", + litellm_logging_obj=logging_obj, + ) + + handler_call_kwargs = mock_handler.call_args.kwargs + assert handler_call_kwargs.get("custom_llm_provider") == "anthropic" From 67f5ce9c7c23d476eeb371885b401c625ca61a40 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:41:26 +0530 Subject: [PATCH 096/539] address greptile review feedback (greploop iteration 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Restore backward compat for leading assistant(tool_calls) — always prepend user_continue - Replace partial assertions with full list assertion in trailing tool-call test Made-with: Cursor --- .../prompt_templates/common_utils.py | 8 +++--- tests/llm_translation/test_prompt_factory.py | 25 ++++++++++++++----- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index d7c1cc708b..8713d0283e 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -287,11 +287,9 @@ def _insert_user_continue_message( result_messages = messages.copy() # Don't modify the input list continue_message = user_continue_message or DEFAULT_USER_CONTINUE_MESSAGE - # Handle first counted message if it's an assistant message - if ( - result_messages[0]["role"] == "assistant" - and _counts_for_alternation(result_messages[0]) - ): + # Handle first message if it's an assistant message — always prepend + # user_continue regardless of tool_calls, to preserve backward compatibility. + if result_messages[0]["role"] == "assistant": result_messages.insert(0, continue_message) # Handle consecutive assistant messages in the counted sequence diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index b02ed3ebea..3a9f267e6a 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -929,12 +929,25 @@ def test_ensure_alternating_roles_trailing_tool_call_assistant(): ensure_alternating_roles=True, ) - # Backward compat: trailing assistant (even with tool_calls) gets user_continue - # appended, then assistant_continue bridges the user→user gap. - assert transformed_messages[-1] == {"role": "user", "content": "Please continue."} - assert transformed_messages[0] == {"role": "user", "content": "What's the weather?"} - assert transformed_messages[1]["role"] == "assistant" - assert transformed_messages[1].get("tool_calls") is not None + assert transformed_messages == [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + {"role": "assistant", "content": "Please continue."}, + {"role": "user", "content": "Please continue."}, + ] def test_alternating_roles_e2e(): From d333dc4077b98012b08709ce0ed31881ec06879e Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:49:07 +0530 Subject: [PATCH 097/539] address greptile review feedback (greploop iteration 1) - Fix async path: call async_get_chat_completion_prompt in aresponses() before executor dispatch, mirroring acompletion() in main.py. Discard merged_optional_params in async path (sync responses() handles them via local_vars), avoiding TypeError from duplicate kwargs in partial(). - Fix provider re-resolution: replace "/" in model heuristic with model != original_model comparison so bare model names are handled. - Add 3 async tests covering hook invocation, optional param propagation, and non-message item filtering in aresponses(). Made-with: Cursor --- litellm/responses/main.py | 12 +- .../test_responses_prompt_management.py | 119 ++++++++++++++++-- 2 files changed, 117 insertions(+), 14 deletions(-) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 862973e610..0e56836355 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -470,6 +470,7 @@ async def aresponses( litellm_logging_obj = kwargs.get("litellm_logging_obj", None) prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) + original_model = model if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( prompt_id=prompt_id, non_default_params=kwargs @@ -487,7 +488,7 @@ async def aresponses( ( model, merged_input, - merged_optional_params, + _, ) = await litellm_logging_obj.async_get_chat_completion_prompt( model=model, messages=client_input, @@ -498,14 +499,10 @@ async def aresponses( prompt_version=kwargs.get("prompt_version", None), ) input = cast(Union[str, ResponseInputParam], merged_input) - if "/" in model: + if model != original_model: _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model ) - local_vars["custom_llm_provider"] = custom_llm_provider - for k, v in merged_optional_params.items(): - if k in local_vars: - local_vars[k] = v func = partial( responses, @@ -672,6 +669,7 @@ def responses( ######################################################### prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) + original_model = model if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( prompt_id=prompt_id, non_default_params=kwargs @@ -702,7 +700,7 @@ def responses( input = cast(Union[str, ResponseInputParam], merged_input) local_vars["input"] = input local_vars["model"] = model - if "/" in model: + if model != original_model: _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model ) diff --git a/tests/test_litellm/responses/test_responses_prompt_management.py b/tests/test_litellm/responses/test_responses_prompt_management.py index 666555792f..9defaceed8 100644 --- a/tests/test_litellm/responses/test_responses_prompt_management.py +++ b/tests/test_litellm/responses/test_responses_prompt_management.py @@ -6,10 +6,18 @@ Covers: B) list input is merged with the template C) no prompt_id → hook is skipped, input is unchanged D) model override from the prompt template is applied + E) prompt_template_optional_params flow into the request + F) non-message items in input are filtered out + G) model override re-resolves provider + H) async path calls async_get_chat_completion_prompt + I) async path propagates optional params to downstream handler """ +import asyncio from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import AllMessageValues @@ -22,18 +30,19 @@ def _make_logging_obj( merged_model: str, merged_messages: List[AllMessageValues], should_run: bool = True, + merged_optional_params: dict = None, ) -> MagicMock: """Return a mock LiteLLMLoggingObj pre-configured for prompt management.""" + if merged_optional_params is None: + merged_optional_params = {} logging_obj = MagicMock() - # Make isinstance(logging_obj, LiteLLMLoggingObj) return True logging_obj.__class__ = LiteLLMLoggingObj logging_obj.should_run_prompt_management_hooks.return_value = should_run - logging_obj.get_chat_completion_prompt.return_value = ( - merged_model, - merged_messages, - {}, + prompt_return = (merged_model, merged_messages, merged_optional_params) + logging_obj.get_chat_completion_prompt.return_value = prompt_return + logging_obj.async_get_chat_completion_prompt = AsyncMock( + return_value=prompt_return ) - # Instance attribute accessed by post-call metadata utilities logging_obj.model_call_details = {} return logging_obj @@ -274,3 +283,99 @@ class TestResponsesAPIPromptManagement: handler_call_kwargs = mock_handler.call_args.kwargs assert handler_call_kwargs.get("custom_llm_provider") == "anthropic" + + +class TestAsyncResponsesAPIPromptManagement: + """Tests for the async aresponses() prompt management path. + + aresponses() calls async_get_chat_completion_prompt at the outer async level + (for async-only prompt loggers), then delegates to responses() via + run_in_executor where the sync hook also runs — mirroring acompletion() in + main.py. Optional params are handled by the sync responses() path. + """ + + @pytest.mark.asyncio + async def test_async_calls_async_hook(self): + """[H] aresponses() invokes async_get_chat_completion_prompt before + dispatching to the sync responses() path.""" + template_messages: List[AllMessageValues] = [ + {"role": "system", "content": "You are helpful."}, # type: ignore[list-item] + ] + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=template_messages + [{"role": "user", "content": "Hi"}], # type: ignore[list-item] + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + await litellm.aresponses( + input="Hi", + model="gpt-4o", + prompt_id="async-test", + prompt_variables={}, + litellm_logging_obj=logging_obj, + ) + + logging_obj.async_get_chat_completion_prompt.assert_called_once() + call_kwargs = logging_obj.async_get_chat_completion_prompt.call_args.kwargs + assert call_kwargs["prompt_id"] == "async-test" + + @pytest.mark.asyncio + async def test_async_optional_params_propagated(self): + """[I] Template-defined optional params (e.g. temperature) reach the downstream + handler when called via aresponses(). The sync responses() path applies them + via local_vars.""" + template_messages: List[AllMessageValues] = [ + {"role": "user", "content": "Hello"}, # type: ignore[list-item] + ] + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=template_messages, + merged_optional_params={"temperature": 0.7}, + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3] as mock_handler: + import litellm + await litellm.aresponses( + input="Hello", + model="gpt-4o", + prompt_id="async-temp", + litellm_logging_obj=logging_obj, + ) + + handler_call_kwargs = mock_handler.call_args.kwargs + request_params = handler_call_kwargs.get("responses_api_request", {}) + assert request_params.get("temperature") == 0.7 + + @pytest.mark.asyncio + async def test_async_non_message_items_filtered(self): + """[J] Non-message items are filtered in the async path too.""" + template_messages: List[AllMessageValues] = [ + {"role": "system", "content": "Be helpful."}, # type: ignore[list-item] + ] + mixed_input = [ + {"role": "user", "content": "Hello"}, + {"type": "function_call_output", "call_id": "abc", "output": "42"}, + ] + logging_obj = _make_logging_obj( + merged_model="openai/gpt-4o", + merged_messages=template_messages + [{"role": "user", "content": "Hello"}], # type: ignore[operator] + ) + + patches = _patch_responses_dispatch() + with patches[0], patches[1], patches[2], patches[3]: + import litellm + await litellm.aresponses( + input=mixed_input, # type: ignore[arg-type] + model="gpt-4o", + prompt_id="async-filter", + litellm_logging_obj=logging_obj, + ) + + logging_obj.async_get_chat_completion_prompt.assert_called_once() + call_kwargs = logging_obj.async_get_chat_completion_prompt.call_args.kwargs + passed_messages = call_kwargs["messages"] + assert all(isinstance(m, dict) and "role" in m for m in passed_messages) + assert len(passed_messages) == 1 From 22fc08d602598f5b5cbe2293ccbd146bb748e622 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 16:59:16 +0530 Subject: [PATCH 098/539] fix(prompting): revert _insert_assistant_continue_message to adjacent-check logic Restore backward-compatible behavior: only insert assistant_continue between directly adjacent user messages, not across tool-call chains. The _counts_for_alternation skip logic was a silent behavioral change for [user, assistant(tc), tool, user] sequences. Made-with: Cursor --- .../prompt_templates/common_utils.py | 42 +++++-------------- tests/llm_translation/test_prompt_factory.py | 3 +- 2 files changed, 12 insertions(+), 33 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 8713d0283e..eb3755b71f 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -326,45 +326,25 @@ def _insert_assistant_continue_message( ) -> List[AllMessageValues]: """ Add assistant continuation messages between consecutive user messages. - Skips tool messages and assistant messages with tool calls in the - alternation check, matching strict templates like llama.cpp. - Args: - messages: List of message dictionaries - assistant_continue_message: Optional custom assistant message - ensure_alternating_roles: Whether to enforce alternating roles - - Returns: - Modified list of messages with inserted assistant messages + Only checks directly adjacent messages to preserve backward compatibility. """ if not ensure_alternating_roles or len(messages) <= 1: return messages + continue_message = assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE - insert_before_indexes = set() - - for i, message in enumerate(messages): - if message.get("role") != "user": - continue - - next_counted_index = i + 1 - while next_counted_index < len(messages) and not _counts_for_alternation( - messages[next_counted_index] - ): - next_counted_index += 1 - - if ( - next_counted_index < len(messages) - and messages[next_counted_index].get("role") == "user" - ): - # Insert before the next counted user turn. - # This avoids splitting assistant tool-call -> tool chains. - insert_before_indexes.add(next_counted_index) modified_messages: List[AllMessageValues] = [] - for idx, message in enumerate(messages): - if idx in insert_before_indexes: + for i, message in enumerate(messages): + if ( + i < len(messages) - 1 + and message.get("role") == "user" + and messages[i + 1].get("role") == "user" + ): + modified_messages.append(message) modified_messages.append(continue_message) - modified_messages.append(message) + else: + modified_messages.append(message) return modified_messages diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index 3a9f267e6a..fe46c24a29 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -860,6 +860,7 @@ def test_ensure_alternating_roles_three_consecutive_assistants(): def test_ensure_alternating_roles_does_not_split_tool_call_chain(): + """Tool-call chains [user, assistant(tc), tool, user] are preserved as-is.""" messages = [ {"role": "user", "content": "Search for X"}, { @@ -898,7 +899,6 @@ def test_ensure_alternating_roles_does_not_split_tool_call_chain(): ], }, {"role": "tool", "tool_call_id": "c1", "content": "results"}, - {"role": "assistant", "content": "Please continue."}, {"role": "user", "content": "Thanks, now do Y"}, ] @@ -945,7 +945,6 @@ def test_ensure_alternating_roles_trailing_tool_call_assistant(): } ], }, - {"role": "assistant", "content": "Please continue."}, {"role": "user", "content": "Please continue."}, ] From 92b89353ae6017bcc2cec821c4a9c6a71c6e0da5 Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Mon, 16 Mar 2026 23:22:00 +0100 Subject: [PATCH 099/539] fix: surface Anthropic code execution results as code_interpreter_call in Responses API PR #18945 added support for capturing Anthropic server-side tool results (bash_code_execution_tool_result, etc.) in provider_specific_fields, but the data never reached the Responses API output because: 1. Non-streaming: provider_specific_fields wasn't copied into _hidden_params 2. Streaming: chunk delta's provider_specific_fields wasn't accumulated 3. Tool results weren't mapped to standard output items This fix: - Copies provider_specific_fields to _hidden_params in transform_response() - Accumulates provider_specific_fields from streaming chunk deltas - Maps bash_code_execution_tool_result to code_interpreter_call output items with code and outputs (matching OpenAI's native shape) - Removes redundant function_call items for server-side tools - Adds OutputCodeInterpreterCall type to the output union --- litellm/llms/anthropic/chat/handler.py | 86 +- litellm/llms/anthropic/chat/transformation.py | 79 +- .../streaming_iterator.py | 52 +- .../transformation.py | 57 +- litellm/types/llms/openai.py | 55 +- litellm/types/responses/main.py | 18 + .../chat/test_anthropic_chat_handler.py | 247 ++++- .../test_anthropic_chat_transformation.py | 889 +++++++++--------- 8 files changed, 971 insertions(+), 512 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 5eebebc2e2..51b9c9835a 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -48,6 +48,10 @@ from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, ) +from litellm.types.responses.main import ( + OutputCodeInterpreterCall, + OutputCodeInterpreterCallLog, +) from litellm.types.utils import ( Delta, GenericStreamingChunk, @@ -538,6 +542,11 @@ class ModelResponseIterator: # Accumulate compaction blocks for multi-turn reconstruction self.compaction_blocks: List[Dict[str, Any]] = [] + # Track server tool use inputs and results for code_interpreter_results + self._server_tool_inputs: Dict[str, Any] = {} + self.tool_results: List[Dict[str, Any]] = [] + self._last_code_interpreter_results_count: int = 0 + def check_empty_tool_call_args(self) -> bool: """ Check if the tool call block so far has been an empty string @@ -568,9 +577,7 @@ class ModelResponseIterator: speed=self.speed, ) - def _content_block_delta_helper( - self, chunk: dict - ) -> Tuple[ + def _content_block_delta_helper(self, chunk: dict) -> Tuple[ str, Optional[ChatCompletionToolCallChunk], List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]], @@ -682,6 +689,44 @@ class ModelResponseIterator: return content_block_start + def _build_code_interpreter_results(self) -> list: + """Convert accumulated tool_results to OutputCodeInterpreterCall objects. + + Called during streaming to produce provider-neutral code_interpreter_results + alongside the raw tool_results, so the Responses API layer doesn't need + Anthropic-specific knowledge. + """ + # Only convert tool_results added since the last call to avoid + # duplicates when _merge_provider_specific_fields extends the list. + new_results = self.tool_results[self._last_code_interpreter_results_count :] + self._last_code_interpreter_results_count = len(self.tool_results) + results = [] + for tr in new_results: + call_id = tr.get("tool_use_id", "") + content = tr.get("content", {}) + if isinstance(content, dict): + parts = [] + if content.get("stdout"): + parts.append(content["stdout"]) + if content.get("stderr"): + parts.append(f"STDERR: {content['stderr']}") + logs = "".join(parts) if parts else str(content) + else: + logs = str(content) + tool_input = self._server_tool_inputs.get(call_id, {}) + code = tool_input.get("command", "") if isinstance(tool_input, dict) else "" + results.append( + OutputCodeInterpreterCall( + type="code_interpreter_call", + id=call_id, + code=code, + container_id=None, + status="completed", + outputs=[OutputCodeInterpreterCallLog(type="logs", logs=logs)], + ) + ) + return results + def chunk_parser(self, chunk: dict) -> ModelResponseStream: # noqa: PLR0915 try: type_chunk = chunk.get("type", "") or "" @@ -748,6 +793,17 @@ class ModelResponseIterator: ), index=self.tool_index, ) + # Track server tool use inputs for code_interpreter_results + if ( + content_block_start["content_block"]["type"] + == "server_tool_use" + ): + tool_input = content_block_start["content_block"].get( + "input", {} + ) + self._server_tool_inputs[ + content_block_start["content_block"]["id"] + ] = tool_input # Include caller information if present (for programmatic tool calling) if "caller" in content_block_start["content_block"]: caller_data = content_block_start["content_block"]["caller"] @@ -768,9 +824,9 @@ class ModelResponseIterator: # Handle compaction blocks # The full content comes in content_block_start self.compaction_blocks.append(content_block_start["content_block"]) - provider_specific_fields[ - "compaction_blocks" - ] = self.compaction_blocks + provider_specific_fields["compaction_blocks"] = ( + self.compaction_blocks + ) provider_specific_fields["compaction_start"] = { "type": "compaction", "content": content_block_start["content_block"].get( @@ -792,9 +848,9 @@ class ModelResponseIterator: self.web_search_results.append( content_block_start["content_block"] ) - provider_specific_fields[ - "web_search_results" - ] = self.web_search_results + provider_specific_fields["web_search_results"] = ( + self.web_search_results + ) elif content_type == "web_fetch_tool_result": # Capture web_fetch_tool_result for multi-turn reconstruction # The full content comes in content_block_start, not in deltas @@ -802,16 +858,18 @@ class ModelResponseIterator: self.web_search_results.append( content_block_start["content_block"] ) - provider_specific_fields[ - "web_search_results" - ] = self.web_search_results + provider_specific_fields["web_search_results"] = ( + self.web_search_results + ) elif content_type != "tool_search_tool_result": # Handle other tool results (code execution, etc.) # Skip tool_search_tool_result as it's internal metadata - if not hasattr(self, "tool_results"): - self.tool_results = [] self.tool_results.append(content_block_start["content_block"]) provider_specific_fields["tool_results"] = self.tool_results + # Convert to provider-neutral code_interpreter_results + provider_specific_fields["code_interpreter_results"] = ( + self._build_code_interpreter_results() + ) elif type_chunk == "content_block_stop": ContentBlockStop(**chunk) # type: ignore diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 47cdd8287e..033afea2ff 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -59,6 +59,10 @@ from litellm.types.utils import ( PromptTokensDetailsWrapper, ServerToolUse, ) +from litellm.types.responses.main import ( + OutputCodeInterpreterCall, + OutputCodeInterpreterCallLog, +) from litellm.utils import ( ModelResponse, Usage, @@ -960,11 +964,11 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): if mcp_servers: optional_params["mcp_servers"] = mcp_servers elif param == "tool_choice" or param == "parallel_tool_calls": - _tool_choice: Optional[ - AnthropicMessagesToolChoice - ] = self._map_tool_choice( - tool_choice=non_default_params.get("tool_choice"), - parallel_tool_use=non_default_params.get("parallel_tool_calls"), + _tool_choice: Optional[AnthropicMessagesToolChoice] = ( + self._map_tool_choice( + tool_choice=non_default_params.get("tool_choice"), + parallel_tool_use=non_default_params.get("parallel_tool_calls"), + ) ) if _tool_choice is not None: @@ -1062,9 +1066,9 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): self.map_openai_context_management_to_anthropic(value) ) if anthropic_context_management is not None: - optional_params[ - "context_management" - ] = anthropic_context_management + optional_params["context_management"] = ( + anthropic_context_management + ) elif param == "speed" and isinstance(value, str): # Pass through Anthropic-specific speed parameter for fast mode optional_params["speed"] = value @@ -1138,9 +1142,9 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): text=system_message_block["content"], ) if "cache_control" in system_message_block: - anthropic_system_message_content[ - "cache_control" - ] = system_message_block["cache_control"] + anthropic_system_message_content["cache_control"] = ( + system_message_block["cache_control"] + ) anthropic_system_message_list.append( anthropic_system_message_content ) @@ -1164,9 +1168,9 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): ) ) if "cache_control" in _content: - anthropic_system_message_content[ - "cache_control" - ] = _content["cache_control"] + anthropic_system_message_content["cache_control"] = ( + _content["cache_control"] + ) anthropic_system_message_list.append( anthropic_system_message_content @@ -1463,9 +1467,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): ) return _message - def extract_response_content( - self, completion_response: dict - ) -> Tuple[ + def extract_response_content(self, completion_response: dict) -> Tuple[ str, Optional[List[Any]], Optional[ @@ -1749,6 +1751,48 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): provider_specific_fields["web_search_results"] = web_search_results if tool_results is not None: provider_specific_fields["tool_results"] = tool_results + # Convert to provider-neutral OutputCodeInterpreterCall objects + # so the Responses API layer can use them without Anthropic-specific knowledge. + container_id = ( + completion_response.get("container", {}).get("id") + if isinstance(completion_response.get("container"), dict) + else None + ) + code_by_id: Dict[str, str] = {} + for tc in tool_calls: + try: + args = json.loads(tc.get("function", {}).get("arguments", "{}")) + code_by_id[tc.get("id", "")] = args.get("command", "") + except Exception: + pass + code_interpreter_results = [] + for tr in tool_results: + call_id = tr.get("tool_use_id", "") + content = tr.get("content", {}) + if isinstance(content, dict): + parts = [] + if content.get("stdout"): + parts.append(content["stdout"]) + if content.get("stderr"): + parts.append(f"STDERR: {content['stderr']}") + logs = "".join(parts) if parts else str(content) + else: + logs = str(content) + code_interpreter_results.append( + OutputCodeInterpreterCall( + type="code_interpreter_call", + id=call_id, + code=code_by_id.get(call_id, ""), + container_id=container_id, + status="completed", + outputs=[ + OutputCodeInterpreterCallLog(type="logs", logs=logs) + ], + ) + ) + provider_specific_fields["code_interpreter_results"] = ( + code_interpreter_results + ) if container is not None: provider_specific_fields["container"] = container if compaction_blocks is not None: @@ -1794,6 +1838,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): model_response.created = int(time.time()) model_response.model = completion_response["model"] + _hidden_params["provider_specific_fields"] = provider_specific_fields model_response._hidden_params = _hidden_params return model_response diff --git a/litellm/responses/litellm_completion_transformation/streaming_iterator.py b/litellm/responses/litellm_completion_transformation/streaming_iterator.py index ce037850b8..0b7d6e8a7a 100644 --- a/litellm/responses/litellm_completion_transformation/streaming_iterator.py +++ b/litellm/responses/litellm_completion_transformation/streaming_iterator.py @@ -107,6 +107,7 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator): self._reasoning_done_emitted = False self._reasoning_item_id: Optional[str] = None self._accumulated_reasoning_content_parts: List[str] = [] + self._accumulated_provider_specific_fields: Dict[str, Any] = {} def _get_or_assign_tool_output_index(self, call_id: str) -> int: existing = self._tool_output_index_by_call_id.get(call_id) @@ -479,16 +480,37 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator): event.__dict__["sequence_number"] = self._sequence_number return event - def create_litellm_model_response( - self, - ) -> Optional[ModelResponse]: - return cast( + def _merge_provider_specific_fields(self, src: dict) -> None: + """Merge provider_specific_fields, extending list values instead of replacing.""" + for key, val in src.items(): + existing = self._accumulated_provider_specific_fields.get(key) + if ( + existing is not None + and isinstance(val, list) + and isinstance(existing, list) + ): + existing.extend(val) + else: + self._accumulated_provider_specific_fields[key] = val + + def create_litellm_model_response(self) -> Optional[ModelResponse]: + response = cast( Optional[ModelResponse], stream_chunk_builder( chunks=self.collected_chat_completion_chunks, logging_obj=self.litellm_logging_obj, ), ) + if response is not None and self._accumulated_provider_specific_fields: + if ( + not hasattr(response, "_hidden_params") + or response._hidden_params is None + ): + response._hidden_params = {} + response._hidden_params.setdefault("provider_specific_fields", {}).update( + self._accumulated_provider_specific_fields + ) + return response @staticmethod def _snapshot_chunk_for_stream_chunk_builder( @@ -853,6 +875,17 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator): if chunk is not None: chunk = cast(ModelResponseStream, chunk) self._ensure_output_item_for_chunk(chunk) + # Accumulate provider_specific_fields from chunk and delta + for src in ( + getattr(chunk, "provider_specific_fields", None), + getattr( + chunk.choices[0].delta if chunk.choices else None, + "provider_specific_fields", + None, + ), + ): + if src and isinstance(src, dict): + self._merge_provider_specific_fields(src) # Proceed to transformation self.collected_chat_completion_chunks.append( self._snapshot_chunk_for_stream_chunk_builder(chunk) @@ -964,6 +997,17 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator): try: chunk = self.litellm_custom_stream_wrapper.__next__() self._ensure_output_item_for_chunk(chunk) + # Accumulate provider_specific_fields from chunk and delta + for src in ( + getattr(chunk, "provider_specific_fields", None), + getattr( + chunk.choices[0].delta if chunk.choices else None, + "provider_specific_fields", + None, + ), + ): + if src and isinstance(src, dict): + self._merge_provider_specific_fields(src) # Emit any just-queued output_item event if self._pending_response_events: return self._pending_response_events.pop(0) diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index 71fa88fb75..b54d5930ef 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -42,6 +42,7 @@ from litellm.types.llms.openai import ( from litellm.types.responses.main import ( GenericResponseOutputItem, GenericResponseOutputItemContentAnnotation, + OutputCodeInterpreterCall, OutputFunctionToolCall, OutputImageGenerationCall, OutputText, @@ -1696,6 +1697,7 @@ class LiteLLMCompletionResponsesConfig: ) -> List[ Union[ GenericResponseOutputItem, + OutputCodeInterpreterCall, OutputFunctionToolCall, OutputImageGenerationCall, ResponseFunctionToolCall, @@ -1704,6 +1706,7 @@ class LiteLLMCompletionResponsesConfig: responses_output: List[ Union[ GenericResponseOutputItem, + OutputCodeInterpreterCall, OutputFunctionToolCall, OutputImageGenerationCall, ResponseFunctionToolCall, @@ -1725,8 +1728,56 @@ class LiteLLMCompletionResponsesConfig: chat_completion_response=chat_completion_response ) ) + + # Convert server-side tool results (e.g. Anthropic code execution) + # into code_interpreter_call output items, replacing the corresponding + # function_call items so the output matches OpenAI's native shape. + tool_result_items = ( + LiteLLMCompletionResponsesConfig._extract_tool_result_output_items( + chat_completion_response + ) + ) + if tool_result_items: + result_by_id = {item.id: item for item in tool_result_items} + replaced_ids = set(result_by_id.keys()) + responses_output = [ + ( + result_by_id[getattr(item, "call_id", None)] + if ( + getattr(item, "type", None) == "function_call" + and getattr(item, "call_id", None) in replaced_ids + ) + else item + ) + for item in responses_output + ] + return responses_output + @staticmethod + def _extract_tool_result_output_items( + chat_completion_response: ModelResponse, + ) -> list: + """Extract pre-built code_interpreter_call output items from provider_specific_fields. + + Provider transformers (e.g. Anthropic) convert their native tool results + into OutputCodeInterpreterCall objects and store them in + provider_specific_fields["code_interpreter_results"]. This method + simply retrieves them — no provider-specific parsing here. + """ + output_items: list = [] + for choice in chat_completion_response.choices or []: + message = getattr(choice, "message", None) + if not message: + continue + psf = getattr(message, "provider_specific_fields", None) + if not psf or not isinstance(psf, dict): + continue + results = psf.get("code_interpreter_results") + if results and isinstance(results, list): + output_items.extend(results) + return output_items + @staticmethod def _extract_reasoning_output_items( chat_completion_response: ModelResponse, @@ -2055,9 +2106,9 @@ class LiteLLMCompletionResponsesConfig: hasattr(completion_details, "reasoning_tokens") and completion_details.reasoning_tokens is not None ): - output_details_dict[ - "reasoning_tokens" - ] = completion_details.reasoning_tokens + output_details_dict["reasoning_tokens"] = ( + completion_details.reasoning_tokens + ) else: output_details_dict["reasoning_tokens"] = 0 diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index a2df3f2e0d..a265198e6b 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -84,6 +84,7 @@ from typing_extensions import Annotated, Dict, Required, TypedDict, override from litellm.types.llms.base import BaseLiteLLMOpenAIResponseObject from litellm.types.responses.main import ( GenericResponseOutputItem, + OutputCodeInterpreterCall, OutputFunctionToolCall, OutputImageGenerationCall, ) @@ -969,12 +970,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk): class Hyperparameters(BaseModel): batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch." - learning_rate_multiplier: Optional[ - Union[str, float] - ] = None # Scaling factor for the learning rate - n_epochs: Optional[ - Union[str, int] - ] = None # "The number of epochs to train the model for" + learning_rate_multiplier: Optional[Union[str, float]] = ( + None # Scaling factor for the learning rate + ) + n_epochs: Optional[Union[str, int]] = ( + None # "The number of epochs to train the model for" + ) model_config = {"extra": "allow"} @@ -1003,18 +1004,18 @@ class FineTuningJobCreate(BaseModel): model: str # "The name of the model to fine-tune." training_file: str # "The ID of an uploaded file that contains training data." - hyperparameters: Optional[ - Hyperparameters - ] = None # "The hyperparameters used for the fine-tuning job." - suffix: Optional[ - str - ] = None # "A string of up to 18 characters that will be added to your fine-tuned model name." - validation_file: Optional[ - str - ] = None # "The ID of an uploaded file that contains validation data." - integrations: Optional[ - List[str] - ] = None # "A list of integrations to enable for your fine-tuning job." + hyperparameters: Optional[Hyperparameters] = ( + None # "The hyperparameters used for the fine-tuning job." + ) + suffix: Optional[str] = ( + None # "A string of up to 18 characters that will be added to your fine-tuned model name." + ) + validation_file: Optional[str] = ( + None # "The ID of an uploaded file that contains validation data." + ) + integrations: Optional[List[str]] = ( + None # "A list of integrations to enable for your fine-tuning job." + ) seed: Optional[int] = None # "The seed controls the reproducibility of the job." @@ -1242,6 +1243,7 @@ class ResponsesAPIResponse(BaseLiteLLMOpenAIResponseObject): List[ Union[ GenericResponseOutputItem, + OutputCodeInterpreterCall, OutputFunctionToolCall, OutputImageGenerationCall, ResponseFunctionToolCall, @@ -1308,13 +1310,16 @@ class ResponsesAPIResponse(BaseLiteLLMOpenAIResponseObject): if not isinstance(serialized, list): return serialized return [ - { - k: v - for k, v in item.items() - if v is not None or k not in ("status", "content", "encrypted_content") - } - if isinstance(item, dict) and item.get("type") == "reasoning" - else item + ( + { + k: v + for k, v in item.items() + if v is not None + or k not in ("status", "content", "encrypted_content") + } + if isinstance(item, dict) and item.get("type") == "reasoning" + else item + ) for item in serialized ] diff --git a/litellm/types/responses/main.py b/litellm/types/responses/main.py index 7a666d5e65..e46857565c 100644 --- a/litellm/types/responses/main.py +++ b/litellm/types/responses/main.py @@ -49,6 +49,24 @@ class OutputImageGenerationCall(BaseLiteLLMOpenAIResponseObject): result: Optional[str] # Base64 encoded image data (without data:image prefix) +class OutputCodeInterpreterCallLog(BaseLiteLLMOpenAIResponseObject): + """Log output from a code interpreter call""" + + type: Literal["logs"] + logs: str + + +class OutputCodeInterpreterCall(BaseLiteLLMOpenAIResponseObject): + """A code interpreter / code execution call output""" + + type: Literal["code_interpreter_call"] + id: str + code: Optional[str] + container_id: Optional[str] + status: Literal["in_progress", "completed", "incomplete", "failed"] + outputs: Optional[List[OutputCodeInterpreterCallLog]] + + class GenericResponseOutputItem(BaseLiteLLMOpenAIResponseObject): """ Generic response API output item diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py index d9f513d8d1..35c7a62027 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py @@ -6,6 +6,7 @@ from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, ) +from litellm.types.responses.main import OutputCodeInterpreterCall def test_redacted_thinking_content_block_delta(): @@ -479,14 +480,22 @@ def test_partial_json_chunk_accumulation(): # First partial chunk should return None (still accumulating) result1 = iterator._parse_sse_data(f"data:{partial_chunk_1}") assert result1 is None, "First partial chunk should return None while accumulating" - assert iterator.chunk_type == "accumulated_json", "Should switch to accumulated_json mode" - assert iterator.accumulated_json == partial_chunk_1, "Should have accumulated first part" + assert ( + iterator.chunk_type == "accumulated_json" + ), "Should switch to accumulated_json mode" + assert ( + iterator.accumulated_json == partial_chunk_1 + ), "Should have accumulated first part" # Second partial chunk should complete the JSON and return a parsed result result2 = iterator._parse_sse_data(f"data:{partial_chunk_2}") assert result2 is not None, "Second chunk should return parsed result" - assert iterator.accumulated_json == "", "Buffer should be cleared after successful parse" - assert result2.choices[0].delta.content == "Hello", f"Expected 'Hello', got '{result2.choices[0].delta.content}'" + assert ( + iterator.accumulated_json == "" + ), "Buffer should be cleared after successful parse" + assert ( + result2.choices[0].delta.content == "Hello" + ), f"Expected 'Hello', got '{result2.choices[0].delta.content}'" def test_complete_json_chunk_no_accumulation(): @@ -503,7 +512,9 @@ def test_complete_json_chunk_no_accumulation(): assert result is not None, "Complete chunk should return parsed result immediately" assert iterator.chunk_type == "valid_json", "Should remain in valid_json mode" assert iterator.accumulated_json == "", "Buffer should remain empty" - assert result.choices[0].delta.content == "Hello", f"Expected 'Hello', got '{result.choices[0].delta.content}'" + assert ( + result.choices[0].delta.content == "Hello" + ), f"Expected 'Hello', got '{result.choices[0].delta.content}'" def test_multiple_partial_chunks_accumulation(): @@ -620,7 +631,9 @@ def test_web_search_tool_result_no_extra_tool_calls(): # Should have exactly 2 tool calls: # 1. From content_block_start (server_tool_use) with id and name # 2. From content_block_delta with the actual query - assert len(tool_calls_emitted) == 2, f"Expected 2 tool calls, got {len(tool_calls_emitted)}" + assert ( + len(tool_calls_emitted) == 2 + ), f"Expected 2 tool calls, got {len(tool_calls_emitted)}" # First tool call should have the id and name assert tool_calls_emitted[0]["id"] == "srvtoolu_01ABC123" @@ -722,7 +735,10 @@ def test_web_search_tool_result_captured_in_provider_specific_fields(): { "type": "content_block_delta", "index": 0, - "delta": {"type": "input_json_delta", "partial_json": '{"query": "otter facts"}'}, + "delta": { + "type": "input_json_delta", + "partial_json": '{"query": "otter facts"}', + }, }, # 4. content_block_stop for server_tool_use {"type": "content_block_stop", "index": 0}, @@ -822,7 +838,10 @@ def test_web_fetch_tool_result_captured_in_provider_specific_fields(): { "type": "content_block_delta", "index": 0, - "delta": {"type": "input_json_delta", "partial_json": '{"url": "https://example.com"}'}, + "delta": { + "type": "input_json_delta", + "partial_json": '{"url": "https://example.com"}', + }, }, # 4. content_block_stop for server_tool_use {"type": "content_block_stop", "index": 0}, @@ -946,7 +965,7 @@ def test_web_fetch_tool_result_no_extra_tool_calls(): def test_container_in_provider_specific_fields_streaming(): """ Test that container is captured in provider_specific_fields for streaming responses. - + When container with skills is used, the container field should be present in the provider_specific_fields of the message_delta chunk. """ @@ -1025,7 +1044,9 @@ def test_container_in_provider_specific_fields_streaming(): ] # Verify container was captured - assert container_field is not None, "container should be captured in provider_specific_fields" + assert ( + container_field is not None + ), "container should be captured in provider_specific_fields" assert ( container_field["id"] == "container_011CW9hA9zpZ8xD3bjjShy4p" ), "container id should match" @@ -1033,18 +1054,14 @@ def test_container_in_provider_specific_fields_streaming(): container_field["expires_at"] == "2025-12-16T04:57:16.913181Z" ), "expires_at should match" assert len(container_field["skills"]) == 1, "Should have 1 skill" - assert ( - container_field["skills"][0]["skill_id"] == "pptx" - ), "skill_id should be pptx" - assert ( - container_field["skills"][0]["version"] == "20251013" - ), "version should match" + assert container_field["skills"][0]["skill_id"] == "pptx", "skill_id should be pptx" + assert container_field["skills"][0]["version"] == "20251013", "version should match" def test_container_in_provider_specific_fields_non_streaming(): """ Test that container is captured in provider_specific_fields for non-streaming responses. - + When container with skills is used in non-streaming, the container field should be present in the provider_specific_fields of the response. """ @@ -1106,7 +1123,7 @@ def test_container_in_provider_specific_fields_non_streaming(): def test_container_absent_when_not_provided(): """ Test that container is not added to provider_specific_fields when not provided. - + This ensures we don't add empty or None container fields. """ iterator = ModelResponseIterator( @@ -1133,3 +1150,197 @@ def test_container_absent_when_not_provided(): assert ( "container" not in model_response.choices[0].delta.provider_specific_fields ), "container should not be present when not provided in delta" + + +def test_streaming_code_execution_produces_code_interpreter_results(): + """ + Test that bash_code_execution_tool_result content blocks in streaming + produce code_interpreter_results in provider_specific_fields, so the + Responses API layer can use them without Anthropic-specific knowledge. + """ + + chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "text", + "text": "", + }, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Running code..."}, + }, + {"type": "content_block_stop", "index": 0}, + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01ABC", + "name": "bash_code_execution", + "input": {"command": "echo hello"}, + }, + }, + {"type": "content_block_stop", "index": 1}, + { + "type": "content_block_start", + "index": 2, + "content_block": { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01ABC", + "content": { + "type": "bash_code_execution_result", + "stdout": "hello\n", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 2}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + iterator = ModelResponseIterator(None, sync_stream=True) + + found_code_interpreter_results = False + for chunk in chunks: + parsed = iterator.chunk_parser(chunk) + psf = None + if parsed.choices and parsed.choices[0].delta: + psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) + if psf and "code_interpreter_results" in psf: + found_code_interpreter_results = True + results = psf["code_interpreter_results"] + assert len(results) == 1 + assert isinstance(results[0], OutputCodeInterpreterCall) + assert results[0].type == "code_interpreter_call" + assert results[0].id == "srvtoolu_01ABC" + assert results[0].code == "echo hello" + assert results[0].outputs is not None + assert len(results[0].outputs) == 1 + assert results[0].outputs[0].logs == "hello\n" + + assert found_code_interpreter_results, ( + "code_interpreter_results should appear in provider_specific_fields " + "when bash_code_execution_tool_result is streamed" + ) + + +def test_streaming_multiple_code_executions_no_duplicates(): + """ + Test that multiple code executions in a single streaming response produce + exactly one code_interpreter_result per execution — no duplicates from + _build_code_interpreter_results rebuilding the full list. + """ + chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + # First code execution + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01AAA", + "name": "bash_code_execution", + "input": {"command": "echo first"}, + }, + }, + {"type": "content_block_stop", "index": 0}, + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01AAA", + "content": { + "type": "bash_code_execution_result", + "stdout": "first\n", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 1}, + # Second code execution + { + "type": "content_block_start", + "index": 2, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01BBB", + "name": "bash_code_execution", + "input": {"command": "echo second"}, + }, + }, + {"type": "content_block_stop", "index": 2}, + { + "type": "content_block_start", + "index": 3, + "content_block": { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01BBB", + "content": { + "type": "bash_code_execution_result", + "stdout": "second\n", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 3}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + iterator = ModelResponseIterator(None, sync_stream=True) + + # Collect ALL code_interpreter_results emitted across all chunks + all_results = [] + for chunk in chunks: + parsed = iterator.chunk_parser(chunk) + psf = None + if parsed.choices and parsed.choices[0].delta: + psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) + if psf and "code_interpreter_results" in psf: + all_results.extend(psf["code_interpreter_results"]) + + # Should have exactly 2 results, one per execution — no duplicates + assert len(all_results) == 2, ( + f"Expected 2 code_interpreter_results, got {len(all_results)}. " + f"IDs: {[r.id for r in all_results]}" + ) + assert all_results[0].id == "srvtoolu_01AAA" + assert all_results[0].code == "echo first" + assert all_results[0].outputs[0].logs == "first\n" + assert all_results[1].id == "srvtoolu_01BBB" + assert all_results[1].code == "echo second" + assert all_results[1].outputs[0].logs == "second\n" diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index a95b9413b9..a346996486 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -183,7 +183,9 @@ def test_extract_response_content_with_citations(): }, } - _, citations, _, _, _, _, _, _ = config.extract_response_content(completion_response) + _, citations, _, _, _, _, _, _ = config.extract_response_content( + completion_response + ) assert citations == [ [ { @@ -305,7 +307,7 @@ def test_web_search_tool_result_extraction(): "type": "server_tool_use", "id": "srvtoolu_01ABC123", "name": "web_search", - "input": {"query": "average weight african elephant kg"} + "input": {"query": "average weight african elephant kg"}, }, { "type": "web_search_tool_result", @@ -317,32 +319,39 @@ def test_web_search_tool_result_extraction(): "title": "African Elephant Facts", "encrypted_content": "encrypted_data_here", "page_age": "2024-01-15", - "snippet": "Adult African elephants weigh between 4,000-6,000 kg..." + "snippet": "Adult African elephants weigh between 4,000-6,000 kg...", } - ] + ], }, { "type": "text", - "text": "Based on my search, African elephants weigh around 5,000 kg." + "text": "Based on my search, African elephants weigh around 5,000 kg.", }, { "type": "tool_use", "id": "toolu_01XYZ789", "name": "add_numbers", - "input": {"a": 5000, "b": 100} - } + "input": {"a": 5000, "b": 100}, + }, ], "stop_reason": "tool_use", "usage": { "input_tokens": 100, "output_tokens": 50, - "server_tool_use": {"web_search_requests": 1} - } + "server_tool_use": {"web_search_requests": 1}, + }, } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify text extraction assert "Based on my search" in text @@ -388,7 +397,7 @@ def test_web_search_tool_result_in_provider_specific_fields(): "type": "server_tool_use", "id": "srvtoolu_provider_test", "name": "web_search", - "input": {"query": "test query"} + "input": {"query": "test query"}, }, { "type": "web_search_tool_result", @@ -398,21 +407,18 @@ def test_web_search_tool_result_in_provider_specific_fields(): "type": "web_search_result", "url": "https://example.com/test", "title": "Test Result", - "snippet": "Test snippet content" + "snippet": "Test snippet content", } - ] + ], }, - { - "type": "text", - "text": "Here is the result." - } + {"type": "text", "text": "Here is the result."}, ], "stop_reason": "end_turn", "usage": { "input_tokens": 50, "output_tokens": 25, - "server_tool_use": {"web_search_requests": 1} - } + "server_tool_use": {"web_search_requests": 1}, + }, } raw_response = httpx.Response(status_code=200, headers={}) @@ -432,7 +438,10 @@ def test_web_search_tool_result_in_provider_specific_fields(): assert "web_search_results" in provider_fields assert len(provider_fields["web_search_results"]) == 1 assert provider_fields["web_search_results"][0]["type"] == "web_search_tool_result" - assert provider_fields["web_search_results"][0]["tool_use_id"] == "srvtoolu_provider_test" + assert ( + provider_fields["web_search_results"][0]["tool_use_id"] + == "srvtoolu_provider_test" + ) def test_multiple_web_search_tool_results(): @@ -447,34 +456,52 @@ def test_multiple_web_search_tool_results(): "type": "server_tool_use", "id": "srvtoolu_search1", "name": "web_search", - "input": {"query": "african elephant weight"} + "input": {"query": "african elephant weight"}, }, { "type": "web_search_tool_result", "tool_use_id": "srvtoolu_search1", - "content": [{"type": "web_search_result", "url": "https://example1.com", "title": "Result 1", "snippet": "First result"}] + "content": [ + { + "type": "web_search_result", + "url": "https://example1.com", + "title": "Result 1", + "snippet": "First result", + } + ], }, { "type": "server_tool_use", "id": "srvtoolu_search2", "name": "web_search", - "input": {"query": "asian elephant weight"} + "input": {"query": "asian elephant weight"}, }, { "type": "web_search_tool_result", "tool_use_id": "srvtoolu_search2", - "content": [{"type": "web_search_result", "url": "https://example2.com", "title": "Result 2", "snippet": "Second result"}] + "content": [ + { + "type": "web_search_result", + "url": "https://example2.com", + "title": "Result 2", + "snippet": "Second result", + } + ], }, - { - "type": "text", - "text": "Found information about both elephants." - } + {"type": "text", "text": "Found information about both elephants."}, ] } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify both web_search_tool_results are extracted assert web_search_results is not None @@ -751,7 +778,7 @@ def test_anthropic_beta_header_merging_with_output_format(): optional_params = { "output_format": { "type": "json_schema", - "schema": {"type": "object", "properties": {}} + "schema": {"type": "object", "properties": {}}, } } @@ -761,10 +788,12 @@ def test_anthropic_beta_header_merging_with_output_format(): # Both beta headers should be present beta_value = result_headers["anthropic-beta"] - assert "context-1m-2025-08-07" in beta_value, \ - f"User's context-1m beta header missing from: {beta_value}" - assert "structured-outputs-2025-11-13" in beta_value, \ - f"Structured output beta header missing from: {beta_value}" + assert ( + "context-1m-2025-08-07" in beta_value + ), f"User's context-1m beta header missing from: {beta_value}" + assert ( + "structured-outputs-2025-11-13" in beta_value + ), f"Structured output beta header missing from: {beta_value}" def test_anthropic_beta_header_merging_with_multiple_features(): @@ -780,10 +809,10 @@ def test_anthropic_beta_header_merging_with_multiple_features(): optional_params = { "output_format": { "type": "json_schema", - "schema": {"type": "object", "properties": {}} + "schema": {"type": "object", "properties": {}}, }, "context_management": _sample_context_management_payload(), - "tools": [{"type": "web_fetch_20250910", "name": "web_fetch"}] + "tools": [{"type": "web_fetch_20250910", "name": "web_fetch"}], } result_headers = config.update_headers_with_optional_anthropic_beta( @@ -950,20 +979,12 @@ def test_tool_search_regex_detection(): # Test with tool search regex tool tools = [ - { - "type": "tool_search_tool_regex_20251119", - "name": "tool_search_tool_regex" - } + {"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"} ] assert config.is_tool_search_used(tools) is True # Test without tool search - tools = [ - { - "type": "function", - "function": {"name": "get_weather"} - } - ] + tools = [{"type": "function", "function": {"name": "get_weather"}}] assert config.is_tool_search_used(tools) is False @@ -975,10 +996,7 @@ def test_tool_search_bm25_detection(): # Test with tool search BM25 tool tools = [ - { - "type": "tool_search_tool_bm25_20251119", - "name": "tool_search_tool_bm25" - } + {"type": "tool_search_tool_bm25_20251119", "name": "tool_search_tool_bm25"} ] assert config.is_tool_search_used(tools) is True @@ -1002,10 +1020,7 @@ def test_tool_search_regex_mapping(): """Test that tool search regex tools are properly mapped""" config = AnthropicConfig() - tool = { - "type": "tool_search_tool_regex_20251119", - "name": "tool_search_tool_regex" - } + tool = {"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"} mapped_tool, mcp_server = config._map_tool_helper(tool) @@ -1019,10 +1034,7 @@ def test_tool_search_bm25_mapping(): """Test that tool search BM25 tools are properly mapped""" config = AnthropicConfig() - tool = { - "type": "tool_search_tool_bm25_20251119", - "name": "tool_search_tool_bm25" - } + tool = {"type": "tool_search_tool_bm25_20251119", "name": "tool_search_tool_bm25"} mapped_tool, mcp_server = config._map_tool_helper(tool) @@ -1037,20 +1049,17 @@ def test_deferred_tools_separation(): config = AnthropicConfig() tools = [ - { - "type": "tool_search_tool_regex_20251119", - "name": "tool_search_tool_regex" - }, + {"type": "tool_search_tool_regex_20251119", "name": "tool_search_tool_regex"}, { "type": "function", "function": {"name": "get_weather"}, - "defer_loading": True + "defer_loading": True, }, { "type": "function", "function": {"name": "search_files"}, - "defer_loading": False - } + "defer_loading": False, + }, ] non_deferred, deferred = config._separate_deferred_tools(tools) @@ -1069,14 +1078,21 @@ def test_server_tool_use_in_response(): "type": "server_tool_use", "id": "srvtoolu_01ABC123", "name": "tool_search_tool_regex", - "input": {"query": "weather"} + "input": {"query": "weather"}, } ] } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) assert len(tool_calls) == 1 assert tool_calls[0]["id"] == "srvtoolu_01ABC123" @@ -1091,9 +1107,7 @@ def test_tool_search_usage_tracking(): usage_object = { "input_tokens": 100, "output_tokens": 50, - "server_tool_use": { - "tool_search_requests": 2 - } + "server_tool_use": {"tool_search_requests": 2}, } usage = config.calculate_usage(usage_object=usage_object, reasoning_content=None) @@ -1109,16 +1123,13 @@ def test_tool_reference_expansion(): deferred_tools = [ { "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather" - } + "function": {"name": "get_weather", "description": "Get weather"}, } ] content = [ {"type": "text", "text": "I'll search for tools"}, - {"type": "tool_reference", "tool_name": "get_weather"} + {"type": "tool_reference", "tool_name": "get_weather"}, ] expanded = config._expand_tool_references(content, deferred_tools) @@ -1140,13 +1151,11 @@ def test_defer_loading_preserved_in_transformation(): "description": "Get weather information", "parameters": { "type": "object", - "properties": { - "location": {"type": "string"} - }, - "required": ["location"] - } + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, }, - "defer_loading": True + "defer_loading": True, } mapped_tool, mcp_server = config._map_tool_helper(tool) @@ -1166,45 +1175,51 @@ def test_tool_search_complete_response_parsing(): "content": [ { "type": "text", - "text": "I'll search for weather-related tools that can help you." + "text": "I'll search for weather-related tools that can help you.", }, { "type": "server_tool_use", "id": "srvtoolu_015i6aVA2niwzv4RG4DtnxDJ", "name": "tool_search_tool_regex", "input": {"pattern": "weather", "limit": 5}, - "caller": {"type": "direct"} + "caller": {"type": "direct"}, }, { "type": "tool_search_tool_result", "tool_use_id": "srvtoolu_015i6aVA2niwzv4RG4DtnxDJ", "content": { "type": "tool_search_tool_search_result", - "tool_references": [{"type": "tool_reference", "tool_name": "get_weather"}] - } - }, - { - "type": "text", - "text": "Great! I found a weather tool." + "tool_references": [ + {"type": "tool_reference", "tool_name": "get_weather"} + ], + }, }, + {"type": "text", "text": "Great! I found a weather tool."}, { "type": "tool_use", "id": "toolu_01CrCNx4ntSaeeV9iArT4JfQ", "name": "get_weather", - "input": {"location": "San Francisco"} - } + "input": {"location": "San Francisco"}, + }, ], "usage": { "input_tokens": 1639, "output_tokens": 170, - "server_tool_use": {"web_search_requests": 0} - } + "server_tool_use": {"web_search_requests": 0}, + }, } # Extract content - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify text extraction (should concatenate both text blocks) assert "I'll search for weather-related tools" in text @@ -1222,12 +1237,14 @@ def test_tool_search_complete_response_parsing(): usage = config.calculate_usage( usage_object=completion_response["usage"], reasoning_content=None, - completion_response=completion_response + completion_response=completion_response, ) assert usage.server_tool_use is not None assert usage.server_tool_use.web_search_requests == 0 - assert usage.server_tool_use.tool_search_requests == 1 # Counted from server_tool_use blocks + assert ( + usage.server_tool_use.tool_search_requests == 1 + ) # Counted from server_tool_use blocks def test_allowed_callers_field_preservation(): @@ -1242,13 +1259,11 @@ def test_allowed_callers_field_preservation(): "description": "Execute a SQL query", "parameters": { "type": "object", - "properties": { - "sql": {"type": "string"} - }, - "required": ["sql"] - } + "properties": {"sql": {"type": "string"}}, + "required": ["sql"], + }, }, - "allowed_callers": ["code_execution_20250825"] + "allowed_callers": ["code_execution_20250825"], } transformed_tool, _ = config._map_tool_helper(tool_with_allowed_callers) @@ -1265,19 +1280,16 @@ def test_programmatic_tool_calling_beta_header(): # Test detection with allowed_callers tools = [ - { - "type": "code_execution_20250825", - "name": "code_execution" - }, + {"type": "code_execution_20250825", "name": "code_execution"}, { "type": "function", "function": { "name": "query_database", "description": "Execute a SQL query", - "parameters": {"type": "object", "properties": {}} + "parameters": {"type": "object", "properties": {}}, }, - "allowed_callers": ["code_execution_20250825"] - } + "allowed_callers": ["code_execution_20250825"], + }, ] is_programmatic = model_info.is_programmatic_tool_calling_used(tools) @@ -1285,8 +1297,7 @@ def test_programmatic_tool_calling_beta_header(): # Test header generation headers = model_info.get_anthropic_headers( - api_key="test-key", - programmatic_tool_calling_used=True + api_key="test-key", programmatic_tool_calling_used=True ) assert "anthropic-beta" in headers @@ -1303,10 +1314,7 @@ def test_caller_field_in_response(): "type": "message", "role": "assistant", "content": [ - { - "type": "text", - "text": "I'll query the database." - }, + {"type": "text", "text": "I'll query the database."}, { "type": "tool_use", "id": "toolu_123", @@ -1314,15 +1322,24 @@ def test_caller_field_in_response(): "input": {"sql": "SELECT * FROM users"}, "caller": { "type": "code_execution_20250825", - "tool_id": "srvtoolu_abc" - } - } + "tool_id": "srvtoolu_abc", + }, + }, ], "stop_reason": "tool_use", - "usage": {"input_tokens": 100, "output_tokens": 50} + "usage": {"input_tokens": 100, "output_tokens": 50}, } - text, citations, thinking, reasoning, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content(completion_response) + ( + text, + citations, + thinking, + reasoning, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) assert len(tool_calls) == 1 assert tool_calls[0]["id"] == "toolu_123" @@ -1337,10 +1354,7 @@ def test_code_execution_20250825_tool_type(): """Test that code_execution_20250825 tool type is handled correctly.""" config = AnthropicConfig() - tool = { - "type": "code_execution_20250825", - "name": "code_execution" - } + tool = {"type": "code_execution_20250825", "name": "code_execution"} transformed_tool, _ = config._map_tool_helper(tool) assert transformed_tool is not None @@ -1360,13 +1374,11 @@ def test_allowed_callers_in_function_field(): "description": "Execute a SQL query", "parameters": { "type": "object", - "properties": { - "sql": {"type": "string"} - }, - "required": ["sql"] + "properties": {"sql": {"type": "string"}}, + "required": ["sql"], }, - "allowed_callers": ["code_execution_20250825"] - } + "allowed_callers": ["code_execution_20250825"], + }, } transformed_tool, _ = config._map_tool_helper(tool) @@ -1389,15 +1401,15 @@ def test_input_examples_field_preservation(): "type": "object", "properties": { "location": {"type": "string"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "required": ["location"] - } + "required": ["location"], + }, }, "input_examples": [ {"location": "San Francisco, CA", "unit": "fahrenheit"}, - {"location": "Tokyo, Japan", "unit": "celsius"} - ] + {"location": "Tokyo, Japan", "unit": "celsius"}, + ], } transformed_tool, _ = config._map_tool_helper(tool_with_examples) @@ -1420,11 +1432,9 @@ def test_input_examples_beta_header(): "function": { "name": "get_weather", "description": "Get weather information", - "parameters": {"type": "object", "properties": {}} + "parameters": {"type": "object", "properties": {}}, }, - "input_examples": [ - {"location": "San Francisco, CA"} - ] + "input_examples": [{"location": "San Francisco, CA"}], } ] @@ -1433,8 +1443,7 @@ def test_input_examples_beta_header(): # Test header generation headers = model_info.get_anthropic_headers( - api_key="test-key", - input_examples_used=True + api_key="test-key", input_examples_used=True ) assert "anthropic-beta" in headers @@ -1453,16 +1462,14 @@ def test_input_examples_in_function_field(): "description": "Get weather information", "parameters": { "type": "object", - "properties": { - "location": {"type": "string"} - }, - "required": ["location"] + "properties": {"location": {"type": "string"}}, + "required": ["location"], }, "input_examples": [ {"location": "Paris, France"}, - {"location": "London, UK"} - ] - } + {"location": "London, UK"}, + ], + }, } transformed_tool, _ = config._map_tool_helper(tool) @@ -1483,17 +1490,13 @@ def test_input_examples_with_other_features(): "description": "Execute a SQL query", "parameters": { "type": "object", - "properties": { - "sql": {"type": "string"} - }, - "required": ["sql"] - } + "properties": {"sql": {"type": "string"}}, + "required": ["sql"], + }, }, - "input_examples": [ - {"sql": "SELECT * FROM users WHERE id = 1"} - ], + "input_examples": [{"sql": "SELECT * FROM users WHERE id = 1"}], "defer_loading": True, - "allowed_callers": ["code_execution_20250825"] + "allowed_callers": ["code_execution_20250825"], } transformed_tool, _ = config._map_tool_helper(tool) @@ -1517,19 +1520,20 @@ def test_input_examples_empty_list_not_added(): "description": "Get weather information", "parameters": { "type": "object", - "properties": { - "location": {"type": "string"} - }, - "required": ["location"] - } + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, }, - "input_examples": [] + "input_examples": [], } transformed_tool, _ = config._map_tool_helper(tool) assert transformed_tool is not None # Empty list should not be added - assert "input_examples" not in transformed_tool or len(transformed_tool.get("input_examples", [])) == 0 + assert ( + "input_examples" not in transformed_tool + or len(transformed_tool.get("input_examples", [])) == 0 + ) # ============ Effort Parameter Tests ============ @@ -1540,18 +1544,14 @@ def test_effort_output_config_preservation(): config = AnthropicConfig() messages = [{"role": "user", "content": "Analyze this code"}] - optional_params = { - "output_config": { - "effort": "medium" - } - } + optional_params = {"output_config": {"effort": "medium"}} result = config.transform_request( model="claude-opus-4-5-20251101", messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) assert "output_config" in result @@ -1565,18 +1565,13 @@ def test_effort_beta_header_injection(): model_info = AnthropicModelInfo() # Test with effort parameter - optional_params = { - "output_config": { - "effort": "low" - } - } + optional_params = {"output_config": {"effort": "low"}} effort_used = model_info.is_effort_used(optional_params=optional_params) assert effort_used is True headers = model_info.get_anthropic_headers( - api_key="test-key", - effort_used=effort_used + api_key="test-key", effort_used=effort_used ) assert "anthropic-beta" in headers @@ -1597,7 +1592,7 @@ def test_effort_validation(): messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) assert result["output_config"]["effort"] == effort @@ -1609,7 +1604,7 @@ def test_effort_validation(): messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) @@ -1618,18 +1613,14 @@ def test_effort_with_claude_opus_45(): config = AnthropicConfig() messages = [{"role": "user", "content": "Complex analysis task"}] - optional_params = { - "output_config": { - "effort": "high" - } - } + optional_params = {"output_config": {"effort": "high"}} result = config.transform_request( model="claude-opus-4-5-20251101", messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) assert "output_config" in result @@ -1650,7 +1641,7 @@ def test_effort_validation_with_opus_46(): messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) assert result["output_config"]["effort"] == effort @@ -1661,14 +1652,16 @@ def test_max_effort_rejected_for_opus_45(): messages = [{"role": "user", "content": "Test"}] - with pytest.raises(ValueError, match="effort='max' is only supported by Claude Opus 4.6"): + with pytest.raises( + ValueError, match="effort='max' is only supported by Claude Opus 4.6" + ): optional_params = {"output_config": {"effort": "max"}} config.transform_request( model="claude-opus-4-5-20251101", messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) @@ -1685,23 +1678,16 @@ def test_effort_with_other_features(): "description": "Get data", "parameters": { "type": "object", - "properties": { - "query": {"type": "string"} - }, - "required": ["query"] - } - } + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, } ] optional_params = { - "output_config": { - "effort": "low" - }, + "output_config": {"effort": "low"}, "tools": tools, - "thinking": { - "type": "enabled", - "budget_tokens": 1000 - } + "thinking": {"type": "enabled", "budget_tokens": 1000}, } result = config.transform_request( @@ -1709,7 +1695,7 @@ def test_effort_with_other_features(): messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) # Verify all features are present @@ -1752,11 +1738,14 @@ def test_translate_system_message_skips_empty_list_content(): # Test list content with empty text block messages = [ - {"role": "system", "content": [ - {"type": "text", "text": ""}, - {"type": "text", "text": "Valid content"}, - {"type": "text", "text": ""}, - ]}, + { + "role": "system", + "content": [ + {"type": "text", "text": ""}, + {"type": "text", "text": "Valid content"}, + {"type": "text", "text": ""}, + ], + }, {"role": "user", "content": "Hello"}, ] @@ -1794,9 +1783,16 @@ def test_translate_system_message_preserves_cache_control(): # Test list content with cache_control messages = [ - {"role": "system", "content": [ - {"type": "text", "text": "Cached content", "cache_control": {"type": "ephemeral"}}, - ]}, + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Cached content", + "cache_control": {"type": "ephemeral"}, + }, + ], + }, {"role": "user", "content": "Hello"}, ] @@ -1938,7 +1934,7 @@ def test_transform_request_uses_dynamic_max_tokens(): messages=messages, optional_params={}, # No max_tokens provided litellm_params={}, - headers={} + headers={}, ) assert result["max_tokens"] == 64000 @@ -1959,7 +1955,7 @@ def test_transform_request_respects_user_max_tokens(): messages=messages, optional_params={"max_tokens": 1000}, litellm_params={}, - headers={} + headers={}, ) assert result["max_tokens"] == 1000 @@ -2006,11 +2002,12 @@ def test_calculate_usage_completion_tokens_details_with_reasoning(): "output_tokens": 500, } # Simulating reasoning content that would count as ~50 tokens - reasoning_content = "Let me think about this step by step. " * 10 # Roughly 50 tokens + reasoning_content = ( + "Let me think about this step by step. " * 10 + ) # Roughly 50 tokens usage = config.calculate_usage( - usage_object=usage_object, - reasoning_content=reasoning_content + usage_object=usage_object, reasoning_content=reasoning_content ) # completion_tokens_details should be populated with both reasoning and text tokens @@ -2051,7 +2048,7 @@ def test_reasoning_effort_maps_to_adaptive_thinking_for_claude_4_6_models(): non_default_params=non_default_params, optional_params=optional_params, model=model, - drop_params=False + drop_params=False, ) # Should map to adaptive thinking type @@ -2062,7 +2059,9 @@ def test_reasoning_effort_maps_to_adaptive_thinking_for_claude_4_6_models(): # reasoning_effort should not be in the result (it's transformed to thinking) assert "reasoning_effort" not in result # Should set output_config with the mapped effort value - assert "output_config" in result, f"output_config missing for {model} with effort={effort}" + assert ( + "output_config" in result + ), f"output_config missing for {model} with effort={effort}" assert result["output_config"]["effort"] == effort_map[effort] @@ -2123,10 +2122,10 @@ def test_reasoning_effort_maps_to_budget_thinking_for_non_opus_4_6(): # Test with Claude Sonnet 4.5 (non-Opus 4.6 model) test_cases = [ - ("low", 1024), # DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET - ("medium", 2048), # DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET - ("high", 4096), # DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET - ("minimal", 128), # DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET + ("low", 1024), # DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET + ("medium", 2048), # DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET + ("high", 4096), # DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET + ("minimal", 128), # DEFAULT_REASONING_EFFORT_MINIMAL_THINKING_BUDGET ] for effort, expected_budget in test_cases: @@ -2137,7 +2136,7 @@ def test_reasoning_effort_maps_to_budget_thinking_for_non_opus_4_6(): non_default_params=non_default_params, optional_params=optional_params, model="claude-sonnet-4-5-20250929", - drop_params=False + drop_params=False, ) # Should map to enabled thinking type with budget_tokens @@ -2166,9 +2165,9 @@ def test_reasoning_effort_sets_output_config_for_46_models(): drop_params=False, ) - assert "output_config" in result, ( - f"output_config missing for {model} with effort={effort}" - ) + assert ( + "output_config" in result + ), f"output_config missing for {model} with effort={effort}" assert result["output_config"]["effort"] == effort @@ -2207,9 +2206,9 @@ def test_reasoning_effort_does_not_set_output_config_for_older_models(): drop_params=False, ) - assert "output_config" not in result, ( - f"output_config should not be set for {model}" - ) + assert ( + "output_config" not in result + ), f"output_config should not be set for {model}" def test_max_effort_rejected_for_sonnet_46(): @@ -2217,7 +2216,9 @@ def test_max_effort_rejected_for_sonnet_46(): config = AnthropicConfig() messages = [{"role": "user", "content": "Test"}] - with pytest.raises(ValueError, match="effort='max' is only supported by Claude Opus 4.6"): + with pytest.raises( + ValueError, match="effort='max' is only supported by Claude Opus 4.6" + ): config.transform_request( model="claude-sonnet-4-6-20260219", messages=messages, @@ -2260,9 +2261,7 @@ def test_effort_beta_header_not_injected_for_46_models(): optional_params={"output_config": {"effort": "high"}}, model=model, ) - assert result is False, ( - f"is_effort_used should return False for {model}" - ) + assert result is False, f"is_effort_used should return False for {model}" def test_effort_beta_header_still_injected_for_older_models(): @@ -2302,17 +2301,12 @@ def test_code_execution_tool_results_extraction(): "role": "assistant", "model": "claude-sonnet-4-5-20250929", "content": [ - { - "type": "text", - "text": "I'll calculate that for you." - }, + {"type": "text", "text": "I'll calculate that for you."}, { "type": "server_tool_use", "id": "srvtoolu_01ABC", "name": "bash_code_execution", - "input": { - "command": "python3 << 'EOF'\nprint(2 + 2)\nEOF\n" - } + "input": {"command": "python3 << 'EOF'\nprint(2 + 2)\nEOF\n"}, }, { "type": "bash_code_execution_tool_result", @@ -2321,8 +2315,8 @@ def test_code_execution_tool_results_extraction(): "type": "bash_code_execution_result", "stdout": "4\n", "stderr": "", - "return_code": 0 - } + "return_code": 0, + }, }, { "type": "server_tool_use", @@ -2331,28 +2325,22 @@ def test_code_execution_tool_results_extraction(): "input": { "command": "create", "path": "test.txt", - "file_text": "Hello" - } + "file_text": "Hello", + }, }, { "type": "text_editor_code_execution_tool_result", "tool_use_id": "srvtoolu_01DEF", "content": { "type": "text_editor_code_execution_result", - "is_file_update": False - } + "is_file_update": False, + }, }, - { - "type": "text", - "text": "Done!" - } + {"type": "text", "text": "Done!"}, ], "stop_reason": "stop", "stop_sequence": None, - "usage": { - "input_tokens": 100, - "output_tokens": 50 - } + "usage": {"input_tokens": 100, "output_tokens": 50}, } # Create mock HTTP response @@ -2377,11 +2365,17 @@ def test_code_execution_tool_results_extraction(): # Verify first tool call assert transformed_response.choices[0].message.tool_calls[0].id == "srvtoolu_01ABC" - assert transformed_response.choices[0].message.tool_calls[0].function.name == "bash_code_execution" + assert ( + transformed_response.choices[0].message.tool_calls[0].function.name + == "bash_code_execution" + ) # Verify second tool call assert transformed_response.choices[0].message.tool_calls[1].id == "srvtoolu_01DEF" - assert transformed_response.choices[0].message.tool_calls[1].function.name == "text_editor_code_execution" + assert ( + transformed_response.choices[0].message.tool_calls[1].function.name + == "text_editor_code_execution" + ) # Verify tool results are in provider_specific_fields provider_fields = transformed_response.choices[0].message.provider_specific_fields @@ -2404,10 +2398,83 @@ def test_code_execution_tool_results_extraction(): assert editor_result["content"]["is_file_update"] is False # Verify text content is properly concatenated - assert "I'll calculate that for you." in transformed_response.choices[0].message.content + assert ( + "I'll calculate that for you." + in transformed_response.choices[0].message.content + ) assert "Done!" in transformed_response.choices[0].message.content +def test_code_execution_tool_results_in_hidden_params(): + """ + Test that tool_results reaches _hidden_params so the Responses API adapter + can surface them via provider_specific_fields. + + The Responses API adapter reads _hidden_params.get("provider_specific_fields") + to set provider_specific_fields on the response. Without this, server-side + code execution results (stdout/stderr) are lost when using responses.create(). + """ + import httpx + + from litellm.types.utils import ModelResponse + + config = AnthropicConfig() + + mock_anthropic_response = { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "model": "claude-sonnet-4-5-20250929", + "content": [ + {"type": "text", "text": "Here's the result."}, + { + "type": "server_tool_use", + "id": "srvtoolu_01ABC", + "name": "bash_code_execution", + "input": {"command": "echo hello"}, + }, + { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01ABC", + "content": { + "type": "bash_code_execution_result", + "stdout": "hello\n", + "stderr": "", + "return_code": 0, + }, + }, + ], + "stop_reason": "stop", + "stop_sequence": None, + "usage": {"input_tokens": 100, "output_tokens": 50}, + } + + mock_raw_response = MagicMock(spec=httpx.Response) + mock_raw_response.json.return_value = mock_anthropic_response + mock_raw_response.status_code = 200 + mock_raw_response.headers = {} + + model_response = ModelResponse() + + transformed_response = config.transform_parsed_response( + completion_response=mock_anthropic_response, + raw_response=mock_raw_response, + model_response=model_response, + json_mode=False, + prefix_prompt=None, + ) + + # Verify tool_results is in _hidden_params for the Responses API adapter + hidden = transformed_response._hidden_params + assert "provider_specific_fields" in hidden + assert "tool_results" in hidden["provider_specific_fields"] + assert len(hidden["provider_specific_fields"]["tool_results"]) == 1 + assert ( + hidden["provider_specific_fields"]["tool_results"][0]["content"]["stdout"] + == "hello\n" + ) + + def test_tool_search_tool_result_not_in_tool_results(): """ Test that tool_search_tool_result is NOT included in tool_results @@ -2425,21 +2492,12 @@ def test_tool_search_tool_result_not_in_tool_results(): "role": "assistant", "model": "claude-sonnet-4-5-20250929", "content": [ - { - "type": "text", - "text": "Found tools." - }, - { - "type": "tool_search_tool_result", - "tool_references": ["tool1", "tool2"] - } + {"type": "text", "text": "Found tools."}, + {"type": "tool_search_tool_result", "tool_references": ["tool1", "tool2"]}, ], "stop_reason": "stop", "stop_sequence": None, - "usage": { - "input_tokens": 100, - "output_tokens": 50 - } + "usage": {"input_tokens": 100, "output_tokens": 50}, } mock_raw_response = MagicMock(spec=httpx.Response) @@ -2479,22 +2537,16 @@ def test_web_search_tool_result_backwards_compatibility(): "role": "assistant", "model": "claude-sonnet-4-5-20250929", "content": [ - { - "type": "text", - "text": "Here are the results." - }, + {"type": "text", "text": "Here are the results."}, { "type": "web_search_tool_result", "search_query": "test query", - "results": [{"title": "Result 1", "url": "https://example.com"}] - } + "results": [{"title": "Result 1", "url": "https://example.com"}], + }, ], "stop_reason": "stop", "stop_sequence": None, - "usage": { - "input_tokens": 100, - "output_tokens": 50 - } + "usage": {"input_tokens": 100, "output_tokens": 50}, } mock_raw_response = MagicMock(spec=httpx.Response) @@ -2540,24 +2592,28 @@ def test_compaction_block_extraction(): "content": [ { "type": "compaction", - "content": "Summary of the conversation: The user requested help building a web scraper..." + "content": "Summary of the conversation: The user requested help building a web scraper...", }, { "type": "text", - "text": "I don't have access to real-time data, so I can't provide the current weather in San Francisco." - } + "text": "I don't have access to real-time data, so I can't provide the current weather in San Francisco.", + }, ], "stop_reason": "max_tokens", "stop_sequence": None, - "usage": { - "input_tokens": 86, - "output_tokens": 100 - } + "usage": {"input_tokens": 86, "output_tokens": 100}, } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify compaction blocks are extracted assert compaction_blocks is not None @@ -2587,18 +2643,12 @@ def test_compaction_block_in_provider_specific_fields(): "content": [ { "type": "compaction", - "content": "Summary of the conversation: The user requested help building a web scraper..." + "content": "Summary of the conversation: The user requested help building a web scraper...", }, - { - "type": "text", - "text": "Here is the response." - } + {"type": "text", "text": "Here is the response."}, ], "stop_reason": "end_turn", - "usage": { - "input_tokens": 50, - "output_tokens": 25 - } + "usage": {"input_tokens": 50, "output_tokens": 25}, } raw_response = httpx.Response(status_code=200, headers={}) @@ -2618,7 +2668,10 @@ def test_compaction_block_in_provider_specific_fields(): assert "compaction_blocks" in provider_fields assert len(provider_fields["compaction_blocks"]) == 1 assert provider_fields["compaction_blocks"][0]["type"] == "compaction" - assert "Summary of the conversation" in provider_fields["compaction_blocks"][0]["content"] + assert ( + "Summary of the conversation" + in provider_fields["compaction_blocks"][0]["content"] + ) def test_multiple_compaction_blocks(): @@ -2629,24 +2682,22 @@ def test_multiple_compaction_blocks(): completion_response = { "content": [ - { - "type": "compaction", - "content": "First summary..." - }, - { - "type": "text", - "text": "Some text." - }, - { - "type": "compaction", - "content": "Second summary..." - } + {"type": "compaction", "content": "First summary..."}, + {"type": "text", "text": "Some text."}, + {"type": "compaction", "content": "Second summary..."}, ] } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify both compaction blocks are extracted assert compaction_blocks is not None @@ -2665,37 +2716,26 @@ def test_compaction_block_request_transformation(): ) messages = [ - { - "role": "user", - "content": "What is the weather in San Francisco?" - }, + {"role": "user", "content": "What is the weather in San Francisco?"}, { "role": "assistant", "content": [ - { - "type": "text", - "text": "I don't have access to real-time data." - } + {"type": "text", "text": "I don't have access to real-time data."} ], "provider_specific_fields": { "compaction_blocks": [ { "type": "compaction", - "content": "Summary of the conversation: The user requested help building a web scraper..." + "content": "Summary of the conversation: The user requested help building a web scraper...", } ] - } + }, }, - { - "role": "user", - "content": "What about New York?" - } + {"role": "user", "content": "What about New York?"}, ] result = anthropic_messages_pt( - messages=messages, - model="claude-opus-4-6", - llm_provider="anthropic" + messages=messages, model="claude-opus-4-6", llm_provider="anthropic" ) # Find the assistant message @@ -2727,14 +2767,8 @@ def test_compaction_with_context_management(): messages = [{"role": "user", "content": "Hello"}] optional_params = { - "context_management": { - "edits": [ - { - "type": "compact_20260112" - } - ] - }, - "max_tokens": 100 + "context_management": {"edits": [{"type": "compact_20260112"}]}, + "max_tokens": 100, } result = config.transform_request( @@ -2742,7 +2776,7 @@ def test_compaction_with_context_management(): messages=messages, optional_params=optional_params, litellm_params={}, - headers={} + headers={}, ) # Verify context_management is included @@ -2758,30 +2792,28 @@ def test_compaction_block_with_other_content_types(): completion_response = { "content": [ - { - "type": "compaction", - "content": "Summary of previous conversation..." - }, - { - "type": "thinking", - "thinking": "Let me think about this..." - }, - { - "type": "text", - "text": "Based on my analysis..." - }, + {"type": "compaction", "content": "Summary of previous conversation..."}, + {"type": "thinking", "thinking": "Let me think about this..."}, + {"type": "text", "text": "Based on my analysis..."}, { "type": "tool_use", "id": "toolu_123", "name": "get_weather", - "input": {"location": "San Francisco"} - } + "input": {"location": "San Francisco"}, + }, ] } - text, citations, thinking_blocks, reasoning_content, tool_calls, web_search_results, tool_results, compaction_blocks = config.extract_response_content( - completion_response - ) + ( + text, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = config.extract_response_content(completion_response) # Verify all content types are extracted assert compaction_blocks is not None @@ -2798,11 +2830,11 @@ def test_map_openai_context_management_to_anthropic(): Test mapping OpenAI Responses API context_management format to Anthropic format. """ config = AnthropicConfig() - + # Test OpenAI list format with compaction openai_format = [{"type": "compaction", "compact_threshold": 200000}] result = config.map_openai_context_management_to_anthropic(openai_format) - + assert result is not None assert "edits" in result assert len(result["edits"]) == 1 @@ -2811,26 +2843,32 @@ def test_map_openai_context_management_to_anthropic(): assert result["edits"][0]["trigger"]["value"] == 200000 # Test OpenAI format with instructions - openai_format_with_instructions = [{ - "type": "compaction", - "compact_threshold": 150000, - "instructions": "Focus on preserving code snippets" - }] - result = config.map_openai_context_management_to_anthropic(openai_format_with_instructions) - + openai_format_with_instructions = [ + { + "type": "compaction", + "compact_threshold": 150000, + "instructions": "Focus on preserving code snippets", + } + ] + result = config.map_openai_context_management_to_anthropic( + openai_format_with_instructions + ) + assert result is not None assert result["edits"][0]["trigger"]["value"] == 150000 assert result["edits"][0]["instructions"] == "Focus on preserving code snippets" - + # Test Anthropic format (should pass through) anthropic_format = { - "edits": [{ - "type": "compact_20260112", - "trigger": {"type": "input_tokens", "value": 150000} - }] + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": 150000}, + } + ] } result = config.map_openai_context_management_to_anthropic(anthropic_format) - + assert result == anthropic_format @@ -2839,46 +2877,51 @@ def test_map_openai_params_with_context_management(): Test that map_openai_params correctly transforms context_management from OpenAI to Anthropic format. """ config = AnthropicConfig() - + # Test with OpenAI list format non_default_params = { "context_management": [{"type": "compaction", "compact_threshold": 200000}] } optional_params = {} - + result = config.map_openai_params( non_default_params=non_default_params, optional_params=optional_params, model="claude-opus-4-6", - drop_params=False + drop_params=False, ) - + assert "context_management" in result assert "edits" in result["context_management"] assert result["context_management"]["edits"][0]["type"] == "compact_20260112" assert result["context_management"]["edits"][0]["trigger"]["value"] == 200000 - + # Test with Anthropic dict format (should pass through) non_default_params_anthropic = { "context_management": { - "edits": [{ - "type": "compact_20260112", - "trigger": {"type": "input_tokens", "value": 150000}, - "instructions": "Focus on preserving code" - }] + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": 150000}, + "instructions": "Focus on preserving code", + } + ] } } optional_params = {} - + result = config.map_openai_params( non_default_params=non_default_params_anthropic, optional_params=optional_params, model="claude-opus-4-6", - drop_params=False + drop_params=False, ) - + assert "context_management" in result - assert result["context_management"] == non_default_params_anthropic["context_management"] + assert ( + result["context_management"] + == non_default_params_anthropic["context_management"] + ) def test_cache_control_in_supported_params(): @@ -2897,9 +2940,7 @@ def test_map_openai_params_with_cache_control(): """ config = AnthropicConfig() - non_default_params = { - "cache_control": {"type": "ephemeral"} - } + non_default_params = {"cache_control": {"type": "ephemeral"}} optional_params = {} result = config.map_openai_params( @@ -2919,9 +2960,7 @@ def test_map_openai_params_cache_control_ignored_when_not_dict(): """ config = AnthropicConfig() - non_default_params = { - "cache_control": "ephemeral" - } + non_default_params = {"cache_control": "ephemeral"} optional_params = {} result = config.map_openai_params( @@ -2974,17 +3013,9 @@ def test_compaction_block_empty_list_not_added(): "type": "message", "role": "assistant", "model": "claude-opus-4-6", - "content": [ - { - "type": "text", - "text": "Just a regular response." - } - ], + "content": [{"type": "text", "text": "Just a regular response."}], "stop_reason": "end_turn", - "usage": { - "input_tokens": 10, - "output_tokens": 5 - } + "usage": {"input_tokens": 10, "output_tokens": 5}, } raw_response = httpx.Response(status_code=200, headers={}) @@ -3001,7 +3032,10 @@ def test_compaction_block_empty_list_not_added(): # Verify compaction_blocks is not in provider_specific_fields when there are none provider_fields = result.choices[0].message.provider_specific_fields if provider_fields: - assert "compaction_blocks" not in provider_fields or provider_fields.get("compaction_blocks") is None + assert ( + "compaction_blocks" not in provider_fields + or provider_fields.get("compaction_blocks") is None + ) def test_fast_mode_beta_header(): @@ -3014,8 +3048,7 @@ def test_fast_mode_beta_header(): optional_params = {"speed": "fast"} result_headers = config.update_headers_with_optional_anthropic_beta( - headers=headers, - optional_params=optional_params + headers=headers, optional_params=optional_params ) assert "anthropic-beta" in result_headers @@ -3029,14 +3062,10 @@ def test_fast_mode_with_other_beta_headers(): config = AnthropicConfig() headers = {} - optional_params = { - "speed": "fast", - "output_format": {"type": "json_object"} - } + optional_params = {"speed": "fast", "output_format": {"type": "json_object"}} result_headers = config.update_headers_with_optional_anthropic_beta( - headers=headers, - optional_params=optional_params + headers=headers, optional_params=optional_params ) assert "anthropic-beta" in result_headers @@ -3056,9 +3085,7 @@ def test_fast_mode_usage_calculation(): } usage = config.calculate_usage( - usage_object=usage_object, - reasoning_content=None, - speed="fast" + usage_object=usage_object, reasoning_content=None, speed="fast" ) assert usage.prompt_tokens == 1000 @@ -3171,7 +3198,7 @@ def test_fast_mode_parameter_mapping(): non_default_params=non_default_params, optional_params=optional_params, model="claude-opus-4-6", - drop_params=False + drop_params=False, ) assert "speed" in result @@ -3236,9 +3263,9 @@ def test_map_tool_helper_enforces_object_type_when_missing(): assert "properties" in result["input_schema"] assert "query" in result["input_schema"]["properties"] # Original parameters dict must not be modified in place - assert tool["function"]["parameters"] == original_params, ( - "parameters dict was mutated; _map_tool_helper should not modify caller data" - ) + assert ( + tool["function"]["parameters"] == original_params + ), "parameters dict was mutated; _map_tool_helper should not modify caller data" def test_map_tool_helper_enforces_object_type_when_wrong_type(): @@ -3264,13 +3291,13 @@ def test_map_tool_helper_enforces_object_type_when_wrong_type(): result, _ = config._map_tool_helper(tool) assert result is not None assert result["input_schema"]["type"] == "object" - assert result["input_schema"].get("properties") == {}, ( - "properties should be injected as {} when schema has non-object type and no properties key" - ) + assert ( + result["input_schema"].get("properties") == {} + ), "properties should be injected as {} when schema has non-object type and no properties key" # Original parameters dict must not be modified in place - assert tool["function"]["parameters"] == original_params, ( - "parameters dict was mutated; _map_tool_helper should not modify caller data" - ) + assert ( + tool["function"]["parameters"] == original_params + ), "parameters dict was mutated; _map_tool_helper should not modify caller data" def test_map_tool_helper_preserves_valid_object_schema(): From 2bf8751f6b0cfe30b51e4bb298c7e36ebc4ea46b Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Tue, 17 Mar 2026 17:57:25 +0100 Subject: [PATCH 100/539] fix: streaming code_interpreter_results dropped for multiple code executions stream_chunk_builder uses "last value wins" for list-valued provider_specific_fields keys. _build_code_interpreter_results was emitting only new items (incremental), so earlier results were silently dropped when multiple sequential code executions occurred. - Emit cumulative list from _build_code_interpreter_results, matching web_search_results pattern - Assemble server_tool_use input from input_json_delta deltas at content_block_stop (Anthropic streams input: {} in start block) - Handle dict items in _extract_tool_result_output_items after model_dump() serialization in stream_chunk_builder - Simplify _merge_provider_specific_fields to last-value-wins for lists, matching stream_chunk_builder semantics --- litellm/llms/anthropic/chat/handler.py | 45 ++++-- .../streaming_iterator.py | 19 ++- .../transformation.py | 5 +- .../chat/test_anthropic_chat_handler.py | 133 +++++++++++++++--- 4 files changed, 165 insertions(+), 37 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 51b9c9835a..7d5fa2a559 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -545,7 +545,7 @@ class ModelResponseIterator: # Track server tool use inputs and results for code_interpreter_results self._server_tool_inputs: Dict[str, Any] = {} self.tool_results: List[Dict[str, Any]] = [] - self._last_code_interpreter_results_count: int = 0 + self._current_server_tool_id: Optional[str] = None def check_empty_tool_call_args(self) -> bool: """ @@ -695,13 +695,14 @@ class ModelResponseIterator: Called during streaming to produce provider-neutral code_interpreter_results alongside the raw tool_results, so the Responses API layer doesn't need Anthropic-specific knowledge. + + Returns the full cumulative list each time (not incremental), matching + how web_search_results works. stream_chunk_builder uses "last value + wins" for list-valued provider_specific_fields keys, so the last + emission must contain every result. """ - # Only convert tool_results added since the last call to avoid - # duplicates when _merge_provider_specific_fields extends the list. - new_results = self.tool_results[self._last_code_interpreter_results_count :] - self._last_code_interpreter_results_count = len(self.tool_results) results = [] - for tr in new_results: + for tr in self.tool_results: call_id = tr.get("tool_use_id", "") content = tr.get("content", {}) if isinstance(content, dict): @@ -793,17 +794,23 @@ class ModelResponseIterator: ), index=self.tool_index, ) - # Track server tool use inputs for code_interpreter_results + # Track server tool use inputs for code_interpreter_results. + # The initial input in content_block_start is typically {} + # for streaming; the full input arrives via input_json_delta + # and is assembled at content_block_stop. if ( content_block_start["content_block"]["type"] == "server_tool_use" ): + self._current_server_tool_id = content_block_start[ + "content_block" + ]["id"] tool_input = content_block_start["content_block"].get( "input", {} ) - self._server_tool_inputs[ - content_block_start["content_block"]["id"] - ] = tool_input + self._server_tool_inputs[self._current_server_tool_id] = ( + tool_input + ) # Include caller information if present (for programmatic tool calling) if "caller" in content_block_start["content_block"]: caller_data = content_block_start["content_block"]["caller"] @@ -886,6 +893,24 @@ class ModelResponseIterator: ), index=self.tool_index, ) + # Update server_tool_inputs with fully assembled input + # from input_json_delta chunks (content_block_start has {}) + if ( + self.current_content_block_type == "server_tool_use" + and self._current_server_tool_id + ): + args = "" + for block in self.content_blocks: + if block["delta"]["type"] == "input_json_delta": + args += block["delta"].get("partial_json", "") + if args: + try: + self._server_tool_inputs[ + self._current_server_tool_id + ] = json.loads(args) + except (json.JSONDecodeError, TypeError): + pass + self._current_server_tool_id = None # Reset response_format tool tracking when block stops self.is_response_format_tool = False # Reset current content block type diff --git a/litellm/responses/litellm_completion_transformation/streaming_iterator.py b/litellm/responses/litellm_completion_transformation/streaming_iterator.py index 0b7d6e8a7a..0672b03bcd 100644 --- a/litellm/responses/litellm_completion_transformation/streaming_iterator.py +++ b/litellm/responses/litellm_completion_transformation/streaming_iterator.py @@ -481,17 +481,16 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator): return event def _merge_provider_specific_fields(self, src: dict) -> None: - """Merge provider_specific_fields, extending list values instead of replacing.""" + """Merge provider_specific_fields using last-value-wins for lists. + + List-valued keys (web_search_results, tool_results, + code_interpreter_results, etc.) are emitted cumulatively — each + emission contains the full list so far. Using "last value wins" + matches stream_chunk_builder's semantics and avoids quadratic + growth from repeated extend calls. + """ for key, val in src.items(): - existing = self._accumulated_provider_specific_fields.get(key) - if ( - existing is not None - and isinstance(val, list) - and isinstance(existing, list) - ): - existing.extend(val) - else: - self._accumulated_provider_specific_fields[key] = val + self._accumulated_provider_specific_fields[key] = val def create_litellm_model_response(self) -> Optional[ModelResponse]: response = cast( diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index b54d5930ef..b7f7e9adda 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -1738,7 +1738,10 @@ class LiteLLMCompletionResponsesConfig: ) ) if tool_result_items: - result_by_id = {item.id: item for item in tool_result_items} + result_by_id = { + (item.get("id") if isinstance(item, dict) else item.id): item + for item in tool_result_items + } replaced_ids = set(result_by_id.keys()) responses_output = [ ( diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py index 35c7a62027..d7a04a054a 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py @@ -1245,9 +1245,9 @@ def test_streaming_code_execution_produces_code_interpreter_results(): def test_streaming_multiple_code_executions_no_duplicates(): """ - Test that multiple code executions in a single streaming response produce - exactly one code_interpreter_result per execution — no duplicates from - _build_code_interpreter_results rebuilding the full list. + Test that multiple code executions in a single streaming response emit + cumulative code_interpreter_results on each chunk (matching stream_chunk_builder's + "last value wins" contract). The final emission must contain ALL results. """ chunks = [ { @@ -1323,24 +1323,125 @@ def test_streaming_multiple_code_executions_no_duplicates(): iterator = ModelResponseIterator(None, sync_stream=True) - # Collect ALL code_interpreter_results emitted across all chunks - all_results = [] + # Collect each emission of code_interpreter_results + emissions = [] for chunk in chunks: parsed = iterator.chunk_parser(chunk) psf = None if parsed.choices and parsed.choices[0].delta: psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) if psf and "code_interpreter_results" in psf: - all_results.extend(psf["code_interpreter_results"]) + emissions.append(psf["code_interpreter_results"]) - # Should have exactly 2 results, one per execution — no duplicates - assert len(all_results) == 2, ( - f"Expected 2 code_interpreter_results, got {len(all_results)}. " - f"IDs: {[r.id for r in all_results]}" + # Should have 2 emissions (one per tool_result block) + assert len(emissions) == 2, f"Expected 2 emissions, got {len(emissions)}" + + # First emission: cumulative list with 1 result + assert len(emissions[0]) == 1 + assert emissions[0][0].id == "srvtoolu_01AAA" + assert emissions[0][0].code == "echo first" + assert emissions[0][0].outputs[0].logs == "first\n" + + # Second (final) emission: cumulative list with BOTH results + # This is what stream_chunk_builder will pick as "last value wins" + assert len(emissions[1]) == 2, ( + f"Expected final emission to have 2 results, got {len(emissions[1])}. " + f"IDs: {[r.id for r in emissions[1]]}" ) - assert all_results[0].id == "srvtoolu_01AAA" - assert all_results[0].code == "echo first" - assert all_results[0].outputs[0].logs == "first\n" - assert all_results[1].id == "srvtoolu_01BBB" - assert all_results[1].code == "echo second" - assert all_results[1].outputs[0].logs == "second\n" + assert emissions[1][0].id == "srvtoolu_01AAA" + assert emissions[1][0].code == "echo first" + assert emissions[1][0].outputs[0].logs == "first\n" + assert emissions[1][1].id == "srvtoolu_01BBB" + assert emissions[1][1].code == "echo second" + assert emissions[1][1].outputs[0].logs == "second\n" + + +def test_streaming_code_execution_input_assembled_from_deltas(): + """ + In real Anthropic streaming, content_block_start for server_tool_use has + input: {}. The actual input arrives via input_json_delta deltas and must + be assembled at content_block_stop so the code field is populated. + + This test uses realistic chunk shapes (empty input in start, partial JSON + in deltas) to exercise the input assembly path. + """ + chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + # server_tool_use with empty input (real streaming behaviour) + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01AAA", + "name": "code_execution", + "input": {}, + }, + }, + # Input arrives via deltas, split across two chunks + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": '{"comma', + }, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": 'nd": "echo hello"}', + }, + }, + {"type": "content_block_stop", "index": 0}, + # Tool result + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "code_execution_tool_result", + "tool_use_id": "srvtoolu_01AAA", + "content": { + "type": "code_execution_result", + "stdout": "hello\n", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 1}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + iterator = ModelResponseIterator(None, sync_stream=True) + + code_results = None + for chunk in chunks: + parsed = iterator.chunk_parser(chunk) + psf = None + if parsed.choices and parsed.choices[0].delta: + psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) + if psf and "code_interpreter_results" in psf: + code_results = psf["code_interpreter_results"] + + # The code field must contain the assembled input, not be empty + assert code_results is not None, "No code_interpreter_results emitted" + assert len(code_results) == 1 + assert code_results[0].id == "srvtoolu_01AAA" + assert code_results[0].code == "echo hello" + assert code_results[0].outputs[0].logs == "hello\n" From 4be1d76fd7da5c4c0b04f1ead59e98bcf54d1dfb Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Tue, 17 Mar 2026 18:37:32 +0100 Subject: [PATCH 101/539] fix: empty stdout/stderr produces str(content) instead of empty logs When both stdout and stderr are empty strings, the `if parts else str(content)` fallback produced the raw dict representation as logs. Drop the fallback so logs is correctly empty. --- litellm/llms/anthropic/chat/handler.py | 2 +- litellm/llms/anthropic/chat/transformation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 7d5fa2a559..88d9ee6596 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -711,7 +711,7 @@ class ModelResponseIterator: parts.append(content["stdout"]) if content.get("stderr"): parts.append(f"STDERR: {content['stderr']}") - logs = "".join(parts) if parts else str(content) + logs = "".join(parts) else: logs = str(content) tool_input = self._server_tool_inputs.get(call_id, {}) diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 033afea2ff..21ca8db825 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1775,7 +1775,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): parts.append(content["stdout"]) if content.get("stderr"): parts.append(f"STDERR: {content['stderr']}") - logs = "".join(parts) if parts else str(content) + logs = "".join(parts) else: logs = str(content) code_interpreter_results.append( From 5b3e84f383627c951c4d7bf5f150e4173c39fdc4 Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Tue, 17 Mar 2026 18:59:47 +0100 Subject: [PATCH 102/539] fix: address remaining review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Empty stdout/stderr now produces outputs=None (matching OpenAI parity) instead of outputs=[{logs:""}], in both streaming and non-streaming paths - Fix test fixture to use real Anthropic type "bash_code_execution_tool_result" instead of "code_execution_tool_result" - Add test for empty-output → outputs=None behavior - Add unit tests for _extract_tool_result_output_items: Pydantic objects, plain dicts (post-model_dump), empty/missing provider_specific_fields, and in-place substitution preserving output ordering --- litellm/llms/anthropic/chat/handler.py | 5 +- litellm/llms/anthropic/chat/transformation.py | 9 +- .../chat/test_anthropic_chat_handler.py | 72 +++++++- ...est_code_interpreter_results_extraction.py | 163 ++++++++++++++++++ 4 files changed, 243 insertions(+), 6 deletions(-) create mode 100644 tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 88d9ee6596..0b84830a8b 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -716,6 +716,9 @@ class ModelResponseIterator: logs = str(content) tool_input = self._server_tool_inputs.get(call_id, {}) code = tool_input.get("command", "") if isinstance(tool_input, dict) else "" + log_outputs = ( + [OutputCodeInterpreterCallLog(type="logs", logs=logs)] if logs else None + ) results.append( OutputCodeInterpreterCall( type="code_interpreter_call", @@ -723,7 +726,7 @@ class ModelResponseIterator: code=code, container_id=None, status="completed", - outputs=[OutputCodeInterpreterCallLog(type="logs", logs=logs)], + outputs=log_outputs, ) ) return results diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 21ca8db825..99b02e7ab4 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1778,6 +1778,11 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): logs = "".join(parts) else: logs = str(content) + log_outputs = ( + [OutputCodeInterpreterCallLog(type="logs", logs=logs)] + if logs + else None + ) code_interpreter_results.append( OutputCodeInterpreterCall( type="code_interpreter_call", @@ -1785,9 +1790,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): code=code_by_id.get(call_id, ""), container_id=container_id, status="completed", - outputs=[ - OutputCodeInterpreterCallLog(type="logs", logs=logs) - ], + outputs=log_outputs, ) ) provider_specific_fields["code_interpreter_results"] = ( diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py index d7a04a054a..ab298f6809 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py @@ -1410,10 +1410,10 @@ def test_streaming_code_execution_input_assembled_from_deltas(): "type": "content_block_start", "index": 1, "content_block": { - "type": "code_execution_tool_result", + "type": "bash_code_execution_tool_result", "tool_use_id": "srvtoolu_01AAA", "content": { - "type": "code_execution_result", + "type": "bash_code_execution_result", "stdout": "hello\n", "stderr": "", "return_code": 0, @@ -1445,3 +1445,71 @@ def test_streaming_code_execution_input_assembled_from_deltas(): assert code_results[0].id == "srvtoolu_01AAA" assert code_results[0].code == "echo hello" assert code_results[0].outputs[0].logs == "hello\n" + + +def test_empty_output_produces_null_outputs(): + """ + When both stdout and stderr are empty, outputs should be None + (matching OpenAI's native behavior) rather than [{logs: ""}]. + """ + chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01AAA", + "name": "bash_code_execution", + "input": {"command": "true"}, + }, + }, + {"type": "content_block_stop", "index": 0}, + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01AAA", + "content": { + "type": "bash_code_execution_result", + "stdout": "", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 1}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + iterator = ModelResponseIterator(None, sync_stream=True) + + code_results = None + for chunk in chunks: + parsed = iterator.chunk_parser(chunk) + psf = None + if parsed.choices and parsed.choices[0].delta: + psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) + if psf and "code_interpreter_results" in psf: + code_results = psf["code_interpreter_results"] + + assert code_results is not None, "No code_interpreter_results emitted" + assert len(code_results) == 1 + assert code_results[0].id == "srvtoolu_01AAA" + assert ( + code_results[0].outputs is None + ), f"Expected outputs=None for empty execution, got {code_results[0].outputs}" diff --git a/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py new file mode 100644 index 0000000000..c6ff1c7af8 --- /dev/null +++ b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py @@ -0,0 +1,163 @@ +""" +Tests for the Responses API _extract_tool_result_output_items path +and the non-streaming _hidden_params propagation of code_interpreter_results. +""" + +from unittest.mock import MagicMock + +from litellm.responses.litellm_completion_transformation.transformation import ( + LiteLLMCompletionResponsesConfig, +) +from litellm.types.responses.main import ( + OutputCodeInterpreterCall, + OutputCodeInterpreterCallLog, +) +from litellm.types.utils import Choices, Message, ModelResponse + + +def _make_model_response(code_interpreter_results=None, provider_specific_fields=None): + """Helper to build a ModelResponse with provider_specific_fields on the message.""" + psf = provider_specific_fields or {} + if code_interpreter_results is not None: + psf["code_interpreter_results"] = code_interpreter_results + msg = Message(content="test", provider_specific_fields=psf if psf else None) + choice = Choices(index=0, message=msg, finish_reason="stop") + resp = ModelResponse() + resp.choices = [choice] + return resp + + +def test_extract_tool_result_output_items_from_pydantic_objects(): + """Non-streaming path: code_interpreter_results are Pydantic OutputCodeInterpreterCall objects.""" + items = [ + OutputCodeInterpreterCall( + type="code_interpreter_call", + id="srvtoolu_01AAA", + code="echo hello", + container_id=None, + status="completed", + outputs=[OutputCodeInterpreterCallLog(type="logs", logs="hello\n")], + ), + OutputCodeInterpreterCall( + type="code_interpreter_call", + id="srvtoolu_01BBB", + code="echo world", + container_id=None, + status="completed", + outputs=[OutputCodeInterpreterCallLog(type="logs", logs="world\n")], + ), + ] + resp = _make_model_response(code_interpreter_results=items) + result = LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) + assert len(result) == 2 + assert result[0].id == "srvtoolu_01AAA" + assert result[1].id == "srvtoolu_01BBB" + + +def test_extract_tool_result_output_items_from_dicts(): + """Streaming path: after model_dump(), code_interpreter_results are plain dicts.""" + items = [ + { + "type": "code_interpreter_call", + "id": "srvtoolu_01AAA", + "code": "echo hello", + "container_id": None, + "status": "completed", + "outputs": [{"type": "logs", "logs": "hello\n"}], + }, + ] + resp = _make_model_response(code_interpreter_results=items) + result = LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) + assert len(result) == 1 + assert result[0]["id"] == "srvtoolu_01AAA" + + +def test_extract_tool_result_output_items_empty(): + """No code_interpreter_results → empty list.""" + resp = _make_model_response() + result = LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) + assert result == [] + + +def test_extract_tool_result_output_items_no_provider_specific_fields(): + """Message with no provider_specific_fields → empty list.""" + msg = Message(content="test") + choice = Choices(index=0, message=msg, finish_reason="stop") + resp = ModelResponse() + resp.choices = [choice] + result = LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) + assert result == [] + + +def test_in_place_substitution_preserves_ordering(): + """ + function_call items matching code_interpreter_results should be replaced + in-place, preserving the original output ordering. + + Simulates: [message, function_call(exec1), function_call(regular), function_call(exec2)] + Expected: [message, code_interpreter_call(exec1), function_call(regular), code_interpreter_call(exec2)] + """ + code_results = [ + OutputCodeInterpreterCall( + type="code_interpreter_call", + id="srvtoolu_01AAA", + code="echo first", + container_id=None, + status="completed", + outputs=[OutputCodeInterpreterCallLog(type="logs", logs="first\n")], + ), + OutputCodeInterpreterCall( + type="code_interpreter_call", + id="srvtoolu_01CCC", + code="echo third", + container_id=None, + status="completed", + outputs=[OutputCodeInterpreterCallLog(type="logs", logs="third\n")], + ), + ] + resp = _make_model_response(code_interpreter_results=code_results) + + # Build a mock responses_output list with interleaved items + class MockItem: + def __init__(self, type, call_id=None): + self.type = type + self.call_id = call_id + + msg_item = MockItem(type="message") + fc_exec1 = MockItem(type="function_call", call_id="srvtoolu_01AAA") + fc_regular = MockItem(type="function_call", call_id="srvtoolu_01BBB") + fc_exec2 = MockItem(type="function_call", call_id="srvtoolu_01CCC") + + responses_output = [msg_item, fc_exec1, fc_regular, fc_exec2] + + # Apply the same logic as _transform_chat_completion_choices_to_responses_output + tool_result_items = ( + LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) + ) + if tool_result_items: + result_by_id = { + (item.get("id") if isinstance(item, dict) else item.id): item + for item in tool_result_items + } + replaced_ids = set(result_by_id.keys()) + responses_output = [ + ( + result_by_id[getattr(item, "call_id", None)] + if ( + getattr(item, "type", None) == "function_call" + and getattr(item, "call_id", None) in replaced_ids + ) + else item + ) + for item in responses_output + ] + + # Verify ordering: message, code_interpreter(AAA), function_call(BBB), code_interpreter(CCC) + assert len(responses_output) == 4 + assert responses_output[0].type == "message" + assert responses_output[1].type == "code_interpreter_call" + assert responses_output[1].id == "srvtoolu_01AAA" + assert responses_output[2].type == "function_call" + assert responses_output[2].call_id == "srvtoolu_01BBB" + assert responses_output[3].type == "code_interpreter_call" + assert responses_output[3].id == "srvtoolu_01CCC" From 3962fbc33ac12c8a6e232dea488779f5570cb5f2 Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Tue, 17 Mar 2026 19:26:03 +0100 Subject: [PATCH 103/539] fix: non-dict tool result content falls back to outputs=None Replace str(content) fallback with empty string so non-dict content (e.g. list-shaped text_editor results) produces outputs=None instead of raw Python object representations in logs. --- litellm/llms/anthropic/chat/handler.py | 2 +- litellm/llms/anthropic/chat/transformation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 0b84830a8b..387901588f 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -713,7 +713,7 @@ class ModelResponseIterator: parts.append(f"STDERR: {content['stderr']}") logs = "".join(parts) else: - logs = str(content) + logs = "" tool_input = self._server_tool_inputs.get(call_id, {}) code = tool_input.get("command", "") if isinstance(tool_input, dict) else "" log_outputs = ( diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 99b02e7ab4..238d72eda4 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1777,7 +1777,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): parts.append(f"STDERR: {content['stderr']}") logs = "".join(parts) else: - logs = str(content) + logs = "" log_outputs = ( [OutputCodeInterpreterCallLog(type="logs", logs=logs)] if logs From 8f60117228821ccde41d4845382afee507dfb70c Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Wed, 18 Mar 2026 00:42:03 +0100 Subject: [PATCH 104/539] fix: guard code_interpreter conversion to bash_code_execution results only Skip non-bash tool result types (e.g. text_editor_code_execution_tool_result) to avoid producing empty code_interpreter_call items in Responses API output. --- litellm/llms/anthropic/chat/handler.py | 2 ++ litellm/llms/anthropic/chat/transformation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 387901588f..91fd303406 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -703,6 +703,8 @@ class ModelResponseIterator: """ results = [] for tr in self.tool_results: + if tr.get("type") != "bash_code_execution_tool_result": + continue call_id = tr.get("tool_use_id", "") content = tr.get("content", {}) if isinstance(content, dict): diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 238d72eda4..ca101df0e9 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1767,6 +1767,8 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): pass code_interpreter_results = [] for tr in tool_results: + if tr.get("type") != "bash_code_execution_tool_result": + continue call_id = tr.get("tool_use_id", "") content = tr.get("content", {}) if isinstance(content, dict): From d10007cef49ae278e97c05e40ce367481e983975 Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Wed, 18 Mar 2026 11:15:15 +0100 Subject: [PATCH 105/539] test: add non-bash skip test and mock end-to-end streaming integration test - test_non_bash_tool_result_skipped: verifies text_editor results produce zero code_interpreter_call items - test_end_to_end_streaming_chunks_to_code_interpreter_output: exercises full path from Anthropic SSE chunks through ModelResponseIterator, stream_chunk_builder, and _extract_tool_result_output_items without a live server --- .../chat/test_anthropic_chat_handler.py | 68 +++++++++++ ...est_code_interpreter_results_extraction.py | 106 +++++++++++++++++- 2 files changed, 172 insertions(+), 2 deletions(-) diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py index ab298f6809..20427e8cc9 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py @@ -1513,3 +1513,71 @@ def test_empty_output_produces_null_outputs(): assert ( code_results[0].outputs is None ), f"Expected outputs=None for empty execution, got {code_results[0].outputs}" + + +def test_non_bash_tool_result_skipped(): + """ + Tool result types other than bash_code_execution_tool_result (e.g. + text_editor_code_execution_tool_result) should be skipped and NOT + produce code_interpreter_call items. + """ + chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01AAA", + "name": "text_editor", + "input": {"command": "view", "path": "/tmp/test.py"}, + }, + }, + {"type": "content_block_stop", "index": 0}, + # text_editor result — should NOT become a code_interpreter_call + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "text_editor_code_execution_tool_result", + "tool_use_id": "srvtoolu_01AAA", + "content": [ + {"type": "text", "text": "file contents here"}, + ], + }, + }, + {"type": "content_block_stop", "index": 1}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + iterator = ModelResponseIterator(None, sync_stream=True) + + code_results = None + for chunk in chunks: + parsed = iterator.chunk_parser(chunk) + psf = None + if parsed.choices and parsed.choices[0].delta: + psf = getattr(parsed.choices[0].delta, "provider_specific_fields", None) + if psf and "code_interpreter_results" in psf: + code_results = psf["code_interpreter_results"] + + # code_interpreter_results should be emitted but empty (no bash results) + assert ( + code_results is not None + ), "Expected code_interpreter_results key to be emitted" + assert ( + len(code_results) == 0 + ), f"Expected 0 code_interpreter_results for text_editor result, got {len(code_results)}" diff --git a/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py index c6ff1c7af8..eea9be38fa 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py +++ b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py @@ -1,10 +1,13 @@ """ -Tests for the Responses API _extract_tool_result_output_items path -and the non-streaming _hidden_params propagation of code_interpreter_results. +Tests for the Responses API _extract_tool_result_output_items path, +the non-streaming _hidden_params propagation of code_interpreter_results, +and mock end-to-end streaming integration. """ from unittest.mock import MagicMock +from litellm.llms.anthropic.chat.handler import ModelResponseIterator +from litellm.main import stream_chunk_builder from litellm.responses.litellm_completion_transformation.transformation import ( LiteLLMCompletionResponsesConfig, ) @@ -161,3 +164,102 @@ def test_in_place_substitution_preserves_ordering(): assert responses_output[2].call_id == "srvtoolu_01BBB" assert responses_output[3].type == "code_interpreter_call" assert responses_output[3].id == "srvtoolu_01CCC" + + +def test_end_to_end_streaming_chunks_to_code_interpreter_output(): + """ + Mock end-to-end test: Anthropic SSE chunks → ModelResponseIterator → + stream_chunk_builder → _extract_tool_result_output_items → final output + with code_interpreter_call items replacing function_call items. + + This exercises the full streaming data flow without a live server. + """ + # Realistic Anthropic streaming chunks for a single code execution + raw_chunks = [ + { + "type": "message_start", + "message": { + "id": "msg_01XYZ", + "type": "message", + "role": "assistant", + "content": [], + "usage": {"input_tokens": 100, "output_tokens": 1}, + }, + }, + { + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "server_tool_use", + "id": "srvtoolu_01AAA", + "name": "bash_code_execution", + "input": {}, + }, + }, + { + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "input_json_delta", + "partial_json": '{"command": "echo e2e_test"}', + }, + }, + {"type": "content_block_stop", "index": 0}, + { + "type": "content_block_start", + "index": 1, + "content_block": { + "type": "bash_code_execution_tool_result", + "tool_use_id": "srvtoolu_01AAA", + "content": { + "type": "bash_code_execution_result", + "stdout": "e2e_test\n", + "stderr": "", + "return_code": 0, + }, + }, + }, + {"type": "content_block_stop", "index": 1}, + { + "type": "message_delta", + "delta": {"stop_reason": "end_turn"}, + "usage": {"output_tokens": 50}, + }, + ] + + # Step 1: Parse chunks through ModelResponseIterator (Anthropic handler) + iterator = ModelResponseIterator(None, sync_stream=True) + parsed_chunks = [] + for chunk in raw_chunks: + parsed = iterator.chunk_parser(chunk) + d = parsed.model_dump() + # In production, CustomStreamWrapper sets the model on each chunk; + # stream_chunk_builder requires it. + d["model"] = "claude-sonnet-4-20250514" + parsed_chunks.append(d) + + # Step 2: Assemble via stream_chunk_builder (simulates end-of-stream) + assembled = stream_chunk_builder(chunks=parsed_chunks) + assert assembled is not None + + # Verify stream_chunk_builder picked up code_interpreter_results via last-value-wins + psf = assembled.choices[0].message.provider_specific_fields + assert psf is not None + assert "code_interpreter_results" in psf + code_results = psf["code_interpreter_results"] + assert len(code_results) == 1 + # After model_dump + stream_chunk_builder, results are plain dicts + assert code_results[0]["id"] == "srvtoolu_01AAA" + assert code_results[0]["code"] == "echo e2e_test" + + # Step 3: Extract via _extract_tool_result_output_items (Responses API layer) + tool_result_items = ( + LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(assembled) + ) + assert len(tool_result_items) == 1 + item = tool_result_items[0] + # Items are dicts after the model_dump path + assert item["type"] == "code_interpreter_call" + assert item["id"] == "srvtoolu_01AAA" + assert item["code"] == "echo e2e_test" + assert item["outputs"][0]["logs"] == "e2e_test\n" From cf8d1ac521648fea10bc121cd51da166e96493a4 Mon Sep 17 00:00:00 2001 From: Andrzej Pomirski Date: Wed, 18 Mar 2026 12:05:25 +0100 Subject: [PATCH 106/539] fix: streaming container_id and consistent Pydantic types in output - Populate container_id on streaming code_interpreter_results by re-emitting at message_delta when container info arrives - Reconstruct Pydantic OutputCodeInterpreterCall objects from plain dicts in _extract_tool_result_output_items so responses_output has uniform types across streaming and non-streaming paths --- litellm/llms/anthropic/chat/handler.py | 14 +++++++++++++- .../transformation.py | 14 +++++++++----- .../test_code_interpreter_results_extraction.py | 17 ++++++++++------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 91fd303406..70ecf91725 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -546,6 +546,7 @@ class ModelResponseIterator: self._server_tool_inputs: Dict[str, Any] = {} self.tool_results: List[Dict[str, Any]] = [] self._current_server_tool_id: Optional[str] = None + self._container_id: Optional[str] = None def check_empty_tool_call_args(self) -> bool: """ @@ -726,7 +727,7 @@ class ModelResponseIterator: type="code_interpreter_call", id=call_id, code=code, - container_id=None, + container_id=self._container_id, status="completed", outputs=log_outputs, ) @@ -928,6 +929,17 @@ class ModelResponseIterator: finish_reason, usage, container = self._handle_message_delta(chunk) if container: provider_specific_fields["container"] = container + # Store container_id and re-emit code_interpreter_results + # so stream_chunk_builder's last-value-wins picks up the + # version with container_id populated. + container_id = ( + container.get("id") if isinstance(container, dict) else None + ) + if container_id and self.tool_results: + self._container_id = container_id + provider_specific_fields["code_interpreter_results"] = ( + self._build_code_interpreter_results() + ) elif type_chunk == "message_start": """ Anthropic diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index b7f7e9adda..cf18511bfa 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -1738,10 +1738,7 @@ class LiteLLMCompletionResponsesConfig: ) ) if tool_result_items: - result_by_id = { - (item.get("id") if isinstance(item, dict) else item.id): item - for item in tool_result_items - } + result_by_id = {item.id: item for item in tool_result_items} replaced_ids = set(result_by_id.keys()) responses_output = [ ( @@ -1778,7 +1775,14 @@ class LiteLLMCompletionResponsesConfig: continue results = psf.get("code_interpreter_results") if results and isinstance(results, list): - output_items.extend(results) + for item in results: + # In the streaming path, items are plain dicts after + # model_dump() in stream_chunk_builder. Reconstruct + # Pydantic objects so responses_output has a uniform type. + if isinstance(item, dict): + output_items.append(OutputCodeInterpreterCall(**item)) + else: + output_items.append(item) return output_items @staticmethod diff --git a/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py index eea9be38fa..60e45c9b8c 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py +++ b/tests/test_litellm/llms/anthropic/chat/test_code_interpreter_results_extraction.py @@ -58,7 +58,8 @@ def test_extract_tool_result_output_items_from_pydantic_objects(): def test_extract_tool_result_output_items_from_dicts(): - """Streaming path: after model_dump(), code_interpreter_results are plain dicts.""" + """Streaming path: after model_dump(), code_interpreter_results are plain dicts. + _extract_tool_result_output_items reconstructs them as Pydantic objects.""" items = [ { "type": "code_interpreter_call", @@ -72,7 +73,8 @@ def test_extract_tool_result_output_items_from_dicts(): resp = _make_model_response(code_interpreter_results=items) result = LiteLLMCompletionResponsesConfig._extract_tool_result_output_items(resp) assert len(result) == 1 - assert result[0]["id"] == "srvtoolu_01AAA" + assert isinstance(result[0], OutputCodeInterpreterCall) + assert result[0].id == "srvtoolu_01AAA" def test_extract_tool_result_output_items_empty(): @@ -258,8 +260,9 @@ def test_end_to_end_streaming_chunks_to_code_interpreter_output(): ) assert len(tool_result_items) == 1 item = tool_result_items[0] - # Items are dicts after the model_dump path - assert item["type"] == "code_interpreter_call" - assert item["id"] == "srvtoolu_01AAA" - assert item["code"] == "echo e2e_test" - assert item["outputs"][0]["logs"] == "e2e_test\n" + # Items are reconstructed as Pydantic OutputCodeInterpreterCall objects + assert isinstance(item, OutputCodeInterpreterCall) + assert item.type == "code_interpreter_call" + assert item.id == "srvtoolu_01AAA" + assert item.code == "echo e2e_test" + assert item.outputs[0].logs == "e2e_test\n" From 021540b2e2d7083cecfadfdcc4cd1f08abb43e31 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 17:09:28 +0530 Subject: [PATCH 107/539] fix: prevent double prompt management in async path, preserve optional params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - aresponses() now pops prompt_id from kwargs after the async hook runs and passes merged_optional_params via _async_prompt_merged_params. responses() checks for this internal kwarg first and skips the sync hook entirely when present — eliminating double-merge of template messages. - merged_optional_params from async_get_chat_completion_prompt is no longer discarded (_); it flows through to local_vars in responses(). - Async tests now assert get_chat_completion_prompt.assert_not_called() to directly detect any double-execution regression. Made-with: Cursor --- litellm/responses/main.py | 97 +++++++++++-------- .../test_responses_prompt_management.py | 22 +++-- 2 files changed, 69 insertions(+), 50 deletions(-) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 0e56836355..af2976cd54 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -466,6 +466,11 @@ async def aresponses( ######################################################### # ASYNC PROMPT MANAGEMENT + # Run the async hook here so async-only prompt loggers are honoured. + # Then pop prompt_id from kwargs so the sync responses() path does NOT + # re-run the hook (which would double-prepend template messages). + # Pass merged_optional_params via an internal kwarg so responses() + # can apply them to local_vars without re-invoking the hook. ######################################################### litellm_logging_obj = kwargs.get("litellm_logging_obj", None) prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) @@ -488,7 +493,7 @@ async def aresponses( ( model, merged_input, - _, + merged_optional_params, ) = await litellm_logging_obj.async_get_chat_completion_prompt( model=model, messages=client_input, @@ -503,6 +508,8 @@ async def aresponses( _, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model ) + kwargs.pop("prompt_id", None) + kwargs["_async_prompt_merged_params"] = merged_optional_params func = partial( responses, @@ -666,47 +673,57 @@ def responses( ######################################################### # PROMPT MANAGEMENT + # If aresponses() already ran the async hook, it pops prompt_id and + # passes the result via _async_prompt_merged_params — apply those + # directly and skip the sync hook to avoid double-merging. ######################################################### - prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) - prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) - original_model = model - - if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( - prompt_id=prompt_id, non_default_params=kwargs - ): - if isinstance(input, str): - client_input: List[AllMessageValues] = [ - {"role": "user", "content": input} - ] - else: - client_input = [ - item # type: ignore[misc] - for item in input - if isinstance(item, dict) and "role" in item - ] - ( - model, - merged_input, - merged_optional_params, - ) = litellm_logging_obj.get_chat_completion_prompt( - model=model, - messages=client_input, - non_default_params=kwargs, - prompt_id=prompt_id, - prompt_variables=prompt_variables, - prompt_label=kwargs.get("prompt_label", None), - prompt_version=kwargs.get("prompt_version", None), - ) - input = cast(Union[str, ResponseInputParam], merged_input) - local_vars["input"] = input - local_vars["model"] = model - if model != original_model: - _, custom_llm_provider, _, _ = litellm.get_llm_provider( - model=model - ) - local_vars["custom_llm_provider"] = custom_llm_provider - for k, v in merged_optional_params.items(): + _async_merged = kwargs.pop("_async_prompt_merged_params", None) + if _async_merged is not None: + for k, v in _async_merged.items(): local_vars[k] = v + else: + prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) + prompt_variables = cast( + Optional[dict], kwargs.get("prompt_variables", None) + ) + original_model = model + + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and litellm_logging_obj.should_run_prompt_management_hooks( + prompt_id=prompt_id, non_default_params=kwargs + ): + if isinstance(input, str): + client_input: List[AllMessageValues] = [ + {"role": "user", "content": input} + ] + else: + client_input = [ + item # type: ignore[misc] + for item in input + if isinstance(item, dict) and "role" in item + ] + ( + model, + merged_input, + merged_optional_params, + ) = litellm_logging_obj.get_chat_completion_prompt( + model=model, + messages=client_input, + non_default_params=kwargs, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + prompt_label=kwargs.get("prompt_label", None), + prompt_version=kwargs.get("prompt_version", None), + ) + input = cast(Union[str, ResponseInputParam], merged_input) + local_vars["input"] = input + local_vars["model"] = model + if model != original_model: + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=model + ) + local_vars["custom_llm_provider"] = custom_llm_provider + for k, v in merged_optional_params.items(): + local_vars[k] = v ######################################################### # Update input and tools with provider-specific file IDs if managed files are used diff --git a/tests/test_litellm/responses/test_responses_prompt_management.py b/tests/test_litellm/responses/test_responses_prompt_management.py index 9defaceed8..f49679fc40 100644 --- a/tests/test_litellm/responses/test_responses_prompt_management.py +++ b/tests/test_litellm/responses/test_responses_prompt_management.py @@ -288,16 +288,16 @@ class TestResponsesAPIPromptManagement: class TestAsyncResponsesAPIPromptManagement: """Tests for the async aresponses() prompt management path. - aresponses() calls async_get_chat_completion_prompt at the outer async level - (for async-only prompt loggers), then delegates to responses() via - run_in_executor where the sync hook also runs — mirroring acompletion() in - main.py. Optional params are handled by the sync responses() path. + aresponses() calls async_get_chat_completion_prompt at the outer async + level, then pops prompt_id from kwargs and passes merged_optional_params + via an internal kwarg. The sync responses() path sees no prompt_id and + skips the sync hook entirely — preventing double-merge of template messages. """ @pytest.mark.asyncio - async def test_async_calls_async_hook(self): - """[H] aresponses() invokes async_get_chat_completion_prompt before - dispatching to the sync responses() path.""" + async def test_async_calls_async_hook_not_sync(self): + """[H] aresponses() invokes async_get_chat_completion_prompt and the + sync get_chat_completion_prompt is NOT called (no double-merge).""" template_messages: List[AllMessageValues] = [ {"role": "system", "content": "You are helpful."}, # type: ignore[list-item] ] @@ -318,14 +318,14 @@ class TestAsyncResponsesAPIPromptManagement: ) logging_obj.async_get_chat_completion_prompt.assert_called_once() + logging_obj.get_chat_completion_prompt.assert_not_called() call_kwargs = logging_obj.async_get_chat_completion_prompt.call_args.kwargs assert call_kwargs["prompt_id"] == "async-test" @pytest.mark.asyncio async def test_async_optional_params_propagated(self): - """[I] Template-defined optional params (e.g. temperature) reach the downstream - handler when called via aresponses(). The sync responses() path applies them - via local_vars.""" + """[I] Template-defined optional params (e.g. temperature) from the async + hook reach the downstream handler — they are NOT silently discarded.""" template_messages: List[AllMessageValues] = [ {"role": "user", "content": "Hello"}, # type: ignore[list-item] ] @@ -345,6 +345,7 @@ class TestAsyncResponsesAPIPromptManagement: litellm_logging_obj=logging_obj, ) + logging_obj.get_chat_completion_prompt.assert_not_called() handler_call_kwargs = mock_handler.call_args.kwargs request_params = handler_call_kwargs.get("responses_api_request", {}) assert request_params.get("temperature") == 0.7 @@ -375,6 +376,7 @@ class TestAsyncResponsesAPIPromptManagement: ) logging_obj.async_get_chat_completion_prompt.assert_called_once() + logging_obj.get_chat_completion_prompt.assert_not_called() call_kwargs = logging_obj.async_get_chat_completion_prompt.call_args.kwargs passed_messages = call_kwargs["messages"] assert all(isinstance(m, dict) and "role" in m for m in passed_messages) From f29b4981a0fca42eb7c3918da1649a9036d74e49 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 17:13:42 +0530 Subject: [PATCH 108/539] fix(prompting): preserve separator for assistant(tc)->assistant edge case When scanning backward over counted messages, preserve old behavior for adjacent assistant turns by inserting user_continue if the immediate previous raw message is assistant. This handles malformed assistant(tool_calls)->assistant(no-tool-calls) inputs without splitting valid assistant(tool_calls)->tool chains. Made-with: Cursor --- .../prompt_templates/common_utils.py | 27 +++++++---- tests/llm_translation/test_prompt_factory.py | 48 +++++++++++++++++++ 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index eb3755b71f..2efd90e0c2 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -298,16 +298,23 @@ def _insert_user_continue_message( curr_message = result_messages[i] inserted_continue_message = False if _counts_for_alternation(curr_message) and curr_message["role"] == "assistant": - j = i - 1 - while j >= 0: - previous_message = result_messages[j] - if _counts_for_alternation(previous_message): - if previous_message["role"] == "assistant": - result_messages.insert(i, continue_message) - i += 2 - inserted_continue_message = True - break - j -= 1 + # Preserve old behavior for malformed adjacent assistant sequences like + # assistant(tool_calls) -> assistant(no-tool-calls) with no tool message. + if i > 0 and result_messages[i - 1].get("role") == "assistant": + result_messages.insert(i, continue_message) + i += 2 + inserted_continue_message = True + else: + j = i - 1 + while j >= 0: + previous_message = result_messages[j] + if _counts_for_alternation(previous_message): + if previous_message["role"] == "assistant": + result_messages.insert(i, continue_message) + i += 2 + inserted_continue_message = True + break + j -= 1 if not inserted_continue_message: i += 1 diff --git a/tests/llm_translation/test_prompt_factory.py b/tests/llm_translation/test_prompt_factory.py index fe46c24a29..64556c3f26 100644 --- a/tests/llm_translation/test_prompt_factory.py +++ b/tests/llm_translation/test_prompt_factory.py @@ -903,6 +903,54 @@ def test_ensure_alternating_roles_does_not_split_tool_call_chain(): ] +def test_ensure_alternating_roles_assistant_tool_call_then_assistant(): + """ + Preserve old behavior for malformed adjacent assistant turns: + [assistant(tool_calls), assistant(no-tool-calls), user] should insert + user_continue between assistant messages. + """ + messages = [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + {"role": "assistant", "content": "Here's what I found."}, + {"role": "user", "content": "Thanks"}, + ] + + transformed_messages = get_completion_messages( + messages=messages, + assistant_continue_message=None, + user_continue_message=None, + ensure_alternating_roles=True, + ) + + assert transformed_messages == [ + {"role": "user", "content": "Please continue."}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + {"role": "user", "content": "Please continue."}, + {"role": "assistant", "content": "Here's what I found."}, + {"role": "user", "content": "Thanks"}, + ] + + def test_ensure_alternating_roles_trailing_tool_call_assistant(): messages = [ {"role": "user", "content": "What's the weather?"}, From 17efd96e6173bd2eee3ee5cc8ee1640dc0799cb4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 17:49:03 +0530 Subject: [PATCH 109/539] docs(vertex): add concise PayGo/Priority guide with cost-tracking flow Document how to send Vertex Priority PayGo headers and explain how trafficType maps to service-tier pricing in LiteLLM, including an embedded flow diagram for quick understanding. Made-with: Cursor --- .../docs/tutorials/vertex_ai_pay_go.md | 151 ++++++++++++++++++ .../static/img/vertex_cost_tracking_flow.svg | 62 +++++++ 2 files changed, 213 insertions(+) create mode 100644 docs/my-website/docs/tutorials/vertex_ai_pay_go.md create mode 100644 docs/my-website/static/img/vertex_cost_tracking_flow.svg diff --git a/docs/my-website/docs/tutorials/vertex_ai_pay_go.md b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md new file mode 100644 index 0000000000..625aff35e1 --- /dev/null +++ b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md @@ -0,0 +1,151 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Vertex AI PayGo and Priority + +## Priority PayGo + +LiteLLM supports Priority PayGo. +Send a priority header, get priority queueing, and pay priority token rates. + +:::info Which models support Priority PayGo? +As of this writing: `gemini/gemini-2.5-pro`, `vertex_ai/gemini-3-pro-preview`, `vertex_ai/gemini-3.1-pro-preview`, `vertex_ai/gemini-3-flash-preview`, and their variants. +Check `supports_service_tier: true` in LiteLLM's [model pricing JSON](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). +::: + +### Send a priority request + +Use this header: + +`X-Vertex-AI-LLM-Shared-Request-Type: priority` + + + + +```python +import litellm + +response = litellm.completion( + model="vertex_ai/gemini-3-pro-preview", + messages=[{"role": "user", "content": "Summarize the Gettysburg Address."}], + vertex_project="YOUR_PROJECT_ID", + vertex_location="us-central1", + extra_headers={"X-Vertex-AI-LLM-Shared-Request-Type": "priority"}, +) + +print(response.choices[0].message.content) +``` + + + + +```yaml title="config.yaml" +model_list: + - model_name: gemini-priority + litellm_params: + model: vertex_ai/gemini-3-pro-preview + vertex_project: "YOUR_PROJECT_ID" + vertex_location: "us-central1" + vertex_credentials: os.environ/GOOGLE_APPLICATION_CREDENTIALS + extra_headers: + X-Vertex-AI-LLM-Shared-Request-Type: priority +``` + +```bash +curl http://localhost:4000/v1/chat/completions \ + -H "Authorization: Bearer sk-your-key" \ + -H "Content-Type: application/json" \ + -d '{"model": "gemini-priority", "messages": [{"role": "user", "content": "Hello"}]}' +``` + + + + +Use `x-pass-` so LiteLLM forwards provider-specific headers. + +```bash +MODEL_ID="gemini-3-pro-preview-0325" +PROJECT_ID="YOUR_PROJECT_ID" + +curl -X POST \ + "${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/global/publishers/google/models/${MODEL_ID}:generateContent" \ + -H "Authorization: Bearer sk-your-litellm-key" \ + -H "Content-Type: application/json" \ + -H "x-pass-X-Vertex-AI-LLM-Shared-Request-Type: priority" \ + -d '{"contents": [{"role": "user", "parts": [{"text": "Hello!"}]}]}' +``` + + + + +### How cost tracking works + +![Vertex AI Priority PayGo Cost Tracking Flow](/img/vertex_cost_tracking_flow.svg) + +**`trafficType` → `service_tier` mapping** + +| `usageMetadata.trafficType` | `service_tier` | Pricing keys used | +|---|---|---| +| `ON_DEMAND` | `None` | `input_cost_per_token` | +| `ON_DEMAND_PRIORITY` | `"priority"` | `input_cost_per_token_priority` | +| `FLEX` / `BATCH` | `"flex"` | `input_cost_per_token_flex` | + +If a tier-specific key is missing, LiteLLM falls back to standard pricing keys. + +--- + +## Standard PayGo vs Provisioned Throughput + +This is a different header from priority routing: + +| Header value | Behavior | +|---|---| +| `X-Vertex-AI-LLM-Request-Type: shared` | Force standard PayGo (bypass PT) | +| `X-Vertex-AI-LLM-Request-Type: dedicated` | Force Provisioned Throughput only (`429` if exhausted) | + +### Native route example + +```python +import litellm + +response = litellm.completion( + model="vertex_ai/gemini-2.0-flash", + messages=[{"role": "user", "content": "Hello!"}], + vertex_project="YOUR_PROJECT_ID", + vertex_location="us-central1", + extra_headers={"X-Vertex-AI-LLM-Request-Type": "shared"}, +) +``` + +### Pass-through example + +```bash +MODEL_ID="gemini-2.0-flash-001" +PROJECT_ID="YOUR_PROJECT_ID" + +curl -X POST \ + "${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:generateContent" \ + -H "Authorization: Bearer sk-your-litellm-key" \ + -H "Content-Type: application/json" \ + -H "x-pass-X-Vertex-AI-LLM-Request-Type: shared" \ + -d '{ + "contents": [{"role": "user", "parts": [{"text": "Hello!"}]}] + }' +``` + +--- + +## Troubleshooting + +**Q: What does `403 Permission denied` or `IAM_PERMISSION_DENIED` mean?** +A: The service account or Application Default Credentials (ADC) user does not have the `roles/aiplatform.user` role. To resolve this, re-run the `gcloud projects add-iam-policy-binding` command as shown above in the guide. + +**Q: What should I do if I get a `429 Quota exceeded` error?** +A: This means you've hit the per-region QPM (queries per minute) or TPM (tokens per minute) quota. You can: +- Request a quota increase from the [GCP Quotas console](https://console.cloud.google.com/iam-admin/quotas) +- Add more regions to your LiteLLM configuration for load balancing (see the region balancing guide above) +- Upgrade to [Provisioned Throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput) for guaranteed capacity + +**Q: How do I fix the `VERTEXAI_PROJECT not set` error?** +A: Either pass the `vertex_project` parameter explicitly in your LiteLLM call, or set the `VERTEXAI_PROJECT` environment variable before running your code. + diff --git a/docs/my-website/static/img/vertex_cost_tracking_flow.svg b/docs/my-website/static/img/vertex_cost_tracking_flow.svg new file mode 100644 index 0000000000..d808dd2e36 --- /dev/null +++ b/docs/my-website/static/img/vertex_cost_tracking_flow.svg @@ -0,0 +1,62 @@ + + + + + + + + + + + HTTP request + X-Vertex-AI-LLM-Shared-Request-Type: priority + + + + + Vertex AI + + + + + Vertex response + usageMetadata.trafficType = ON_DEMAND_PRIORITY + + + + + + + + + LiteLLM stores it + _hidden_params.provider_specific_fields.traffic_type + + + + + + + + + completion_cost() + Maps traffic_type service_tier = "priority" + + + + + + + + + Pricing lookup + input/output_cost_per_token_priority + + + + ` + a + b + c + d + From 8e943929a2a3802f03036437521ce73cdd2a2c50 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 17:52:51 +0530 Subject: [PATCH 110/539] docs(sidebar): add vertex PayGo tutorial under Spend Tracking Made-with: Cursor --- docs/my-website/sidebars.js | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 79a0279bad..56d1bb8b55 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -528,6 +528,7 @@ const sidebars = { label: "Spend Tracking", items: [ "proxy/cost_tracking", + "tutorials/vertex_ai_pay_go", "proxy/request_tags", "proxy/custom_pricing", "proxy/pricing_calculator", From b56fdf188ed95d7db6457caec8f3008c4f0e97f4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 17:57:27 +0530 Subject: [PATCH 111/539] Fix greptile review --- .../docs/tutorials/vertex_ai_pay_go.md | 4 +-- .../static/img/vertex_cost_tracking_flow.svg | 27 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/my-website/docs/tutorials/vertex_ai_pay_go.md b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md index 625aff35e1..b40a8b2157 100644 --- a/docs/my-website/docs/tutorials/vertex_ai_pay_go.md +++ b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md @@ -124,7 +124,7 @@ MODEL_ID="gemini-2.0-flash-001" PROJECT_ID="YOUR_PROJECT_ID" curl -X POST \ - "${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:generateContent" \ + "${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/global/publishers/google/models/${MODEL_ID}:generateContent" \ -H "Authorization: Bearer sk-your-litellm-key" \ -H "Content-Type: application/json" \ -H "x-pass-X-Vertex-AI-LLM-Request-Type: shared" \ @@ -143,7 +143,7 @@ A: The service account or Application Default Credentials (ADC) user does not ha **Q: What should I do if I get a `429 Quota exceeded` error?** A: This means you've hit the per-region QPM (queries per minute) or TPM (tokens per minute) quota. You can: - Request a quota increase from the [GCP Quotas console](https://console.cloud.google.com/iam-admin/quotas) -- Add more regions to your LiteLLM configuration for load balancing (see the region balancing guide above) +- Add more regions to your LiteLLM configuration for load balancing - Upgrade to [Provisioned Throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput) for guaranteed capacity **Q: How do I fix the `VERTEXAI_PROJECT not set` error?** diff --git a/docs/my-website/static/img/vertex_cost_tracking_flow.svg b/docs/my-website/static/img/vertex_cost_tracking_flow.svg index d808dd2e36..d607d072dd 100644 --- a/docs/my-website/static/img/vertex_cost_tracking_flow.svg +++ b/docs/my-website/static/img/vertex_cost_tracking_flow.svg @@ -6,7 +6,7 @@ - + HTTP request X-Vertex-AI-LLM-Shared-Request-Type: priority @@ -17,7 +17,7 @@ Vertex AI - + Vertex response usageMetadata.trafficType = ON_DEMAND_PRIORITY @@ -27,7 +27,7 @@ - + LiteLLM stores it _hidden_params.provider_specific_fields.traffic_type @@ -37,26 +37,27 @@ - + completion_cost() - Maps traffic_type service_tier = "priority" + Maps traffic_type → service_tier = "priority" - + Pricing lookup input/output_cost_per_token_priority - - ` - a - b - c - d - + + + + + + + + \ No newline at end of file From ea80a19a3970f338b398dafb6f57881f52bad30a Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Wed, 18 Mar 2026 18:07:25 +0530 Subject: [PATCH 112/539] Fix greptile review --- docs/my-website/docs/tutorials/vertex_ai_pay_go.md | 2 +- .../static/img/vertex_cost_tracking_flow.svg | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/my-website/docs/tutorials/vertex_ai_pay_go.md b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md index b40a8b2157..87197e5bad 100644 --- a/docs/my-website/docs/tutorials/vertex_ai_pay_go.md +++ b/docs/my-website/docs/tutorials/vertex_ai_pay_go.md @@ -138,7 +138,7 @@ curl -X POST \ ## Troubleshooting **Q: What does `403 Permission denied` or `IAM_PERMISSION_DENIED` mean?** -A: The service account or Application Default Credentials (ADC) user does not have the `roles/aiplatform.user` role. To resolve this, re-run the `gcloud projects add-iam-policy-binding` command as shown above in the guide. +A: The service account or Application Default Credentials (ADC) user does not have the `roles/aiplatform.user` role. To resolve this, re-run the `gcloud projects add-iam-policy-binding`. **Q: What should I do if I get a `429 Quota exceeded` error?** A: This means you've hit the per-region QPM (queries per minute) or TPM (tokens per minute) quota. You can: diff --git a/docs/my-website/static/img/vertex_cost_tracking_flow.svg b/docs/my-website/static/img/vertex_cost_tracking_flow.svg index d607d072dd..c3b2e33a07 100644 --- a/docs/my-website/static/img/vertex_cost_tracking_flow.svg +++ b/docs/my-website/static/img/vertex_cost_tracking_flow.svg @@ -6,7 +6,7 @@ - + HTTP request X-Vertex-AI-LLM-Shared-Request-Type: priority @@ -17,7 +17,7 @@ Vertex AI - + Vertex response usageMetadata.trafficType = ON_DEMAND_PRIORITY @@ -27,7 +27,7 @@ - + LiteLLM stores it _hidden_params.provider_specific_fields.traffic_type @@ -37,7 +37,7 @@ - + completion_cost() Maps traffic_type → service_tier = "priority" @@ -47,7 +47,7 @@ - + Pricing lookup input/output_cost_per_token_priority From b13a7c679033ceea1b93ca82440f5b6bf15a8bd7 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:17:19 -0700 Subject: [PATCH 113/539] Fix guardrail_mode.replace crash when backend returns non-string value The backend type for guardrail_mode is Optional[Union[str, List[str], Dict]] but the UI typed it as just string, causing a crash when .replace() was called on null/object/array values. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../GuardrailViewer/GuardrailViewer.test.tsx | 29 ++++++++++++++++ .../GuardrailViewer/GuardrailViewer.tsx | 33 +++++++++++++++---- .../GuardrailViewer/__tests__/fixtures.ts | 2 +- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx index 28991ebfa0..60778a1337 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx @@ -151,6 +151,35 @@ describe("GuardrailViewer", () => { expect(screen.queryByText(/Raw Bedrock Guardrail Response/)).not.toBeInTheDocument(); }); + it("renders without crashing when guardrail_mode is null", () => { + const data = makeGuardrailInformation({ guardrail_mode: null }); + renderWithProviders(); + + expect(screen.getByText("Guardrails & Policy Compliance")).toBeInTheDocument(); + // Null mode should display as dash + expect(screen.getByText("—")).toBeInTheDocument(); + }); + + it("renders without crashing when guardrail_mode is an object", () => { + const data = makeGuardrailInformation({ + guardrail_mode: { default: "pre_call", tags: {} }, + }); + renderWithProviders(); + + expect(screen.getByText("Guardrails & Policy Compliance")).toBeInTheDocument(); + expect(screen.getByText("PRE-CALL")).toBeInTheDocument(); + }); + + it("renders without crashing when guardrail_mode is an array", () => { + const data = makeGuardrailInformation({ + guardrail_mode: ["pre_call", "post_call"], + }); + renderWithProviders(); + + expect(screen.getByText("Guardrails & Policy Compliance")).toBeInTheDocument(); + expect(screen.getByText("PRE-CALL")).toBeInTheDocument(); + }); + it("integration: renders with real Bedrock details without mocks", async () => { const user = userEvent.setup(); const data = makeGuardrailInformation({ diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx index 1ab4744b89..0c37b70c8b 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx @@ -40,7 +40,7 @@ interface GuardrailInformation { duration: number; end_time: number; start_time: number; - guardrail_mode: string; + guardrail_mode: string | string[] | Record | null; guardrail_name: string; guardrail_status: string; guardrail_response: GuardrailEntity[] | BedrockGuardrailResponse | any; @@ -77,9 +77,25 @@ const PROVIDERS_WITH_CUSTOM_RENDERERS = new Set([ "litellm_content_filter", ]); -const formatMode = (mode: unknown): string => { - if (mode == null || mode === "") return "—"; - const s = typeof mode === "string" ? mode : String(mode); +/** + * Extracts a plain string from guardrail_mode, which may be a string, + * an array of strings, an object with a "default" key, or null. + */ +const resolveMode = (mode: GuardrailInformation["guardrail_mode"]): string | null => { + if (mode == null) return null; + if (typeof mode === "string") return mode; + if (Array.isArray(mode)) return mode[0] ?? null; + if (typeof mode === "object" && "default" in mode) { + const def = mode.default; + if (typeof def === "string") return def; + if (Array.isArray(def)) return def[0] ?? null; + } + return null; +}; + +const formatMode = (mode: GuardrailInformation["guardrail_mode"]): string => { + const s = resolveMode(mode); + if (s == null || s === "") return "—"; return s.replace(/_/g, "-").toUpperCase(); }; @@ -302,9 +318,12 @@ const RequestLifecycle = ({ entries }: { entries: GuardrailInformation[] }) => { items.push({ type: "request", label: "Request received", offsetMs: 0 }); // Pre-call guardrails - const preCalls = sorted.filter((e) => e.guardrail_mode === "pre_call"); - const postCalls = sorted.filter((e) => e.guardrail_mode === "post_call" || e.guardrail_mode === "logging_only"); - const duringCalls = sorted.filter((e) => e.guardrail_mode === "during_call"); + const preCalls = sorted.filter((e) => resolveMode(e.guardrail_mode) === "pre_call"); + const postCalls = sorted.filter((e) => { + const m = resolveMode(e.guardrail_mode); + return m === "post_call" || m === "logging_only"; + }); + const duringCalls = sorted.filter((e) => resolveMode(e.guardrail_mode) === "during_call"); for (const e of preCalls) { const offsetMs = Math.round((e.end_time - baseTime) * 1000); diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/__tests__/fixtures.ts b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/__tests__/fixtures.ts index 3c78aacf85..fc487b04d7 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/__tests__/fixtures.ts +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/__tests__/fixtures.ts @@ -23,7 +23,7 @@ export interface GuardrailInformation { duration: number; end_time: number; start_time: number; - guardrail_mode: string; + guardrail_mode: string | string[] | Record | null; guardrail_name: string; guardrail_status: string; guardrail_response: GuardrailEntity[] | BedrockGuardrailResponse; From 20c1d984a610935368bacb215773135ac109cd5c Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:18:42 -0700 Subject: [PATCH 114/539] [Test] UI: Add unit tests for 10 previously untested components Add Vitest + RTL tests covering DebugWarningBanner, HelpLink, ExportFormatSelector, ExportSummary, ExportTypeSelector, UsageExportHeader, MetricCard, ScoreChart, GuardrailConfig, and AgentHubTableColumns. 73 tests total. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../AIHub/AgentHubTableColumns.test.tsx | 132 ++++++++++++++++++ .../components/DebugWarningBanner.test.tsx | 38 +++++ .../ExportFormatSelector.test.tsx | 36 +++++ .../EntityUsageExport/ExportSummary.test.tsx | 35 +++++ .../ExportTypeSelector.test.tsx | 37 +++++ .../UsageExportHeader.test.tsx | 73 ++++++++++ .../GuardrailConfig.test.tsx | 84 +++++++++++ .../GuardrailsMonitor/MetricCard.test.tsx | 34 +++++ .../GuardrailsMonitor/ScoreChart.test.tsx | 47 +++++++ .../src/components/HelpLink.test.tsx | 128 +++++++++++++++++ 10 files changed, 644 insertions(+) create mode 100644 ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx create mode 100644 ui/litellm-dashboard/src/components/DebugWarningBanner.test.tsx create mode 100644 ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx create mode 100644 ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx create mode 100644 ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx create mode 100644 ui/litellm-dashboard/src/components/EntityUsageExport/UsageExportHeader.test.tsx create mode 100644 ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx create mode 100644 ui/litellm-dashboard/src/components/GuardrailsMonitor/MetricCard.test.tsx create mode 100644 ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx create mode 100644 ui/litellm-dashboard/src/components/HelpLink.test.tsx diff --git a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx new file mode 100644 index 0000000000..c32df2f546 --- /dev/null +++ b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx @@ -0,0 +1,132 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { vi } from "vitest"; +import { flexRender, getCoreRowModel, useReactTable } from "@tanstack/react-table"; +import { getAgentHubTableColumns, AgentHubData } from "./AgentHubTableColumns"; + +const mockAgent: AgentHubData = { + agent_id: "agent-1", + protocolVersion: "1.0", + name: "Test Agent", + description: "A test agent for unit testing", + url: "https://agent.example.com", + version: "2.0", + capabilities: { streaming: true, caching: false }, + defaultInputModes: ["text"], + defaultOutputModes: ["text", "image"], + skills: [ + { id: "s1", name: "Skill One", description: "First skill" }, + { id: "s2", name: "Skill Two", description: "Second skill" }, + { id: "s3", name: "Skill Three", description: "Third skill" }, + ], + is_public: true, +}; + +function TestTable({ data, publicPage = false }: { data: AgentHubData[]; publicPage?: boolean }) { + const showModal = vi.fn(); + const copyToClipboard = vi.fn(); + const columns = getAgentHubTableColumns(showModal, copyToClipboard, publicPage); + const table = useReactTable({ data, columns, getCoreRowModel: getCoreRowModel() }); + + return ( + + + {table.getHeaderGroups().map((hg) => ( + + {hg.headers.map((h) => ( + + ))} + + ))} + + + {table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + ))} + + ))} + +
{flexRender(h.column.columnDef.header, h.getContext())}
{flexRender(cell.column.columnDef.cell, cell.getContext())}
+ ); +} + +describe("AgentHubTableColumns", () => { + it("should render", () => { + render(); + expect(screen.getByText("Test Agent")).toBeInTheDocument(); + }); + + it("should display the agent description", () => { + render(); + // Description appears in both the description column and the mobile view within agent name column + expect(screen.getAllByText("A test agent for unit testing").length).toBeGreaterThanOrEqual(1); + }); + + it("should display the version with a 'v' prefix", () => { + render(); + expect(screen.getByText("v2.0")).toBeInTheDocument(); + }); + + it("should display the protocol version", () => { + render(); + expect(screen.getByText("1.0")).toBeInTheDocument(); + }); + + it("should show skill count with correct pluralization", () => { + render(); + expect(screen.getByText("3 skills")).toBeInTheDocument(); + }); + + it("should show first two skills and '+1' for overflow", () => { + render(); + expect(screen.getByText("Skill One")).toBeInTheDocument(); + expect(screen.getByText("Skill Two")).toBeInTheDocument(); + expect(screen.getByText("+1")).toBeInTheDocument(); + }); + + it("should show only true capabilities as badges", () => { + render(); + expect(screen.getByText("streaming")).toBeInTheDocument(); + expect(screen.queryByText("caching")).not.toBeInTheDocument(); + }); + + it("should display I/O modes", () => { + render(); + expect(screen.getByText("text")).toBeInTheDocument(); + expect(screen.getByText("text, image")).toBeInTheDocument(); + }); + + it("should display 'Yes' badge for public agents", () => { + render(); + expect(screen.getByText("Yes")).toBeInTheDocument(); + }); + + it("should display 'No' badge for non-public agents", () => { + const privateAgent = { ...mockAgent, is_public: false }; + render(); + expect(screen.getByText("No")).toBeInTheDocument(); + }); + + it("should display a Details button", () => { + render(); + expect(screen.getByRole("button", { name: /details|info/i })).toBeInTheDocument(); + }); + + it("should show '-' when agent has no capabilities", () => { + const noCapAgent = { ...mockAgent, capabilities: {} }; + render(); + // The dash is rendered in the capabilities column + expect(screen.getByText("-")).toBeInTheDocument(); + }); + + it("should show singular 'skill' for one skill", () => { + const oneSkillAgent = { + ...mockAgent, + skills: [{ id: "s1", name: "Only Skill", description: "One" }], + }; + render(); + expect(screen.getByText("1 skill")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/DebugWarningBanner.test.tsx b/ui/litellm-dashboard/src/components/DebugWarningBanner.test.tsx new file mode 100644 index 0000000000..4c99175162 --- /dev/null +++ b/ui/litellm-dashboard/src/components/DebugWarningBanner.test.tsx @@ -0,0 +1,38 @@ +import { renderWithProviders, screen } from "../../tests/test-utils"; +import { vi } from "vitest"; +import { DebugWarningBanner } from "./DebugWarningBanner"; + +const mockUseHealthReadiness = vi.fn(); +vi.mock("@/app/(dashboard)/hooks/healthReadiness/useHealthReadiness", () => ({ + useHealthReadiness: () => mockUseHealthReadiness(), +})); + +describe("DebugWarningBanner", () => { + afterEach(() => { + vi.resetAllMocks(); + }); + + it("should render", () => { + mockUseHealthReadiness.mockReturnValue({ data: { is_detailed_debug: true } }); + renderWithProviders(); + expect(screen.getByText(/Performance Warning/i)).toBeInTheDocument(); + }); + + it("should render nothing when debug mode is disabled", () => { + mockUseHealthReadiness.mockReturnValue({ data: { is_detailed_debug: false } }); + const { container } = renderWithProviders(); + expect(container.firstChild).toBeNull(); + }); + + it("should render nothing when health data is undefined", () => { + mockUseHealthReadiness.mockReturnValue({ data: undefined }); + const { container } = renderWithProviders(); + expect(container.firstChild).toBeNull(); + }); + + it("should mention LITELLM_LOG=DEBUG in the description", () => { + mockUseHealthReadiness.mockReturnValue({ data: { is_detailed_debug: true } }); + renderWithProviders(); + expect(screen.getByText(/LITELLM_LOG=DEBUG/)).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx new file mode 100644 index 0000000000..a88f9cb116 --- /dev/null +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx @@ -0,0 +1,36 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { vi } from "vitest"; +import ExportFormatSelector from "./ExportFormatSelector"; + +describe("ExportFormatSelector", () => { + it("should render", () => { + render(); + expect(screen.getByText("Format")).toBeInTheDocument(); + }); + + it("should display the current value as csv", () => { + render(); + expect(screen.getByText("CSV (Excel, Google Sheets)")).toBeInTheDocument(); + }); + + it("should display the current value as json", () => { + render(); + expect(screen.getByText("JSON (includes metadata)")).toBeInTheDocument(); + }); + + it("should call onChange when a different format is selected", async () => { + const onChange = vi.fn(); + const user = userEvent.setup(); + render(); + + // Open the Ant Design Select dropdown + await user.click(screen.getByText("CSV (Excel, Google Sheets)")); + // Select JSON option from the dropdown + const jsonOption = await screen.findByText("JSON (includes metadata)", { + selector: ".ant-select-item-option-content", + }); + await user.click(jsonOption); + expect(onChange).toHaveBeenCalledWith("json", expect.anything()); + }); +}); diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx new file mode 100644 index 0000000000..ab426e291b --- /dev/null +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx @@ -0,0 +1,35 @@ +import { render, screen } from "@testing-library/react"; +import ExportSummary from "./ExportSummary"; + +describe("ExportSummary", () => { + const dateRange = { + from: new Date("2025-01-01"), + to: new Date("2025-01-31"), + }; + + it("should render", () => { + render(); + expect(screen.getByText(/2025/)).toBeInTheDocument(); + }); + + it("should display formatted date range", () => { + render(); + const text = screen.getByText(/\d+.*-.*\d+/); + expect(text).toBeInTheDocument(); + }); + + it("should show singular 'filter' for one filter", () => { + render(); + expect(screen.getByText(/1 filter$/)).toBeInTheDocument(); + }); + + it("should show plural 'filters' for multiple filters", () => { + render(); + expect(screen.getByText(/2 filters/)).toBeInTheDocument(); + }); + + it("should not show filter text when no filters applied", () => { + render(); + expect(screen.queryByText(/filter/)).not.toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx new file mode 100644 index 0000000000..6ccf8822e0 --- /dev/null +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx @@ -0,0 +1,37 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { vi } from "vitest"; +import ExportTypeSelector from "./ExportTypeSelector"; + +describe("ExportTypeSelector", () => { + it("should render", () => { + render(); + expect(screen.getByText("Export type")).toBeInTheDocument(); + }); + + it("should render all three radio options", () => { + render(); + expect(screen.getAllByRole("radio")).toHaveLength(3); + }); + + it("should interpolate entity type in labels", () => { + render(); + expect(screen.getByText(/Day-by-day breakdown by organization$/)).toBeInTheDocument(); + expect(screen.getByText(/organization and key/)).toBeInTheDocument(); + expect(screen.getByText(/organization and model/)).toBeInTheDocument(); + }); + + it("should call onChange when a different option is selected", async () => { + const onChange = vi.fn(); + const user = userEvent.setup(); + render(); + await user.click(screen.getByText(/by team and key/)); + expect(onChange).toHaveBeenCalledWith("daily_with_keys"); + }); + + it("should have the correct radio checked based on value prop", () => { + render(); + const radios = screen.getAllByRole("radio"); + expect(radios[2]).toBeChecked(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/UsageExportHeader.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/UsageExportHeader.test.tsx new file mode 100644 index 0000000000..729d6fd340 --- /dev/null +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/UsageExportHeader.test.tsx @@ -0,0 +1,73 @@ +import { renderWithProviders, screen } from "../../../tests/test-utils"; +import userEvent from "@testing-library/user-event"; +import { vi } from "vitest"; +import UsageExportHeader from "./UsageExportHeader"; +import type { EntitySpendData } from "./types"; + +vi.mock("./EntityUsageExportModal", () => ({ + default: ({ isOpen, onClose }: { isOpen: boolean; onClose: () => void }) => + isOpen ? ( +
+ +
+ ) : null, +})); + +const defaultProps = { + dateValue: { from: new Date("2025-01-01"), to: new Date("2025-01-31") }, + entityType: "team" as const, + spendData: { + results: [], + metadata: { + total_spend: 0, + total_api_requests: 0, + total_successful_requests: 0, + total_failed_requests: 0, + total_tokens: 0, + }, + } satisfies EntitySpendData, +}; + +describe("UsageExportHeader", () => { + it("should render", () => { + renderWithProviders(); + expect(screen.getByRole("button", { name: /export data/i })).toBeInTheDocument(); + }); + + it("should open the export modal when the export button is clicked", async () => { + const user = userEvent.setup(); + renderWithProviders(); + await user.click(screen.getByRole("button", { name: /export data/i })); + expect(screen.getByTestId("export-modal")).toBeInTheDocument(); + }); + + it("should close the export modal when onClose is called", async () => { + const user = userEvent.setup(); + renderWithProviders(); + await user.click(screen.getByRole("button", { name: /export data/i })); + await user.click(screen.getByRole("button", { name: /close/i })); + expect(screen.queryByTestId("export-modal")).not.toBeInTheDocument(); + }); + + it("should not show filter dropdown when showFilters is false", () => { + renderWithProviders(); + expect(screen.queryByText(/filter/i)).not.toBeInTheDocument(); + }); + + it("should show filter dropdown when showFilters is true and options provided", () => { + renderWithProviders( + , + ); + expect(screen.getByText("Team")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx new file mode 100644 index 0000000000..fd9ad61599 --- /dev/null +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx @@ -0,0 +1,84 @@ +import { render, screen, act } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { vi } from "vitest"; +import { GuardrailConfig } from "./GuardrailConfig"; + +describe("GuardrailConfig", () => { + const defaultProps = { + guardrailName: "Content Safety", + guardrailType: "Content Safety", + provider: "bedrock", + }; + + it("should render", () => { + render(); + expect(screen.getByText("Parameters")).toBeInTheDocument(); + }); + + it("should display the guardrail name in the parameters description", () => { + render(); + expect(screen.getByText(/Configure Content Safety behavior/)).toBeInTheDocument(); + }); + + it("should show version history when 'View history' is clicked", async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByRole("button", { name: /view history/i })); + expect(screen.getByText("Initial configuration")).toBeInTheDocument(); + expect(screen.getByText("Added custom categories list")).toBeInTheDocument(); + }); + + it("should toggle version history text between View/Hide", async () => { + const user = userEvent.setup(); + render(); + const button = screen.getByRole("button", { name: /view history/i }); + await user.click(button); + expect(screen.getByRole("button", { name: /hide history/i })).toBeInTheDocument(); + }); + + it("should show custom code textarea when custom code override is toggled on", async () => { + const user = userEvent.setup(); + render(); + const switches = screen.getAllByRole("switch"); + // The second switch is the custom code override toggle + const customCodeSwitch = switches[1]; + await user.click(customCodeSwitch); + expect(screen.getByPlaceholderText(/async def evaluate/)).toBeInTheDocument(); + }); + + it("should hide custom code textarea when custom code override is off", () => { + render(); + // There's an input for categories, but no textarea + expect(screen.queryByPlaceholderText(/async def evaluate/)).not.toBeInTheDocument(); + }); + + it("should show the re-run button in idle state", () => { + render(); + expect(screen.getByRole("button", { name: /re-run on failing logs/i })).toBeInTheDocument(); + }); + + it("should show loading state when re-run is clicked", async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + render(); + await user.click(screen.getByRole("button", { name: /re-run on failing logs/i })); + expect(screen.getByText(/Running on 10 samples/)).toBeInTheDocument(); + vi.useRealTimers(); + }); + + it("should show success message after re-run completes", async () => { + vi.useFakeTimers({ shouldAdvanceTime: true }); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + render(); + await user.click(screen.getByRole("button", { name: /re-run on failing logs/i })); + act(() => { vi.advanceTimersByTime(2500); }); + expect(screen.getByText(/7\/10 would now pass/)).toBeInTheDocument(); + vi.useRealTimers(); + }); + + it("should display the Revert and Save buttons", () => { + render(); + expect(screen.getByRole("button", { name: /revert/i })).toBeInTheDocument(); + expect(screen.getByRole("button", { name: /save as v4/i })).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/MetricCard.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/MetricCard.test.tsx new file mode 100644 index 0000000000..9352cf4655 --- /dev/null +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/MetricCard.test.tsx @@ -0,0 +1,34 @@ +import { render, screen } from "@testing-library/react"; +import { MetricCard } from "./MetricCard"; + +describe("MetricCard", () => { + it("should render", () => { + render(); + expect(screen.getByText("Total Requests")).toBeInTheDocument(); + }); + + it("should display the numeric value", () => { + render(); + expect(screen.getByText("1234")).toBeInTheDocument(); + }); + + it("should display a string value", () => { + render(); + expect(screen.getByText("95.2%")).toBeInTheDocument(); + }); + + it("should display subtitle when provided", () => { + render(); + expect(screen.getByText("Last 7 days")).toBeInTheDocument(); + }); + + it("should not display subtitle when not provided", () => { + render(); + expect(screen.queryByText(/days/)).not.toBeInTheDocument(); + }); + + it("should display icon when provided", () => { + render(!} />); + expect(screen.getByTestId("icon")).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx new file mode 100644 index 0000000000..b950dd9a30 --- /dev/null +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx @@ -0,0 +1,47 @@ +import { render, screen } from "@testing-library/react"; +import { ScoreChart } from "./ScoreChart"; + +vi.mock("@tremor/react", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + BarChart: ({ data, categories }: { data: unknown[]; categories: string[] }) => ( +
+ {data.length} data points +
+ ), + }; +}); + +describe("ScoreChart", () => { + it("should render", () => { + render(); + expect(screen.getByText("Request Outcomes Over Time")).toBeInTheDocument(); + }); + + it("should show empty state when no data provided", () => { + render(); + expect(screen.getByText("No chart data for this period")).toBeInTheDocument(); + }); + + it("should show empty state when data is an empty array", () => { + render(); + expect(screen.getByText("No chart data for this period")).toBeInTheDocument(); + }); + + it("should render chart when data is provided", () => { + const data = [ + { date: "2025-01-01", passed: 100, blocked: 5 }, + { date: "2025-01-02", passed: 120, blocked: 3 }, + ]; + render(); + expect(screen.getByTestId("bar-chart")).toBeInTheDocument(); + expect(screen.getByText("2 data points")).toBeInTheDocument(); + }); + + it("should pass correct categories to the chart", () => { + const data = [{ date: "2025-01-01", passed: 100, blocked: 5 }]; + render(); + expect(screen.getByTestId("bar-chart")).toHaveAttribute("data-categories", "passed,blocked"); + }); +}); diff --git a/ui/litellm-dashboard/src/components/HelpLink.test.tsx b/ui/litellm-dashboard/src/components/HelpLink.test.tsx new file mode 100644 index 0000000000..2e9b44a9e9 --- /dev/null +++ b/ui/litellm-dashboard/src/components/HelpLink.test.tsx @@ -0,0 +1,128 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { HelpLink, HelpIcon, DocsMenu } from "./HelpLink"; + +describe("HelpLink", () => { + it("should render", () => { + render(); + expect(screen.getByRole("link")).toBeInTheDocument(); + }); + + it("should display default 'Learn more' text when no children provided", () => { + render(); + expect(screen.getByText("Learn more")).toBeInTheDocument(); + }); + + it("should display custom children text", () => { + render(Custom docs link); + expect(screen.getByText("Custom docs link")).toBeInTheDocument(); + }); + + it("should open in a new tab with noopener noreferrer", () => { + render(); + const link = screen.getByRole("link"); + expect(link).toHaveAttribute("target", "_blank"); + expect(link).toHaveAttribute("rel", "noopener noreferrer"); + }); + + it("should have accessible screen reader text", () => { + render(); + expect(screen.getByText("(opens in a new tab)")).toBeInTheDocument(); + }); +}); + +describe("HelpIcon", () => { + it("should render", () => { + render(); + expect(screen.getByRole("button", { name: /help information/i })).toBeInTheDocument(); + }); + + it("should show tooltip content on mouse enter", async () => { + const user = userEvent.setup(); + render(); + await user.hover(screen.getByRole("button", { name: /help information/i })); + expect(screen.getByText("Helpful tooltip text")).toBeInTheDocument(); + }); + + it("should hide tooltip content on mouse leave", async () => { + const user = userEvent.setup(); + render(); + const button = screen.getByRole("button", { name: /help information/i }); + await user.hover(button); + await user.unhover(button); + expect(screen.queryByText("Helpful tooltip text")).not.toBeInTheDocument(); + }); + + it("should show learn more link when learnMoreHref is provided", async () => { + const user = userEvent.setup(); + render(); + await user.hover(screen.getByRole("button", { name: /help information/i })); + expect(screen.getByText("Learn more")).toBeInTheDocument(); + }); + + it("should use custom learnMoreText when provided", async () => { + const user = userEvent.setup(); + render( + , + ); + await user.hover(screen.getByRole("button", { name: /help information/i })); + expect(screen.getByText("Read docs")).toBeInTheDocument(); + }); +}); + +describe("DocsMenu", () => { + const items = [ + { label: "Custom pricing", href: "https://docs.example.com/pricing" }, + { label: "Spend tracking", href: "https://docs.example.com/spend" }, + ]; + + it("should render", () => { + render(); + expect(screen.getByRole("button", { name: /docs/i })).toBeInTheDocument(); + }); + + it("should show menu items when clicked", async () => { + const user = userEvent.setup(); + render(); + await user.click(screen.getByRole("button", { name: /docs/i })); + expect(screen.getByText("Custom pricing")).toBeInTheDocument(); + expect(screen.getByText("Spend tracking")).toBeInTheDocument(); + }); + + it("should hide menu items when clicked again", async () => { + const user = userEvent.setup(); + render(); + const button = screen.getByRole("button", { name: /docs/i }); + await user.click(button); + await user.click(button); + expect(screen.queryByText("Custom pricing")).not.toBeInTheDocument(); + }); + + it("should set aria-expanded correctly", async () => { + const user = userEvent.setup(); + render(); + const button = screen.getByRole("button", { name: /docs/i }); + expect(button).toHaveAttribute("aria-expanded", "false"); + await user.click(button); + expect(button).toHaveAttribute("aria-expanded", "true"); + }); + + it("should close menu when clicking outside", async () => { + const user = userEvent.setup(); + render( +
+ + +
, + ); + await user.click(screen.getByRole("button", { name: /docs/i })); + expect(screen.getByText("Custom pricing")).toBeInTheDocument(); + await user.click(screen.getByRole("button", { name: /outside/i })); + expect(screen.queryByText("Custom pricing")).not.toBeInTheDocument(); + }); + + it("should display custom children text", () => { + render(Help); + expect(screen.getByText("Help")).toBeInTheDocument(); + }); +}); From 1c8b5f77c96be048e83a2b35c7679fd8d546e53b Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:29:57 -0700 Subject: [PATCH 115/539] address greptile review feedback (greploop iteration 1) Add modeMatches() helper so array guardrail_mode values (e.g. ["pre_call", "post_call"]) place the entry in all matching timeline buckets, not just the first. Updated test to verify both buckets. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../GuardrailViewer/GuardrailViewer.test.tsx | 6 ++- .../GuardrailViewer/GuardrailViewer.tsx | 37 ++++++++++++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx index 60778a1337..b5e04c7244 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.test.tsx @@ -170,14 +170,18 @@ describe("GuardrailViewer", () => { expect(screen.getByText("PRE-CALL")).toBeInTheDocument(); }); - it("renders without crashing when guardrail_mode is an array", () => { + it("renders without crashing when guardrail_mode is an array and shows in both timeline buckets", () => { const data = makeGuardrailInformation({ guardrail_mode: ["pre_call", "post_call"], }); renderWithProviders(); expect(screen.getByText("Guardrails & Policy Compliance")).toBeInTheDocument(); + // Mode badge shows first element formatted expect(screen.getByText("PRE-CALL")).toBeInTheDocument(); + // Entry should appear in both pre-call and post-call timeline sections + expect(screen.getByText(/Pre-call guardrail:/)).toBeInTheDocument(); + expect(screen.getByText(/Post-call guardrail:/)).toBeInTheDocument(); }); it("integration: renders with real Bedrock details without mocks", async () => { diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx index 0c37b70c8b..1056076fcd 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx @@ -78,8 +78,8 @@ const PROVIDERS_WITH_CUSTOM_RENDERERS = new Set([ ]); /** - * Extracts a plain string from guardrail_mode, which may be a string, - * an array of strings, an object with a "default" key, or null. + * Extracts a plain string from guardrail_mode for display purposes. + * Returns the first mode when multiple are present. */ const resolveMode = (mode: GuardrailInformation["guardrail_mode"]): string | null => { if (mode == null) return null; @@ -93,6 +93,25 @@ const resolveMode = (mode: GuardrailInformation["guardrail_mode"]): string | nul return null; }; +/** + * Checks whether guardrail_mode includes the given target stage. + * Handles arrays (multi-stage guardrails) by checking all elements. + */ +const modeMatches = ( + mode: GuardrailInformation["guardrail_mode"], + target: string, +): boolean => { + if (mode == null) return false; + if (typeof mode === "string") return mode === target; + if (Array.isArray(mode)) return mode.includes(target); + if (typeof mode === "object" && "default" in mode) { + const def = mode.default; + if (typeof def === "string") return def === target; + if (Array.isArray(def)) return (def as string[]).includes(target); + } + return false; +}; + const formatMode = (mode: GuardrailInformation["guardrail_mode"]): string => { const s = resolveMode(mode); if (s == null || s === "") return "—"; @@ -317,13 +336,13 @@ const RequestLifecycle = ({ entries }: { entries: GuardrailInformation[] }) => { // Request received items.push({ type: "request", label: "Request received", offsetMs: 0 }); - // Pre-call guardrails - const preCalls = sorted.filter((e) => resolveMode(e.guardrail_mode) === "pre_call"); - const postCalls = sorted.filter((e) => { - const m = resolveMode(e.guardrail_mode); - return m === "post_call" || m === "logging_only"; - }); - const duringCalls = sorted.filter((e) => resolveMode(e.guardrail_mode) === "during_call"); + // Pre-call guardrails — use modeMatches so array modes (e.g. ["pre_call", "post_call"]) + // place the entry in every matching bucket. + const preCalls = sorted.filter((e) => modeMatches(e.guardrail_mode, "pre_call")); + const postCalls = sorted.filter( + (e) => modeMatches(e.guardrail_mode, "post_call") || modeMatches(e.guardrail_mode, "logging_only"), + ); + const duringCalls = sorted.filter((e) => modeMatches(e.guardrail_mode, "during_call")); for (const e of preCalls) { const offsetMs = Math.round((e.end_time - baseTime) * 1000); From 40721ab18f15bf4e3df2c1c04d8aba2be155e9e0 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:31:25 -0700 Subject: [PATCH 116/539] address greptile review feedback (greploop iteration 1) - Move vi.useRealTimers() to afterEach for proper cleanup - Use label-based DOM queries instead of fragile positional indexes - Remove leftover debug console.log from AgentHubTableColumns.tsx Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/AIHub/AgentHubTableColumns.tsx | 1 - .../ExportTypeSelector.test.tsx | 4 ++-- .../GuardrailConfig.test.tsx | 20 +++++++++++++------ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.tsx b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.tsx index c6a8c0b9da..09b1c14761 100644 --- a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.tsx +++ b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.tsx @@ -194,7 +194,6 @@ export const getAgentHubTableColumns = ( return publicA - publicB; }, cell: ({ row }) => { - console.log(`CHECKPOINT 1: ${JSON.stringify(row.original)}`); const agent = row.original; return agent.is_public === true ? ( diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx index 6ccf8822e0..d9469f095a 100644 --- a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportTypeSelector.test.tsx @@ -31,7 +31,7 @@ describe("ExportTypeSelector", () => { it("should have the correct radio checked based on value prop", () => { render(); - const radios = screen.getAllByRole("radio"); - expect(radios[2]).toBeChecked(); + const modelRadio = screen.getByRole("radio", { name: /by team and model/i }); + expect(modelRadio).toBeChecked(); }); }); diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx index fd9ad61599..38f6568198 100644 --- a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx @@ -10,6 +10,10 @@ describe("GuardrailConfig", () => { provider: "bedrock", }; + afterEach(() => { + vi.useRealTimers(); + }); + it("should render", () => { render(); expect(screen.getByText("Parameters")).toBeInTheDocument(); @@ -39,10 +43,16 @@ describe("GuardrailConfig", () => { it("should show custom code textarea when custom code override is toggled on", async () => { const user = userEvent.setup(); render(); - const switches = screen.getAllByRole("switch"); - // The second switch is the custom code override toggle - const customCodeSwitch = switches[1]; - await user.click(customCodeSwitch); + // Walk up from "Custom Code Override" heading to find the enclosing section, + // then locate the switch within it + const heading = screen.getByText("Custom Code Override"); + let container = heading.parentElement; + let customCodeSwitch: Element | null = null; + while (container && !customCodeSwitch) { + customCodeSwitch = container.querySelector('[role="switch"]'); + container = container.parentElement; + } + await user.click(customCodeSwitch!); expect(screen.getByPlaceholderText(/async def evaluate/)).toBeInTheDocument(); }); @@ -63,7 +73,6 @@ describe("GuardrailConfig", () => { render(); await user.click(screen.getByRole("button", { name: /re-run on failing logs/i })); expect(screen.getByText(/Running on 10 samples/)).toBeInTheDocument(); - vi.useRealTimers(); }); it("should show success message after re-run completes", async () => { @@ -73,7 +82,6 @@ describe("GuardrailConfig", () => { await user.click(screen.getByRole("button", { name: /re-run on failing logs/i })); act(() => { vi.advanceTimersByTime(2500); }); expect(screen.getByText(/7\/10 would now pass/)).toBeInTheDocument(); - vi.useRealTimers(); }); it("should display the Revert and Save buttons", () => { From 6902355f5ba7ae6de8af2366d8ec0ce372c20b9b Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:34:03 -0700 Subject: [PATCH 117/539] address greptile review feedback (greploop iteration 2) Replace unsafe `as string[]` cast in modeMatches with runtime type check via `.some()`. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/view_logs/GuardrailViewer/GuardrailViewer.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx index 1056076fcd..28247d4b2a 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx @@ -107,7 +107,7 @@ const modeMatches = ( if (typeof mode === "object" && "default" in mode) { const def = mode.default; if (typeof def === "string") return def === target; - if (Array.isArray(def)) return (def as string[]).includes(target); + if (Array.isArray(def)) return def.some((x) => typeof x === "string" && x === target); } return false; }; From 7714d1be0b236871ffd24424d8513fd4630eaffb Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:40:56 -0700 Subject: [PATCH 118/539] address greptile review feedback (greploop iteration 3) Add typeof string guards to all array element returns in resolveMode to prevent non-string values from sneaking through via any-widening. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../view_logs/GuardrailViewer/GuardrailViewer.tsx | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx index 28247d4b2a..5608e3677f 100644 --- a/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/GuardrailViewer/GuardrailViewer.tsx @@ -84,11 +84,17 @@ const PROVIDERS_WITH_CUSTOM_RENDERERS = new Set([ const resolveMode = (mode: GuardrailInformation["guardrail_mode"]): string | null => { if (mode == null) return null; if (typeof mode === "string") return mode; - if (Array.isArray(mode)) return mode[0] ?? null; + if (Array.isArray(mode)) { + const first = mode[0]; + return typeof first === "string" ? first : null; + } if (typeof mode === "object" && "default" in mode) { const def = mode.default; if (typeof def === "string") return def; - if (Array.isArray(def)) return def[0] ?? null; + if (Array.isArray(def)) { + const first = def[0]; + return typeof first === "string" ? first : null; + } } return null; }; From 56ed8379e297e20fb219817690d3b197c440ab69 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:42:19 -0700 Subject: [PATCH 119/539] address greptile review feedback (greploop iteration 2) - Add explicit vi import to ScoreChart.test.tsx - Use custom matcher for I/O modes to avoid cross-element text issues - Use version-agnostic regex for Save button assertion - Add comments noting placeholder data in GuardrailConfig tests Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/components/AIHub/AgentHubTableColumns.test.tsx | 10 ++++++++-- .../GuardrailsMonitor/GuardrailConfig.test.tsx | 5 ++++- .../components/GuardrailsMonitor/ScoreChart.test.tsx | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx index c32df2f546..703a24d8d6 100644 --- a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx +++ b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx @@ -94,8 +94,14 @@ describe("AgentHubTableColumns", () => { it("should display I/O modes", () => { render(); - expect(screen.getByText("text")).toBeInTheDocument(); - expect(screen.getByText("text, image")).toBeInTheDocument(); + // "In:" and "Out:" are in children; getByText with exact:false + // matches against the element's full textContent across child nodes + expect(screen.getByText((_, el) => + el?.tagName === "P" && el.textContent === "In: text" + )).toBeInTheDocument(); + expect(screen.getByText((_, el) => + el?.tagName === "P" && el.textContent === "Out: text, image" + )).toBeInTheDocument(); }); it("should display 'Yes' badge for public agents", () => { diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx index 38f6568198..ac84c270d8 100644 --- a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx @@ -24,6 +24,8 @@ describe("GuardrailConfig", () => { expect(screen.getByText(/Configure Content Safety behavior/)).toBeInTheDocument(); }); + // Note: Version history entries are hardcoded placeholders in the component. + // These assertions will need updating when wired to real API data. it("should show version history when 'View history' is clicked", async () => { const user = userEvent.setup(); render(); @@ -87,6 +89,7 @@ describe("GuardrailConfig", () => { it("should display the Revert and Save buttons", () => { render(); expect(screen.getByRole("button", { name: /revert/i })).toBeInTheDocument(); - expect(screen.getByRole("button", { name: /save as v4/i })).toBeInTheDocument(); + // The component's hardcoded default version is "v3", so Save shows "v4" + expect(screen.getByRole("button", { name: /save as v\d+/i })).toBeInTheDocument(); }); }); diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx index b950dd9a30..c32e674578 100644 --- a/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/ScoreChart.test.tsx @@ -1,4 +1,5 @@ import { render, screen } from "@testing-library/react"; +import { vi } from "vitest"; import { ScoreChart } from "./ScoreChart"; vi.mock("@tremor/react", async (importOriginal) => { From bbc120095e2dcc9477f76d647d01490699861114 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 12:48:01 -0700 Subject: [PATCH 120/539] address greptile review feedback (greploop iteration 3) - Remove Ant Design CSS class selector coupling in ExportFormatSelector test - Lift mock fns out of TestTable component body to enable callback assertions Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/AIHub/AgentHubTableColumns.test.tsx | 14 +++++++++++--- .../ExportFormatSelector.test.tsx | 4 +--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx index 703a24d8d6..083e67c297 100644 --- a/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx +++ b/ui/litellm-dashboard/src/components/AIHub/AgentHubTableColumns.test.tsx @@ -22,9 +22,17 @@ const mockAgent: AgentHubData = { is_public: true, }; -function TestTable({ data, publicPage = false }: { data: AgentHubData[]; publicPage?: boolean }) { - const showModal = vi.fn(); - const copyToClipboard = vi.fn(); +function TestTable({ + data, + publicPage = false, + showModal = vi.fn(), + copyToClipboard = vi.fn(), +}: { + data: AgentHubData[]; + publicPage?: boolean; + showModal?: ReturnType; + copyToClipboard?: ReturnType; +}) { const columns = getAgentHubTableColumns(showModal, copyToClipboard, publicPage); const table = useReactTable({ data, columns, getCoreRowModel: getCoreRowModel() }); diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx index a88f9cb116..d20d24992f 100644 --- a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportFormatSelector.test.tsx @@ -27,9 +27,7 @@ describe("ExportFormatSelector", () => { // Open the Ant Design Select dropdown await user.click(screen.getByText("CSV (Excel, Google Sheets)")); // Select JSON option from the dropdown - const jsonOption = await screen.findByText("JSON (includes metadata)", { - selector: ".ant-select-item-option-content", - }); + const jsonOption = await screen.findByText("JSON (includes metadata)"); await user.click(jsonOption); expect(onChange).toHaveBeenCalledWith("json", expect.anything()); }); From fcea5606827552b506fb1de478ec1c6d095e7152 Mon Sep 17 00:00:00 2001 From: Avik Kumar Date: Wed, 18 Mar 2026 15:58:13 -0400 Subject: [PATCH 121/539] fix(langsmith): populate usage_metadata in outputs for Cost column LangSmith reads the Cost column from outputs.usage_metadata.total_cost, but LangsmithLogger._prepare_log_data never wrote to that key. The response_cost was already computed in StandardLoggingPayload but was not forwarded to the outputs dict. Inject usage_metadata with input_tokens, output_tokens, total_tokens, and total_cost into the outputs dict so LangSmith can display cost. Fixes #24001 Made-with: Cursor --- litellm/integrations/langsmith.py | 14 ++++- .../integrations/test_langsmith_init.py | 54 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 03845af521..ef2d30bb26 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -153,11 +153,23 @@ class LangsmithLogger(CustomBatchLogger): if key in requester_metadata and key not in extra_metadata: extra_metadata[key] = requester_metadata[key] + outputs = payload["response"] + if isinstance(outputs, dict): + outputs = {**outputs} + else: + outputs = {"output": outputs} + outputs["usage_metadata"] = { + "input_tokens": payload.get("prompt_tokens", 0), + "output_tokens": payload.get("completion_tokens", 0), + "total_tokens": payload.get("total_tokens", 0), + "total_cost": payload.get("response_cost", 0), + } + data = { "name": run_name, "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" "inputs": payload, - "outputs": payload["response"], + "outputs": outputs, "session_name": project_name, "start_time": payload["startTime"], "end_time": payload["endTime"], diff --git a/tests/test_litellm/integrations/test_langsmith_init.py b/tests/test_litellm/integrations/test_langsmith_init.py index 9f7db4095b..779e7b4c94 100644 --- a/tests/test_litellm/integrations/test_langsmith_init.py +++ b/tests/test_litellm/integrations/test_langsmith_init.py @@ -132,3 +132,57 @@ class TestLangsmithLoggerInit: assert ( logger.sampling_rate >= 0.0 ), f"sampling_rate should be non-negative, got {logger.sampling_rate}" + + +class TestLangsmithPrepareLogData: + """Regression test for #24001: _prepare_log_data must inject + usage_metadata into outputs so LangSmith's Cost column is populated.""" + + @patch("asyncio.create_task") + @patch.dict(os.environ, {"LANGSMITH_SAMPLING_RATE": "1"}, clear=False) + def test_outputs_contain_usage_metadata(self, mock_create_task): + logger = LangsmithLogger( + langsmith_api_key="test-key", + langsmith_project="test-project", + ) + + payload = { + "id": "test-id", + "response": {"choices": [{"message": {"content": "hi"}}]}, + "metadata": {}, + "startTime": 1.0, + "endTime": 2.0, + "request_tags": [], + "error_str": None, + "status": "success", + "response_cost": 0.0042, + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + + kwargs = { + "litellm_params": {"metadata": {}}, + "standard_logging_object": payload, + } + + credentials = { + "LANGSMITH_API_KEY": "test-key", + "LANGSMITH_PROJECT": "test-project", + "LANGSMITH_BASE_URL": "https://api.smith.langchain.com", + } + + data = logger._prepare_log_data( + kwargs=kwargs, + response_obj=None, + start_time=1.0, + end_time=2.0, + credentials=credentials, + ) + + assert "usage_metadata" in data["outputs"] + um = data["outputs"]["usage_metadata"] + assert um["total_cost"] == 0.0042 + assert um["input_tokens"] == 100 + assert um["output_tokens"] == 50 + assert um["total_tokens"] == 150 From eb7efa36daf03f23d81d137122c0b8e3f1575067 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 13:05:53 -0700 Subject: [PATCH 122/539] Add missing permission options to PERMISSION_OPTIONS list Adds /key/info, /key/list, /key/aliases, and /team/daily/activity to the hardcoded PERMISSION_OPTIONS in TeamSSOSettings.tsx. Co-Authored-By: Claude Opus 4.6 (1M context) --- ui/litellm-dashboard/src/components/TeamSSOSettings.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx b/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx index a9c07cdbcf..98745d388a 100644 --- a/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx +++ b/ui/litellm-dashboard/src/components/TeamSSOSettings.tsx @@ -26,6 +26,10 @@ const PERMISSION_OPTIONS = [ "/key/unblock", "/key/bulk_update", "/key/{key_id}/reset_spend", + "/key/info", + "/key/list", + "/key/aliases", + "/team/daily/activity", ]; interface SettingRowProps { From 51f78b7d7250a86f127d02fbe964968b2a8a41e2 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 13:09:17 -0700 Subject: [PATCH 123/539] address greptile review feedback (greploop iteration 4) - Add guard assertion before non-null click on custom code switch - Use await act(async ...) for timer advancement to avoid act warnings - Pin locale in date range assertion for CI determinism Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/EntityUsageExport/ExportSummary.test.tsx | 6 ++++-- .../components/GuardrailsMonitor/GuardrailConfig.test.tsx | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx index ab426e291b..1aeee42f74 100644 --- a/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx +++ b/ui/litellm-dashboard/src/components/EntityUsageExport/ExportSummary.test.tsx @@ -14,8 +14,10 @@ describe("ExportSummary", () => { it("should display formatted date range", () => { render(); - const text = screen.getByText(/\d+.*-.*\d+/); - expect(text).toBeInTheDocument(); + // Pin locale to en-US so test is deterministic regardless of CI runner locale + const expectedFrom = dateRange.from!.toLocaleDateString("en-US"); + const expectedTo = dateRange.to!.toLocaleDateString("en-US"); + expect(screen.getByText(`${expectedFrom} - ${expectedTo}`)).toBeInTheDocument(); }); it("should show singular 'filter' for one filter", () => { diff --git a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx index ac84c270d8..54c7ebabe7 100644 --- a/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx +++ b/ui/litellm-dashboard/src/components/GuardrailsMonitor/GuardrailConfig.test.tsx @@ -54,7 +54,10 @@ describe("GuardrailConfig", () => { customCodeSwitch = container.querySelector('[role="switch"]'); container = container.parentElement; } - await user.click(customCodeSwitch!); + if (!customCodeSwitch) { + throw new Error("Could not find the Custom Code Override switch via DOM traversal"); + } + await user.click(customCodeSwitch); expect(screen.getByPlaceholderText(/async def evaluate/)).toBeInTheDocument(); }); @@ -82,7 +85,7 @@ describe("GuardrailConfig", () => { const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); render(); await user.click(screen.getByRole("button", { name: /re-run on failing logs/i })); - act(() => { vi.advanceTimersByTime(2500); }); + await act(async () => { vi.advanceTimersByTime(2500); }); expect(screen.getByText(/7\/10 would now pass/)).toBeInTheDocument(); }); From 98311e0f0a57094eb8cba5b0ec82e5af671b890d Mon Sep 17 00:00:00 2001 From: Emerson Gomes Date: Wed, 18 Mar 2026 15:07:50 -0500 Subject: [PATCH 124/539] Preserve router model_group in generic API logs --- litellm/router.py | 8 ++++ .../test_router_helper_utils.py | 37 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index f34368172a..d2acd4c1f5 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3874,10 +3874,18 @@ class Router: The response from the handler function """ handler_name = original_function.__name__ + metadata_variable_name = _get_router_metadata_variable_name( + function_name="generic_api_call" + ) try: verbose_router_logger.debug( f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}" ) + self._update_kwargs_before_fallbacks( + model=model, + kwargs=kwargs, + metadata_variable_name=metadata_variable_name, + ) deployment = self.get_available_deployment( model=model, messages=kwargs.get("messages", None), diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 3001aec8b8..75c5250b3d 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1568,6 +1568,43 @@ def test_handle_clientside_credential_with_deployment_model_name(model_list): print("✓ _handle_clientside_credential test passed!") +def test_sync_generic_api_call_preserves_requested_model_group_in_logs(): + router = Router( + model_list=[ + { + "model_name": "claude-sonnet-4-6", + "litellm_params": { + "model": "bedrock/global.anthropic.claude-sonnet-4-6", + "aws_access_key_id": "test-access-key", + "aws_secret_access_key": "test-secret-key", + "aws_region_name": "us-west-2", + }, + } + ] + ) + + captured_kwargs = {} + + def mock_original_function(**kwargs): + captured_kwargs.update(kwargs) + return {"status": "ok"} + + response = router._generic_api_call_with_fallbacks( + model="claude-sonnet-4-6", + original_function=mock_original_function, + ) + + assert response == {"status": "ok"} + assert captured_kwargs["model"] == "bedrock/global.anthropic.claude-sonnet-4-6" + assert ( + captured_kwargs["litellm_metadata"]["model_group"] == "claude-sonnet-4-6" + ) + assert ( + captured_kwargs["litellm_metadata"]["deployment"] + == "bedrock/global.anthropic.claude-sonnet-4-6" + ) + + @pytest.mark.parametrize( "function_name, expected_metadata_key", [ From 845ad042913dd3d85d7f5d484fc0acc9ea990f14 Mon Sep 17 00:00:00 2001 From: Emerson Gomes Date: Wed, 18 Mar 2026 15:25:44 -0500 Subject: [PATCH 125/539] Address router generic API review feedback --- litellm/router.py | 1 + .../test_router_helper_utils.py | 80 +++++++++++++++---- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index d2acd4c1f5..5263e966ab 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3890,6 +3890,7 @@ class Router: model=model, messages=kwargs.get("messages", None), specific_deployment=kwargs.pop("specific_deployment", None), + request_kwargs=kwargs, ) self._update_kwargs_with_deployment( deployment=deployment, kwargs=kwargs, function_name="generic_api_call" diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 75c5250b3d..34a19f5ce7 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1583,26 +1583,74 @@ def test_sync_generic_api_call_preserves_requested_model_group_in_logs(): ] ) - captured_kwargs = {} + try: + captured_kwargs = {} - def mock_original_function(**kwargs): - captured_kwargs.update(kwargs) - return {"status": "ok"} + def mock_original_function(**kwargs): + captured_kwargs.update(kwargs) + return {"status": "ok"} - response = router._generic_api_call_with_fallbacks( - model="claude-sonnet-4-6", - original_function=mock_original_function, + response = router._generic_api_call_with_fallbacks( + model="claude-sonnet-4-6", + original_function=mock_original_function, + ) + + assert response == {"status": "ok"} + assert ( + captured_kwargs["model"] == "bedrock/global.anthropic.claude-sonnet-4-6" + ) + assert ( + captured_kwargs["litellm_metadata"]["model_group"] == "claude-sonnet-4-6" + ) + assert ( + captured_kwargs["litellm_metadata"]["deployment"] + == "bedrock/global.anthropic.claude-sonnet-4-6" + ) + finally: + router.discard() + + +def test_sync_generic_api_call_uses_request_kwargs_for_deployment_selection(): + router = Router( + model_list=[ + { + "model_name": "regional-model", + "litellm_params": { + "model": "anthropic/us-model", + "api_key": "test-api-key", + "region_name": "us", + }, + }, + { + "model_name": "regional-model", + "litellm_params": { + "model": "anthropic/eu-model", + "api_key": "test-api-key", + "region_name": "eu", + }, + }, + ], + enable_pre_call_checks=True, ) - assert response == {"status": "ok"} - assert captured_kwargs["model"] == "bedrock/global.anthropic.claude-sonnet-4-6" - assert ( - captured_kwargs["litellm_metadata"]["model_group"] == "claude-sonnet-4-6" - ) - assert ( - captured_kwargs["litellm_metadata"]["deployment"] - == "bedrock/global.anthropic.claude-sonnet-4-6" - ) + try: + captured_kwargs = {} + + def mock_original_function(**kwargs): + captured_kwargs.update(kwargs) + return {"status": "ok"} + + response = router._generic_api_call_with_fallbacks( + model="regional-model", + original_function=mock_original_function, + messages=[{"role": "user", "content": "Hello from Europe"}], + allowed_model_region="eu", + ) + + assert response == {"status": "ok"} + assert captured_kwargs["model"] == "anthropic/eu-model" + finally: + router.discard() @pytest.mark.parametrize( From b00096f2a08a5d56d2a1d84c663b833dcdda5f66 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 20:33:09 +0000 Subject: [PATCH 126/539] chore: regenerate poetry.lock to match pyproject.toml (#2) Co-authored-by: github-actions[bot] --- poetry.lock | 73 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/poetry.lock b/poetry.lock index 591d0c270e..be87e51f80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. [[package]] name = "a2a-sdk" @@ -7,11 +7,11 @@ description = "A2A Python SDK" optional = false python-versions = ">=3.10" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.10\"" files = [ {file = "a2a_sdk-0.3.22-py3-none-any.whl", hash = "sha256:b98701135bb90b0ff85d35f31533b6b7a299bf810658c1c65f3814a6c15ea385"}, {file = "a2a_sdk-0.3.22.tar.gz", hash = "sha256:77a5694bfc4f26679c11b70c7f1062522206d430b34bc1215cfbb1eba67b7e7d"}, ] +markers = {main = "python_version >= \"3.10\" and extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] google-api-core = ">=1.26.0" @@ -385,6 +385,7 @@ files = [ {file = "azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b"}, {file = "azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] requests = ">=2.21.0" @@ -405,6 +406,7 @@ files = [ {file = "azure_identity-1.25.1-py3-none-any.whl", hash = "sha256:e9edd720af03dff020223cd269fa3a61e8f345ea75443858273bcb44844ab651"}, {file = "azure_identity-1.25.1.tar.gz", hash = "sha256:87ca8328883de6036443e1c37b40e8dc8fb74898240f61071e09d2e369361456"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] azure-core = ">=1.31.0" @@ -598,7 +600,7 @@ files = [ {file = "cachetools-6.2.2-py3-none-any.whl", hash = "sha256:6c09c98183bf58560c97b2abfcedcbaf6a896a490f534b031b661d3723b45ace"}, {file = "cachetools-6.2.2.tar.gz", hash = "sha256:8e6d266b25e539df852251cfd6f990b4bc3a141db73b939058d809ebd2590fc6"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "certifi" @@ -705,7 +707,7 @@ files = [ {file = "cffi-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9"}, {file = "cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529"}, ] -markers = {main = "platform_python_implementation != \"PyPy\" or extra == \"proxy\"", dev = "platform_python_implementation != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\""} +markers = {main = "(platform_python_implementation != \"PyPy\" or extra == \"proxy\") and (python_version >= \"3.10\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\")", dev = "platform_python_implementation != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\""} [package.dependencies] pycparser = {version = "*", markers = "implementation_name != \"PyPy\""} @@ -1055,6 +1057,7 @@ files = [ {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, ] +markers = {main = "python_version >= \"3.10\" and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\") or extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} @@ -1837,11 +1840,11 @@ description = "Google API client core library" optional = false python-versions = ">=3.7" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.14\"" files = [ {file = "google_api_core-2.25.2-py3-none-any.whl", hash = "sha256:e9a8f62d363dc8424a8497f4c2a47d6bcda6c16514c935629c257ab5d10210e7"}, {file = "google_api_core-2.25.2.tar.gz", hash = "sha256:1c63aa6af0d0d5e37966f157a77f9396d820fba59f9e43e9415bc3dc5baff300"}, ] +markers = {main = "python_version >= \"3.14\" and (extra == \"extra-proxy\" or extra == \"google\")", proxy-dev = "python_version >= \"3.14\""} [package.dependencies] google-auth = ">=2.14.1,<3.0.0" @@ -1869,7 +1872,7 @@ files = [ {file = "google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c"}, {file = "google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8"}, ] -markers = {main = "(python_version >= \"3.10\" or extra == \"google\" or extra == \"extra-proxy\") and python_version < \"3.14\"", proxy-dev = "python_version >= \"3.10\" and python_version < \"3.14\""} +markers = {main = "python_version < \"3.14\" and (extra == \"extra-proxy\" or extra == \"google\")", proxy-dev = "python_version >= \"3.10\" and python_version < \"3.14\""} [package.dependencies] google-auth = ">=2.14.1,<3.0.0" @@ -1906,7 +1909,7 @@ files = [ {file = "google_auth-2.43.0-py2.py3-none-any.whl", hash = "sha256:af628ba6fa493f75c7e9dbe9373d148ca9f4399b5ea29976519e0a3848eddd16"}, {file = "google_auth-2.43.0.tar.gz", hash = "sha256:88228eee5fc21b62a1b5fe773ca15e67778cb07dc8363adcb4a8827b52d81483"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] cachetools = ">=2.0.0,<7.0" @@ -2078,11 +2081,11 @@ files = [ ] [package.dependencies] -google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} -google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" -grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" -proto-plus = ">=1.22.3,<2.0.0dev" -protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0.dev0", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +grpc-google-iam-v1 = ">=0.12.4,<1.0.0.dev0" +proto-plus = ">=1.22.3,<2.0.0.dev0" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" [[package]] name = "google-cloud-resource-manager" @@ -2264,7 +2267,7 @@ files = [ {file = "googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038"}, {file = "googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\") or extra == \"google\" or extra == \"extra-proxy\""} [package.dependencies] grpcio = {version = ">=1.44.0,<2.0.0", optional = true, markers = "extra == \"grpc\""} @@ -2673,11 +2676,11 @@ description = "Consume Server-Sent Event (SSE) messages with HTTPX." optional = false python-versions = ">=3.9" groups = ["main", "proxy-dev"] -markers = "python_version >= \"3.10\"" files = [ {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, ] +markers = {main = "python_version >= \"3.10\" and (extra == \"proxy\" or extra == \"extra-proxy\")", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "huey" @@ -3042,7 +3045,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.03.6" +jsonschema-specifications = ">=2023.3.6" referencing = ">=0.28.4" rpds-py = ">=0.7.1" @@ -3219,15 +3222,15 @@ files = [ [[package]] name = "litellm-proxy-extras" -version = "0.4.56" +version = "0.4.57" description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package." optional = true python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" groups = ["main"] markers = "extra == \"proxy\"" files = [ - {file = "litellm_proxy_extras-0.4.56-py3-none-any.whl", hash = "sha256:52dbe3b5358c790e77e12f1ec5ef8e7508b383c2aaf41299750b6fb400908ee7"}, - {file = "litellm_proxy_extras-0.4.56.tar.gz", hash = "sha256:63ad59baa0defccc5c929cfd933ee7e32a6614b0fc5fa0fc45a12d7608e33f08"}, + {file = "litellm_proxy_extras-0.4.57-py3-none-any.whl", hash = "sha256:04538223cd80318a72d70c6e10f701598e58c763368296a6503c674c92fbdb62"}, + {file = "litellm_proxy_extras-0.4.57.tar.gz", hash = "sha256:ef9b95dc42237614216833bd5d46ebf9dea1caa5ea14ea1a66d7f7842b224ec2"}, ] [[package]] @@ -3713,6 +3716,7 @@ files = [ {file = "msal-1.34.0-py3-none-any.whl", hash = "sha256:f669b1644e4950115da7a176441b0e13ec2975c29528d8b9e81316023676d6e1"}, {file = "msal-1.34.0.tar.gz", hash = "sha256:76ba83b716ea5a6d75b0279c0ac353a0e05b820ca1f6682c0eb7f45190c43c2f"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] cryptography = ">=2.5,<49" @@ -3733,6 +3737,7 @@ files = [ {file = "msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca"}, {file = "msal_extensions-1.3.1.tar.gz", hash = "sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4"}, ] +markers = {main = "extra == \"proxy\" or extra == \"extra-proxy\""} [package.dependencies] msal = ">=1.29,<2" @@ -3983,6 +3988,7 @@ files = [ {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] +markers = {main = "extra == \"extra-proxy\""} [[package]] name = "numpy" @@ -4105,7 +4111,7 @@ files = [ {file = "opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950"}, {file = "opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] importlib-metadata = ">=6.0,<8.8.0" @@ -4220,7 +4226,7 @@ files = [ {file = "opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c"}, {file = "opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] opentelemetry-api = "1.39.1" @@ -4238,7 +4244,7 @@ files = [ {file = "opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb"}, {file = "opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953"}, ] -markers = {main = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and extra == \"mlflow\""} [package.dependencies] opentelemetry-api = "1.39.1" @@ -4722,6 +4728,7 @@ files = [ {file = "prisma-0.11.0-py3-none-any.whl", hash = "sha256:22bb869e59a2968b99f3483bb417717273ffbc569fd1e9ceed95e5614cbaf53a"}, {file = "prisma-0.11.0.tar.gz", hash = "sha256:3f2f2fd2361e1ec5ff655f2a04c7860c2f2a5bc4c91f78ca9c5c6349735bf693"}, ] +markers = {main = "extra == \"extra-proxy\""} [package.dependencies] click = ">=7.1.2" @@ -4895,7 +4902,7 @@ files = [ {file = "proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66"}, {file = "proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] protobuf = ">=3.19.0,<7.0.0" @@ -4923,7 +4930,7 @@ files = [ {file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"}, {file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\""} [[package]] name = "psutil" @@ -5083,7 +5090,7 @@ files = [ {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [[package]] name = "pyasn1-modules" @@ -5096,7 +5103,7 @@ files = [ {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"}, {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] pyasn1 = ">=0.6.1,<0.7.0" @@ -5124,7 +5131,7 @@ files = [ {file = "pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934"}, {file = "pycparser-2.23.tar.gz", hash = "sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2"}, ] -markers = {main = "implementation_name != \"PyPy\" and (platform_python_implementation != \"PyPy\" or extra == \"proxy\")", dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\""} +markers = {main = "implementation_name != \"PyPy\" and (platform_python_implementation != \"PyPy\" or extra == \"proxy\") and (python_version >= \"3.10\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"proxy\" or extra == \"extra-proxy\" or extra == \"mlflow\")", dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\"", proxy-dev = "platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\""} [[package]] name = "pydantic" @@ -5347,6 +5354,7 @@ files = [ {file = "pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c"}, {file = "pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b"}, ] +markers = {main = "(python_version <= \"3.13\" or extra == \"proxy\" or extra == \"extra-proxy\") and (extra == \"extra-proxy\" or extra == \"proxy\")"} [package.dependencies] cryptography = {version = ">=3.4.0", optional = true, markers = "extra == \"crypto\""} @@ -6290,7 +6298,7 @@ files = [ {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"}, {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"}, ] -markers = {main = "extra == \"google\" or extra == \"extra-proxy\" or python_version >= \"3.10\"", proxy-dev = "python_version >= \"3.10\""} +markers = {main = "python_version >= \"3.10\" and (extra == \"extra-proxy\" or extra == \"google\" or extra == \"mlflow\") or extra == \"google\" or extra == \"extra-proxy\"", proxy-dev = "python_version >= \"3.10\""} [package.dependencies] pyasn1 = ">=0.1.3" @@ -6336,10 +6344,10 @@ files = [ ] [package.dependencies] -botocore = ">=1.37.4,<2.0a.0" +botocore = ">=1.37.4,<2.0a0" [package.extras] -crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] +crt = ["botocore[crt] (>=1.37.4,<2.0a0)"] [[package]] name = "scikit-learn" @@ -6492,9 +6500,9 @@ tornado = ">=6.4.2,<7" urllib3 = ">=1.26,<3" [package.extras] -all = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)", "cohere (>=5.9.4,<6.00)", "dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\"", "google-cloud-aiplatform (>=1.45.0,<2)", "ipykernel (>=6.25.0,<7)", "llama-cpp-python (>=0.2.28,<0.2.86) ; python_version < \"3.13\"", "mistralai (>=0.0.12,<0.1.0)", "mypy (>=1.7.1,<2)", "ollama (>=0.1.7)", "pillow (>=10.2.0,<11.0.0) ; python_version < \"3.13\"", "pinecone[asyncio] (>=7.0.0,<8.0.0)", "psycopg[binary] (>=3.1.0,<4)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "qdrant-client (>=1.11.1,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "sentence-transformers (>=5.0.0) ; python_version < \"3.13\"", "tokenizers (>=0.19) ; python_version < \"3.13\"", "torch (>=2.6.0) ; python_version < \"3.13\"", "torchvision (>=0.17.0) ; python_version < \"3.13\"", "transformers (>=4.36.2) ; python_version < \"3.13\"", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] +all = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)", "cohere (>=5.9.4,<6.0)", "dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\"", "google-cloud-aiplatform (>=1.45.0,<2)", "ipykernel (>=6.25.0,<7)", "llama-cpp-python (>=0.2.28,<0.2.86) ; python_version < \"3.13\"", "mistralai (>=0.0.12,<0.1.0)", "mypy (>=1.7.1,<2)", "ollama (>=0.1.7)", "pillow (>=10.2.0,<11.0.0) ; python_version < \"3.13\"", "pinecone[asyncio] (>=7.0.0,<8.0.0)", "psycopg[binary] (>=3.1.0,<4)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "qdrant-client (>=1.11.1,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "sentence-transformers (>=5.0.0) ; python_version < \"3.13\"", "tokenizers (>=0.19) ; python_version < \"3.13\"", "torch (>=2.6.0) ; python_version < \"3.13\"", "torchvision (>=0.17.0) ; python_version < \"3.13\"", "transformers (>=4.36.2) ; python_version < \"3.13\"", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] bedrock = ["boto3 (>=1.34.98,<2)", "botocore (>=1.34.110,<2)"] -cohere = ["cohere (>=5.9.4,<6.00)"] +cohere = ["cohere (>=5.9.4,<6.0)"] dev = ["dagger-io (>=0.1.1) ; python_version >= \"3.11\"", "ipykernel (>=6.25.0,<7)", "mypy (>=1.7.1,<2)", "pytest (>=8.2,<9.0)", "pytest-asyncio (>=0.24.0,<0.25)", "pytest-cov (>=4.1.0,<5)", "pytest-mock (>=3.12.0,<4)", "pytest-timeout", "pytest-xdist (>=3.5.0,<4)", "python-dotenv (>=1.0.0,<2)", "requests-mock (>=1.12.1,<2)", "ruff (>=0.11.2,<0.12)", "types-pyyaml (>=6.0.12.12,<7)", "types-requests (>=2.31.0,<3)"] docs = ["pydoc-markdown (>=4.8.2) ; python_version < \"3.12\""] fastembed = ["fastembed (>=0.3.0,<0.4) ; python_version < \"3.13\""] @@ -7222,6 +7230,7 @@ files = [ {file = "tomlkit-0.13.3-py3-none-any.whl", hash = "sha256:c89c649d79ee40629a9fda55f8ace8c6a1b42deb912b2a8fd8d942ddadb606b0"}, {file = "tomlkit-0.13.3.tar.gz", hash = "sha256:430cf247ee57df2b94ee3fbe588e71d362a941ebb545dec29b53961d61add2a1"}, ] +markers = {main = "extra == \"extra-proxy\""} [[package]] name = "tornado" @@ -7994,4 +8003,4 @@ utils = ["numpydoc"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "1f3bbf967451633fb6290ba88980bdf4fbf83420024b14e862d1da717d903684" +content-hash = "22cdc8096e0c8296827734393f5ab6e66088f397b5295caa1d277466d1fde1e8" From 71b687e00a43486ae3171102772a3be02f72ba26 Mon Sep 17 00:00:00 2001 From: Alexey <5122340@mail.ru> Date: Wed, 18 Mar 2026 23:45:57 +0300 Subject: [PATCH 127/539] fix(proxy): sync normalized call_type into model_call_details for proxy-only errors --- litellm/proxy/utils.py | 13 ++-- tests/proxy_unit_tests/test_proxy_utils.py | 69 ++++++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 76662be175..e9946ebf97 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1863,28 +1863,33 @@ class ProxyLogging: ) input: Union[list, str, dict] = "" + normalized_call_type: Optional[str] = None if "messages" in request_data and isinstance( request_data["messages"], list ): input = request_data["messages"] litellm_logging_obj.model_call_details["messages"] = input if litellm_logging_obj.call_type != CallTypes.pass_through.value: - litellm_logging_obj.call_type = CallTypes.acompletion.value + normalized_call_type = CallTypes.acompletion.value elif "prompt" in request_data and isinstance(request_data["prompt"], str): input = request_data["prompt"] litellm_logging_obj.model_call_details["prompt"] = input if litellm_logging_obj.call_type != CallTypes.pass_through.value: - litellm_logging_obj.call_type = CallTypes.atext_completion.value + normalized_call_type = CallTypes.atext_completion.value elif "input" in request_data and isinstance(request_data["input"], list): input = request_data["input"] litellm_logging_obj.model_call_details["input"] = input if litellm_logging_obj.call_type != CallTypes.pass_through.value: - litellm_logging_obj.call_type = CallTypes.aembedding.value + normalized_call_type = CallTypes.aembedding.value + if normalized_call_type is not None: + litellm_logging_obj.call_type = normalized_call_type + litellm_logging_obj.model_call_details["call_type"] = ( + normalized_call_type + ) # Pass-through endpoints are logged via the callback loop's # async_post_call_failure_hook — skip pre_call and failure handlers. if litellm_logging_obj.call_type == CallTypes.pass_through.value: return - litellm_logging_obj.pre_call( input=input, api_key="", diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index d77a7465c1..00d4cd24e4 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2278,6 +2278,75 @@ async def test_post_call_failure_hook_auth_error_llm_api_route(): mock_handle_logging.assert_called_once() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "request_data, route, expected_call_type", + [ + ( + {"model": "bad-model", "messages": [{"role": "user", "content": "hello"}]}, + "/v1/chat/completions", + "acompletion", + ), + ( + {"model": "bad-model", "prompt": "hello"}, + "/v1/completions", + "atext_completion", + ), + ( + {"model": "bad-model", "input": ["hello"]}, + "/v1/embeddings", + "aembedding", + ), + ], +) +async def test_handle_logging_proxy_only_error_syncs_normalized_call_type( + request_data, route, expected_call_type +): + from fastapi import HTTPException + + from litellm.caching.caching import DualCache + from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.utils import ProxyLogging + + cache = DualCache() + proxy_logging = ProxyLogging(user_api_key_cache=cache) + captured_logging_obj = {} + original_function_setup = litellm.utils.function_setup + + def _capture_function_setup(*args, **kwargs): + logging_obj, data = original_function_setup(*args, **kwargs) + captured_logging_obj["logging_obj"] = logging_obj + return logging_obj, data + + with patch( + "litellm.proxy.utils.litellm.utils.function_setup", + side_effect=_capture_function_setup, + ), patch.object( + Logging, "async_failure_handler", new=AsyncMock(return_value=None) + ), patch.object( + Logging, "failure_handler", return_value=None + ), patch( + "litellm.proxy.utils.threading.Thread" + ) as mock_thread: + mock_thread.return_value.start = Mock() + + await proxy_logging._handle_logging_proxy_only_error( + request_data=request_data, + user_api_key_dict=UserAPIKeyAuth( + api_key="test_key", + user_id="test_user", + token="test_token", + request_route=route, + ), + route=route, + original_exception=HTTPException(status_code=400, detail="bad request"), + ) + + logging_obj = captured_logging_obj["logging_obj"] + assert logging_obj.call_type == expected_call_type + assert logging_obj.model_call_details["call_type"] == expected_call_type + + @pytest.mark.asyncio async def test_during_call_hook_parallel_execution(): """ From 3de0a2a834c00f740f02de11db8f8b394efe4f69 Mon Sep 17 00:00:00 2001 From: Milan Date: Wed, 18 Mar 2026 22:59:11 +0000 Subject: [PATCH 128/539] docs: encrypted_content_affinity requires LiteLLM >= 1.82.3 Made-with: Cursor --- docs/my-website/docs/proxy/config_settings.md | 2 +- docs/my-website/docs/proxy/load_balancing.md | 2 +- docs/my-website/docs/response_api.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 1d6fc1b03b..ae5f5a4d0e 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -361,7 +361,7 @@ router_settings: | redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** | | cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. | | router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) | -| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Supported: `router_budget_limiting`, `prompt_caching`, `responses_api_deployment_check`, `encrypted_content_affinity` (**requires LiteLLM >= 1.82.1**), `deployment_affinity`, `session_affinity`, `forward_client_headers_by_model_group` | +| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Supported: `router_budget_limiting`, `prompt_caching`, `responses_api_deployment_check`, `encrypted_content_affinity` (requires LiteLLM >= 1.82.3), `deployment_affinity`, `session_affinity`, `forward_client_headers_by_model_group` | | deployment_affinity_ttl_seconds | int | TTL (seconds) for user-key → deployment affinity mapping when `deployment_affinity` is enabled (configured at Router init / proxy startup). Defaults to `3600` (1 hour). | | ignore_invalid_deployments | boolean | If true, ignores invalid deployments. Default for proxy is True - to prevent invalid models from blocking other models from being loaded. | | search_tools | List[SearchToolTypedDict] | List of search tool configurations for Search API integration. Each tool specifies a search_tool_name and litellm_params with search_provider, api_key, api_base, etc. [Further Docs](../search.md) | diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 313df99b25..74b3e8a511 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -352,7 +352,7 @@ If `order=1` deployment is unavailable (e.g., rate-limited), the router falls ba When load balancing OpenAI's Responses API across deployments with **different API keys** (e.g., different Azure regions or organizations), encrypted content items (like `rs_...` reasoning items) can only be decrypted by the originating API key. -**Solution:** Use the `encrypted_content_affinity` pre-call check (**requires LiteLLM >= 1.82.1**) to automatically route follow-up requests containing encrypted items to the correct deployment: +**Solution:** Use the `encrypted_content_affinity` pre-call check (requires LiteLLM >= 1.82.3) to automatically route follow-up requests containing encrypted items to the correct deployment: ```yaml model_list: diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index 66aa2e1ad9..84f662812f 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -1160,12 +1160,12 @@ follow_up = await router.aresponses( To enable session continuity for Responses API in your LiteLLM proxy, set `optional_pre_call_checks` in your proxy config.yaml. - `responses_api_deployment_check`: high priority routing when `previous_response_id` is provided -- `encrypted_content_affinity`: **[Recommended]** content-aware routing for encrypted items (e.g., `rs_...` reasoning items) (**requires LiteLLM >= 1.82.1**) +- `encrypted_content_affinity`: **[Recommended]** content-aware routing for encrypted items (e.g., `rs_...` reasoning items) (**requires LiteLLM >= 1.82.3**) - `session_affinity`: sticky sessions based on session id (takes priority over `deployment_affinity`) - `deployment_affinity`: sticky sessions based on user key (applies even without `previous_response_id`) :::tip Recommended: Use `encrypted_content_affinity` -For Responses API with load balancing across deployments with **different API keys**, use `encrypted_content_affinity` instead of `deployment_affinity`. It only pins requests that contain encrypted content, avoiding quota reduction while preventing `invalid_encrypted_content` errors. +For Responses API with load balancing across deployments with **different API keys**, use `encrypted_content_affinity` instead of `deployment_affinity`. It only pins requests that contain encrypted content, avoiding quota reduction while preventing `invalid_encrypted_content` errors. (Requires LiteLLM >= 1.82.3.) ::: Notes: From 3ba18d708472869f7d6f7bbdb37c1cc2b471b56a Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 16:38:58 -0700 Subject: [PATCH 129/539] [Refactor] UI - Playground: Extract ChatMessageBubble from ChatUI Extract the chat message bubble rendering (~165 lines) into a dedicated ChatMessageBubble component with 15 Vitest tests covering all display branches. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../chat_ui/ChatMessageBubble.test.tsx | 280 ++++++++++++++++++ .../playground/chat_ui/ChatMessageBubble.tsx | 214 +++++++++++++ .../components/playground/chat_ui/ChatUI.tsx | 171 +---------- 3 files changed, 503 insertions(+), 162 deletions(-) create mode 100644 ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx create mode 100644 ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx new file mode 100644 index 0000000000..368043abee --- /dev/null +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx @@ -0,0 +1,280 @@ +import { render, screen } from "@testing-library/react"; +import { describe, it, expect, vi } from "vitest"; +import ChatMessageBubble from "./ChatMessageBubble"; +import { EndpointType } from "./mode_endpoint_mapping"; +import { MessageType } from "./types"; + +// Mock child components to isolate bubble rendering logic +vi.mock("react-markdown", () => ({ + default: ({ children }: { children: string }) =>
{children}
, +})); + +vi.mock("react-syntax-highlighter", () => ({ + Prism: ({ children }: { children: string }) =>
{children}
, +})); + +vi.mock("react-syntax-highlighter/dist/esm/styles/prism", () => ({ + coy: {}, +})); + +vi.mock("./ReasoningContent", () => ({ + default: ({ reasoningContent }: { reasoningContent: string }) => ( +
{reasoningContent}
+ ), +})); + +vi.mock("./MCPEventsDisplay", () => ({ + default: ({ events }: { events: unknown[] }) => ( +
{events.length} events
+ ), +})); + +vi.mock("./SearchResultsDisplay", () => ({ + SearchResultsDisplay: ({ searchResults }: { searchResults: unknown[] }) => ( +
{searchResults.length} results
+ ), +})); + +vi.mock("./ResponseMetrics", () => ({ + default: ({ timeToFirstToken }: { timeToFirstToken?: number }) => ( +
TTFT: {timeToFirstToken}
+ ), +})); + +vi.mock("./A2AMetrics", () => ({ + default: ({ a2aMetadata }: { a2aMetadata: unknown }) => ( +
A2A
+ ), +})); + +vi.mock("./CodeInterpreterOutput", () => ({ + default: ({ code }: { code: string }) =>
{code}
, +})); + +vi.mock("./AudioRenderer", () => ({ + default: ({ message }: { message: MessageType }) => ( +
{typeof message.content === "string" ? message.content : ""}
+ ), +})); + +vi.mock("./ResponsesImageRenderer", () => ({ + default: () =>
, +})); + +vi.mock("./ChatImageRenderer", () => ({ + default: () =>
, +})); + +const defaultProps = { + isLastMessage: false, + endpointType: EndpointType.CHAT, + mcpEvents: [], + codeInterpreterResult: null, + accessToken: "test-token", +}; + +describe("ChatMessageBubble", () => { + it("should render a user message with right-aligned text", () => { + render( + , + ); + + expect(screen.getByText("user")).toBeInTheDocument(); + expect(screen.getByText("Hello")).toBeInTheDocument(); + }); + + it("should render an assistant message with left-aligned text", () => { + render( + , + ); + + expect(screen.getByText("assistant")).toBeInTheDocument(); + expect(screen.getByText("Hi there")).toBeInTheDocument(); + }); + + it("should show model badge for assistant messages when model is provided", () => { + render( + , + ); + + expect(screen.getByText("gpt-4")).toBeInTheDocument(); + }); + + it("should not show model badge for user messages even when model is set", () => { + render( + , + ); + + expect(screen.queryByText("gpt-4")).not.toBeInTheDocument(); + }); + + it("should render markdown content via ReactMarkdown", () => { + render( + , + ); + + expect(screen.getByTestId("react-markdown")).toHaveTextContent("**bold text**"); + }); + + it("should render an image when isImage is true", () => { + render( + , + ); + + expect(screen.getByAltText("Generated image")).toHaveAttribute("src", "https://example.com/img.png"); + }); + + it("should render AudioRenderer when isAudio is true", () => { + render( + , + ); + + expect(screen.getByTestId("audio-renderer")).toBeInTheDocument(); + }); + + it("should show ReasoningContent when reasoningContent is present", () => { + render( + , + ); + + expect(screen.getByTestId("reasoning-content")).toHaveTextContent("thinking..."); + }); + + it("should show MCP events on the last assistant message for RESPONSES endpoint", () => { + const mcpEvents = [{ type: "tool_call", item_id: "1" }]; + + render( + , + ); + + expect(screen.getByTestId("mcp-events-display")).toHaveTextContent("1 events"); + }); + + it("should not show MCP events when isLastMessage is false", () => { + const mcpEvents = [{ type: "tool_call", item_id: "1" }]; + + render( + , + ); + + expect(screen.queryByTestId("mcp-events-display")).not.toBeInTheDocument(); + }); + + it("should show SearchResultsDisplay when searchResults are present", () => { + render( + , + ); + + expect(screen.getByTestId("search-results-display")).toBeInTheDocument(); + }); + + it("should show ResponseMetrics when usage data is present and no a2aMetadata", () => { + render( + , + ); + + expect(screen.getByTestId("response-metrics")).toBeInTheDocument(); + }); + + it("should show A2AMetrics when a2aMetadata is present instead of ResponseMetrics", () => { + render( + , + ); + + expect(screen.getByTestId("a2a-metrics")).toBeInTheDocument(); + expect(screen.queryByTestId("response-metrics")).not.toBeInTheDocument(); + }); + + it("should show CodeInterpreterOutput on the last assistant message for RESPONSES endpoint", () => { + render( + , + ); + + expect(screen.getByTestId("code-interpreter-output")).toHaveTextContent("print('hello')"); + }); + + it("should render generated image from chat completions via message.image", () => { + render( + , + ); + + const images = screen.getAllByAltText("Generated image"); + expect(images.some((img) => img.getAttribute("src") === "https://example.com/generated.png")).toBe(true); + }); +}); diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx new file mode 100644 index 0000000000..3e1913cb2b --- /dev/null +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx @@ -0,0 +1,214 @@ +import { RobotOutlined, UserOutlined } from "@ant-design/icons"; +import React from "react"; +import ReactMarkdown from "react-markdown"; +import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; +import { coy } from "react-syntax-highlighter/dist/esm/styles/prism"; +import { CodeInterpreterResult } from "../llm_calls/code_interpreter_handler"; +import A2AMetrics from "./A2AMetrics"; +import AudioRenderer from "./AudioRenderer"; +import ChatImageRenderer from "./ChatImageRenderer"; +import CodeInterpreterOutput from "./CodeInterpreterOutput"; +import { EndpointType } from "./mode_endpoint_mapping"; +import MCPEventsDisplay from "./MCPEventsDisplay"; +import type { MCPEvent } from "../../mcp_tools/types"; +import ReasoningContent from "./ReasoningContent"; +import ResponseMetrics from "./ResponseMetrics"; +import ResponsesImageRenderer from "./ResponsesImageRenderer"; +import { SearchResultsDisplay } from "./SearchResultsDisplay"; +import { MessageType } from "./types"; + +interface ChatMessageBubbleProps { + message: MessageType; + /** Whether this is the last message in the chat history. */ + isLastMessage: boolean; + endpointType: string; + /** MCP events to display on the last assistant message. */ + mcpEvents: MCPEvent[]; + /** Code interpreter result to display on the last assistant message. */ + codeInterpreterResult: CodeInterpreterResult | null; + /** API key used to fetch code interpreter file downloads. */ + accessToken: string; +} + +function ChatMessageBubble({ + message, + isLastMessage, + endpointType, + mcpEvents, + codeInterpreterResult, + accessToken, +}: ChatMessageBubbleProps) { + const isUser = message.role === "user"; + + return ( +
+
+ {/* Header: role icon + name + model badge */} +
+
+ {isUser ? ( + + ) : ( + + )} +
+ {message.role} + {message.role === "assistant" && message.model && ( + + {message.model} + + )} +
+ + {/* Reasoning content (chain-of-thought) */} + {message.reasoningContent && } + + {/* MCP events at the start of the last assistant message */} + {message.role === "assistant" && + isLastMessage && + mcpEvents.length > 0 && + (endpointType === EndpointType.RESPONSES || endpointType === EndpointType.CHAT) && ( +
+ +
+ )} + + {/* Search results */} + {message.role === "assistant" && message.searchResults && ( + + )} + + {/* Code Interpreter output for the last assistant message */} + {message.role === "assistant" && + isLastMessage && + codeInterpreterResult && + endpointType === EndpointType.RESPONSES && ( + + )} + + {/* Message body */} +
+ {message.isImage ? ( + Generated image + ) : message.isAudio ? ( + + ) : ( + <> + {/* Attached image for user messages based on endpoint */} + {endpointType === EndpointType.RESPONSES && } + {endpointType === EndpointType.CHAT && } + + & { + inline?: boolean; + node?: unknown; + }) { + const match = /language-(\w+)/.exec(className || ""); + return !inline && match ? ( + } + language={match[1]} + PreTag="div" + className="rounded-md my-2" + wrapLines={true} + wrapLongLines={true} + {...props} + > + {String(children).replace(/\n$/, "")} + + ) : ( + + {children} + + ); + }, + pre: ({ node, ...props }) => ( +
+                  ),
+                }}
+              >
+                {typeof message.content === "string" ? message.content : ""}
+              
+
+              {/* Generated image from chat completions */}
+              {message.image && (
+                
+ Generated image +
+ )} + + )} + + {/* Response metrics */} + {message.role === "assistant" && + (message.timeToFirstToken || message.totalLatency || message.usage) && + !message.a2aMetadata && ( + + )} + + {/* A2A Metrics */} + {message.role === "assistant" && message.a2aMetadata && ( + + )} +
+
+
+ ); +} + +export default ChatMessageBubble; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx index bc0d52cf58..ec37a9e1e9 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx @@ -63,6 +63,7 @@ import EndpointSelector from "./EndpointSelector"; import FilePreviewCard from "./FilePreviewCard"; import MCPEventsDisplay from "./MCPEventsDisplay"; import type { MCPEvent } from "../../mcp_tools/types"; +import ChatMessageBubble from "./ChatMessageBubble"; import { EndpointType, getEndpointType } from "./mode_endpoint_mapping"; import ReasoningContent from "./ReasoningContent"; import ResponseMetrics, { TokenUsage } from "./ResponseMetrics"; @@ -1932,168 +1933,14 @@ const ChatUI: React.FC = ({ {chatHistory.map((message, index) => (
-
-
-
-
- {message.role === "user" ? ( - - ) : ( - - )} -
- {message.role} - {message.role === "assistant" && message.model && ( - - {message.model} - - )} -
- {message.reasoningContent && } - - {/* Show MCP events at the start of assistant messages */} - {message.role === "assistant" && - index === chatHistory.length - 1 && - mcpEvents.length > 0 && - (endpointType === EndpointType.RESPONSES || endpointType === EndpointType.CHAT) && ( -
- -
- )} - - {/* Show search results at the start of assistant messages */} - {message.role === "assistant" && message.searchResults && ( - - )} - - {/* Show Code Interpreter output for the last assistant message */} - {message.role === "assistant" && - index === chatHistory.length - 1 && - codeInterpreter.result && - endpointType === EndpointType.RESPONSES && ( - - )} - -
- {message.isImage ? ( - Generated image - ) : message.isAudio ? ( - - ) : ( - <> - {/* Show attached image for user messages based on current endpoint */} - {endpointType === EndpointType.RESPONSES && } - {endpointType === EndpointType.CHAT && } - - & { - inline?: boolean; - node?: any; - }) { - const match = /language-(\w+)/.exec(className || ""); - return !inline && match ? ( - - {String(children).replace(/\n$/, "")} - - ) : ( - - {children} - - ); - }, - pre: ({ node, ...props }) => ( -
-                                ),
-                              }}
-                            >
-                              {typeof message.content === "string" ? message.content : ""}
-                            
-
-                            {/* Show generated image from chat completions */}
-                            {message.image && (
-                              
- Generated image -
- )} - - )} - - {message.role === "assistant" && - (message.timeToFirstToken || message.totalLatency || message.usage) && - !message.a2aMetadata && ( - - )} - - {/* A2A Metrics - show for A2A agent responses */} - {message.role === "assistant" && message.a2aMetadata && ( - - )} -
-
-
+
))} From b55cb249fe276996e9ca351e31d08e7cff4a9542 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 16:46:41 -0700 Subject: [PATCH 130/539] Address Greptile feedback: use EndpointType enum, add CHAT MCP test - Narrow endpointType prop from string to EndpointType enum - Add missing test for MCP events on CHAT endpoint Co-Authored-By: Claude Opus 4.6 (1M context) --- .../chat_ui/ChatMessageBubble.test.tsx | 16 ++++++++++++++++ .../playground/chat_ui/ChatMessageBubble.tsx | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx index 368043abee..70c3fdd4f2 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.test.tsx @@ -180,6 +180,22 @@ describe("ChatMessageBubble", () => { expect(screen.getByTestId("mcp-events-display")).toHaveTextContent("1 events"); }); + it("should show MCP events on the last assistant message for CHAT endpoint", () => { + const mcpEvents = [{ type: "tool_call", item_id: "1" }]; + + render( + , + ); + + expect(screen.getByTestId("mcp-events-display")).toHaveTextContent("1 events"); + }); + it("should not show MCP events when isLastMessage is false", () => { const mcpEvents = [{ type: "tool_call", item_id: "1" }]; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx index 3e1913cb2b..148cd082e2 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx @@ -21,7 +21,7 @@ interface ChatMessageBubbleProps { message: MessageType; /** Whether this is the last message in the chat history. */ isLastMessage: boolean; - endpointType: string; + endpointType: EndpointType; /** MCP events to display on the last assistant message. */ mcpEvents: MCPEvent[]; /** Code interpreter result to display on the last assistant message. */ From f6cd0a827ae84cffa838eac859eaea504bd6464a Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 16:53:29 -0700 Subject: [PATCH 131/539] fix: /key/update returns 404 (not 401) for nonexistent body key The /key/update endpoint's get_data() call raises a 401 when the body `key` field doesn't exist in the DB, because get_data() treats the token as an auth credential. This caused the auth layer to resolve the body key instead of the Authorization header bearer token. Replace prisma_client.get_data() with direct Prisma find_unique() in both _get_and_validate_existing_key() and update_key_fn(), matching the pattern used in the /key/block and /key/unblock fix (PR #23977). Also fix the incorrect "Team not found" error message in update_key_fn. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../key_management_endpoints.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ec33816b78..2884c8d374 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1662,16 +1662,23 @@ async def _get_and_validate_existing_key( detail={"error": "Database not connected"}, ) - existing_key_row = await prisma_client.get_data( - token=token, - table_name="key", - query_type="find_unique", + from litellm.proxy.proxy_server import hash_token + + if token.startswith("sk-"): + hashed_token = hash_token(token=token) + else: + hashed_token = token + + existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} ) if existing_key_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Key not found: {token}"}, + raise ProxyException( + message=f"Key not found. Passed key={token}", + type=ProxyErrorTypes.not_found_error, + param="key", + code=status.HTTP_404_NOT_FOUND, ) return existing_key_row @@ -2112,14 +2119,23 @@ async def update_key_fn( if prisma_client is None: raise Exception("Not connected to DB!") - existing_key_row = await prisma_client.get_data( - token=data.key, table_name="key", query_type="find_unique" + from litellm.proxy.proxy_server import hash_token + + if data.key.startswith("sk-"): + hashed_token = hash_token(token=data.key) + else: + hashed_token = data.key + + existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token} ) if existing_key_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={data.team_id}"}, + raise ProxyException( + message=f"Key not found. Passed key={data.key}", + type=ProxyErrorTypes.not_found_error, + param="key", + code=status.HTTP_404_NOT_FOUND, ) await _validate_update_key_data( From ebe329cdce15edfd3f0c783ae2bad9053a66bd83 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 17:02:18 -0700 Subject: [PATCH 132/539] Fix build: use `as any` for SyntaxHighlighter style prop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Matches the cast used in ChatUI.tsx — the react-syntax-highlighter type definitions don't accept CSSProperties directly. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/components/playground/chat_ui/ChatMessageBubble.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx index 148cd082e2..15978c17f7 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatMessageBubble.tsx @@ -143,7 +143,7 @@ function ChatMessageBubble({ const match = /language-(\w+)/.exec(className || ""); return !inline && match ? ( } + style={coy as any} language={match[1]} PreTag="div" className="rounded-md my-2" From eceb4981b851659e6ca2152660216577e5f06feb Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 17:03:43 -0700 Subject: [PATCH 133/539] fix: address review feedback - dedup logic, use module-level helper, add test - Deduplicate: update_key_fn now delegates to _get_and_validate_existing_key() instead of inlining its own copy of the lookup logic - Use _hash_token_if_needed (already imported at module level) instead of inline `from proxy_server import hash_token` + manual conditional - Fix stale docstring: _get_and_validate_existing_key raises ProxyException, not HTTPException - Add unit test: test_update_key_nonexistent_key_returns_404 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../key_management_endpoints.py | 32 ++---------- .../test_key_management_endpoints.py | 50 +++++++++++++++++++ 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 2884c8d374..7be1a5a5e0 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -1654,7 +1654,7 @@ async def _get_and_validate_existing_key( LiteLLM_VerificationToken: The existing key row Raises: - HTTPException: If key is not found + ProxyException: 404 if key is not found """ if prisma_client is None: raise HTTPException( @@ -1662,12 +1662,7 @@ async def _get_and_validate_existing_key( detail={"error": "Database not connected"}, ) - from litellm.proxy.proxy_server import hash_token - - if token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token + hashed_token = _hash_token_if_needed(token=token) existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( where={"token": hashed_token} @@ -2116,28 +2111,11 @@ async def update_key_fn( key = data_json.pop("key") # get the row from db - if prisma_client is None: - raise Exception("Not connected to DB!") - - from litellm.proxy.proxy_server import hash_token - - if data.key.startswith("sk-"): - hashed_token = hash_token(token=data.key) - else: - hashed_token = data.key - - existing_key_row = await prisma_client.db.litellm_verificationtoken.find_unique( - where={"token": hashed_token} + existing_key_row = await _get_and_validate_existing_key( + token=data.key, + prisma_client=prisma_client, ) - if existing_key_row is None: - raise ProxyException( - message=f"Key not found. Passed key={data.key}", - type=ProxyErrorTypes.not_found_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) - await _validate_update_key_data( data=data, existing_key_row=existing_key_row, diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 3bff35fbd5..48a1f7936c 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -1678,6 +1678,56 @@ async def test_unblock_key_nonexistent_key_returns_404(monkeypatch): mock_prisma_client.db.litellm_verificationtoken.update.assert_not_called() +@pytest.mark.asyncio +async def test_update_key_nonexistent_key_returns_404(monkeypatch): + """ + Test that update_key_fn returns 404 (not misleading 401) when the body + key doesn't exist in the database, even when the caller is authenticated + as a proxy admin via the Authorization header. + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + update_key_fn, + ) + + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + # find_unique returns None → key does not exist + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", None) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", True) + + mock_request = MagicMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user" + ) + + data = UpdateKeyRequest(key="sk-does-not-exist-key") + + with pytest.raises(ProxyException) as exc_info: + await update_key_fn( + request=mock_request, + data=data, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.code == "404" + assert "not found" in str(exc_info.value.message).lower() + assert "Authentication Error" not in str(exc_info.value.message) + + @pytest.mark.asyncio async def test_block_key_existing_key_succeeds(monkeypatch): """ From 0b63979d4572bfa6225df9bc62f9fbae8d6b3710 Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Wed, 18 Mar 2026 17:04:42 -0700 Subject: [PATCH 134/539] Fix build: cast endpointType to EndpointType at call site MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChatUI stores endpointType as string but the narrowed prop expects EndpointType — add explicit cast at the call site. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/components/playground/chat_ui/ChatUI.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx index ec37a9e1e9..ef57a75062 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx @@ -1936,7 +1936,7 @@ const ChatUI: React.FC = ({ Date: Wed, 18 Mar 2026 17:56:25 -0700 Subject: [PATCH 135/539] [Feature] UI - Leftnav: Add external link icon to Learning Resources Add ExportOutlined icon next to nav items that link to external pages, making it clear to users when a link opens in a new tab. Co-Authored-By: Claude Opus 4.6 (1M context) --- ui/litellm-dashboard/src/components/leftnav.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ui/litellm-dashboard/src/components/leftnav.tsx b/ui/litellm-dashboard/src/components/leftnav.tsx index d01fc06bc0..d3789fcffa 100644 --- a/ui/litellm-dashboard/src/components/leftnav.tsx +++ b/ui/litellm-dashboard/src/components/leftnav.tsx @@ -13,6 +13,7 @@ import { CreditCardOutlined, DatabaseOutlined, ExperimentOutlined, + ExportOutlined, FileTextOutlined, FolderOutlined, KeyOutlined, @@ -400,7 +401,7 @@ const Sidebar: React.FC = ({ setPage, defaultSelectedKey, collapse onClick={(e) => e.stopPropagation()} style={{ color: "inherit", textDecoration: "none" }} > - {label} + {label} ); } From 4770b657e15a25815eccf211f87843db356625e8 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 22:05:27 -0300 Subject: [PATCH 136/539] =?UTF-8?q?refactor:=20extract=20duplicated=20stdo?= =?UTF-8?q?ut/stderr=20=E2=86=92=20logs=20logic=20to=20shared=20helper?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- litellm/llms/anthropic/chat/handler.py | 15 ++------------- litellm/llms/anthropic/chat/transformation.py | 17 ++--------------- litellm/types/responses/main.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 70ecf91725..7dce72f1e8 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -50,7 +50,7 @@ from litellm.types.llms.openai import ( ) from litellm.types.responses.main import ( OutputCodeInterpreterCall, - OutputCodeInterpreterCallLog, + build_code_interpreter_log_outputs, ) from litellm.types.utils import ( Delta, @@ -708,20 +708,9 @@ class ModelResponseIterator: continue call_id = tr.get("tool_use_id", "") content = tr.get("content", {}) - if isinstance(content, dict): - parts = [] - if content.get("stdout"): - parts.append(content["stdout"]) - if content.get("stderr"): - parts.append(f"STDERR: {content['stderr']}") - logs = "".join(parts) - else: - logs = "" + log_outputs = build_code_interpreter_log_outputs(content) tool_input = self._server_tool_inputs.get(call_id, {}) code = tool_input.get("command", "") if isinstance(tool_input, dict) else "" - log_outputs = ( - [OutputCodeInterpreterCallLog(type="logs", logs=logs)] if logs else None - ) results.append( OutputCodeInterpreterCall( type="code_interpreter_call", diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index ca101df0e9..bbc73fcfd4 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -61,7 +61,7 @@ from litellm.types.utils import ( ) from litellm.types.responses.main import ( OutputCodeInterpreterCall, - OutputCodeInterpreterCallLog, + build_code_interpreter_log_outputs, ) from litellm.utils import ( ModelResponse, @@ -1771,20 +1771,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): continue call_id = tr.get("tool_use_id", "") content = tr.get("content", {}) - if isinstance(content, dict): - parts = [] - if content.get("stdout"): - parts.append(content["stdout"]) - if content.get("stderr"): - parts.append(f"STDERR: {content['stderr']}") - logs = "".join(parts) - else: - logs = "" - log_outputs = ( - [OutputCodeInterpreterCallLog(type="logs", logs=logs)] - if logs - else None - ) + log_outputs = build_code_interpreter_log_outputs(content) code_interpreter_results.append( OutputCodeInterpreterCall( type="code_interpreter_call", diff --git a/litellm/types/responses/main.py b/litellm/types/responses/main.py index e46857565c..ebd2ad5b5a 100644 --- a/litellm/types/responses/main.py +++ b/litellm/types/responses/main.py @@ -67,6 +67,24 @@ class OutputCodeInterpreterCall(BaseLiteLLMOpenAIResponseObject): outputs: Optional[List[OutputCodeInterpreterCallLog]] +def build_code_interpreter_log_outputs( + content: Any, +) -> Optional[List[OutputCodeInterpreterCallLog]]: + """Convert Anthropic bash_code_execution stdout/stderr to log outputs. + + Shared by streaming (handler.py) and non-streaming (transformation.py) paths. + """ + if not isinstance(content, dict): + return None + parts = [] + if content.get("stdout"): + parts.append(content["stdout"]) + if content.get("stderr"): + parts.append(f"STDERR: {content['stderr']}") + logs = "".join(parts) + return [OutputCodeInterpreterCallLog(type="logs", logs=logs)] if logs else None + + class GenericResponseOutputItem(BaseLiteLLMOpenAIResponseObject): """ Generic response API output item From 8969a3d1763e7e2bb3b0c48dbb855780cf992ebe Mon Sep 17 00:00:00 2001 From: xianren Date: Thu, 19 Mar 2026 09:10:21 +0800 Subject: [PATCH 137/539] Fixed thinking blocks dropped when thinking field is null (#24026) The check `content.get("thinking", None) is not None` incorrectly drops thinking blocks when the `thinking` key is explicitly null or absent. Changed to `content.get("type") == "thinking"` to match the fix already applied in the experimental pass-through path (PR #15501). Co-Authored-By: Claude Opus 4.6 (1M context) --- litellm/llms/anthropic/chat/transformation.py | 2 +- .../test_anthropic_chat_transformation.py | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 47cdd8287e..7552becef6 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1522,7 +1522,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): tool_results = [] tool_results.append(content) - elif content.get("thinking", None) is not None: + elif content.get("type") == "thinking": if thinking_blocks is None: thinking_blocks = [] thinking_blocks.append(cast(ChatCompletionThinkingBlock, content)) diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index a95b9413b9..7f5b7d6158 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -3321,3 +3321,54 @@ def test_map_tool_helper_empty_parameters_get_default(): assert result is not None assert result["input_schema"]["type"] == "object" assert result["input_schema"].get("properties") == {} + + +def test_extract_response_content_thinking_block_null_thinking(): + """ + Test that thinking blocks are not dropped when the 'thinking' field is null + or missing. Regression test for https://github.com/BerriAI/litellm/issues/24026 + """ + config = AnthropicConfig() + + # Case 1: thinking key is explicitly null + completion_response_null = { + "content": [ + {"type": "thinking", "thinking": None, "signature": "sig123"}, + {"type": "text", "text": "Hello"}, + ] + } + text, _, thinking_blocks, _, _, _, _, _ = config.extract_response_content( + completion_response_null + ) + assert thinking_blocks is not None, "thinking blocks should not be None when thinking=null" + assert len(thinking_blocks) == 1 + assert "Hello" in text + + # Case 2: thinking key is absent entirely + completion_response_missing = { + "content": [ + {"type": "thinking", "signature": "sig456"}, + {"type": "text", "text": "World"}, + ] + } + text, _, thinking_blocks, _, _, _, _, _ = config.extract_response_content( + completion_response_missing + ) + assert thinking_blocks is not None, "thinking blocks should not be None when thinking key is absent" + assert len(thinking_blocks) == 1 + assert "World" in text + + # Case 3: thinking key has actual content (should still work) + completion_response_text = { + "content": [ + {"type": "thinking", "thinking": "Let me think...", "signature": "sig789"}, + {"type": "text", "text": "Done"}, + ] + } + text, _, thinking_blocks, _, _, _, _, _ = config.extract_response_content( + completion_response_text + ) + assert thinking_blocks is not None + assert len(thinking_blocks) == 1 + assert thinking_blocks[0]["thinking"] == "Let me think..." + assert "Done" in text From 4bd7bdcf43da7106bd07240acc2bd0b1c997936c Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 22:30:49 -0300 Subject: [PATCH 138/539] fix: add additionalProperties: false for OpenAI strict mode in Anthropic adapter When translating Anthropic output_format to OpenAI response_format, the adapter sets strict: true but didn't add additionalProperties: false, which OpenAI requires at every object nesting level. This caused BadRequestError for structured output requests routed to OpenAI models. Fixes #20997 --- .../adapters/transformation.py | 37 +++++++ ...al_pass_through_adapters_transformation.py | 101 ++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index 43a6fa8045..47c9a223f8 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -1,3 +1,4 @@ +import copy import hashlib import json from typing import ( @@ -824,6 +825,11 @@ class LiteLLMAnthropicMessagesAdapter: if not schema: return None + # Deep copy to avoid mutating the original schema + schema = copy.deepcopy(schema) + # OpenAI strict mode requires additionalProperties: false on every object + self._add_additional_properties_false(schema) + # Convert to OpenAI response_format structure return { "type": "json_schema", @@ -834,6 +840,37 @@ class LiteLLMAnthropicMessagesAdapter: }, } + @staticmethod + def _add_additional_properties_false(schema: dict) -> None: + """ + Recursively add 'additionalProperties': false to all object schemas. + + OpenAI's strict mode requires this at every object nesting level. + """ + if not isinstance(schema, dict): + return + + if schema.get("type") == "object" and "properties" in schema: + schema["additionalProperties"] = False + for prop in schema["properties"].values(): + LiteLLMAnthropicMessagesAdapter._add_additional_properties_false(prop) + + # Handle array items + if "items" in schema: + LiteLLMAnthropicMessagesAdapter._add_additional_properties_false(schema["items"]) + + # Handle anyOf/oneOf/allOf + for key in ("anyOf", "oneOf", "allOf"): + if key in schema: + for sub_schema in schema[key]: + LiteLLMAnthropicMessagesAdapter._add_additional_properties_false(sub_schema) + + # Handle $defs / definitions + for key in ("$defs", "definitions"): + if key in schema: + for def_schema in schema[key].values(): + LiteLLMAnthropicMessagesAdapter._add_additional_properties_false(def_schema) + def _add_system_message_to_messages( self, new_messages: List[AllMessageValues], diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py index 839d032c43..b0442ebf0e 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py @@ -1984,3 +1984,104 @@ def test_translate_anthropic_to_openai_with_mixed_tools(): # tool_name_mapping should be empty for short tool names assert tool_name_mapping == {} + + +class TestTranslateAnthropicOutputFormatToOpenAI: + """Tests for translate_anthropic_output_format_to_openai adding additionalProperties: false.""" + + def setup_method(self): + self.adapter = LiteLLMAnthropicMessagesAdapter() + + def test_simple_object_adds_additional_properties_false(self): + output_format = { + "type": "json_schema", + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + }, + } + result = self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert result is not None + schema = result["json_schema"]["schema"] + assert schema["additionalProperties"] is False + + def test_nested_objects_adds_additional_properties_false(self): + output_format = { + "type": "json_schema", + "schema": { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + }, + }, + } + result = self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert result is not None + schema = result["json_schema"]["schema"] + assert schema["additionalProperties"] is False + assert schema["properties"]["user"]["additionalProperties"] is False + assert schema["properties"]["user"]["properties"]["address"]["additionalProperties"] is False + + def test_array_items_object_adds_additional_properties_false(self): + output_format = { + "type": "json_schema", + "schema": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"id": {"type": "integer"}}, + }, + } + }, + }, + } + result = self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert result is not None + schema = result["json_schema"]["schema"] + assert schema["additionalProperties"] is False + assert schema["properties"]["items"]["items"]["additionalProperties"] is False + + def test_does_not_mutate_original_schema(self): + original_schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + output_format = {"type": "json_schema", "schema": original_schema} + self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert "additionalProperties" not in original_schema + + def test_defs_adds_additional_properties_false(self): + output_format = { + "type": "json_schema", + "schema": { + "type": "object", + "properties": {"ref": {"$ref": "#/$defs/Item"}}, + "$defs": { + "Item": { + "type": "object", + "properties": {"value": {"type": "string"}}, + } + }, + }, + } + result = self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert result is not None + schema = result["json_schema"]["schema"] + assert schema["$defs"]["Item"]["additionalProperties"] is False + + def test_invalid_output_format_returns_none(self): + assert self.adapter.translate_anthropic_output_format_to_openai("invalid") is None + assert self.adapter.translate_anthropic_output_format_to_openai({"type": "text"}) is None + assert self.adapter.translate_anthropic_output_format_to_openai({"type": "json_schema"}) is None From 6f4b4d3c42c73641cb08f78a9587c8328c23afef Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 22:33:01 -0300 Subject: [PATCH 139/539] feat(gemini): support context circulation for server-side tool combination Enables Gemini 3+ models to combine built-in tools (Google Search, etc.) with custom functions via `include_server_side_tool_invocations=True`. Server-side invocations are surfaced in provider_specific_fields and automatically re-injected on subsequent turns for multi-turn coherence. Closes #24047 --- docs/my-website/docs/providers/gemini.md | 108 +++++++- .../llms/vertex_ai/gemini/transformation.py | 40 +++ .../vertex_and_google_ai_studio_gemini.py | 78 ++++++ litellm/types/llms/vertex_ai.py | 3 +- .../gemini/test_context_circulation.py | 234 ++++++++++++++++++ 5 files changed, 461 insertions(+), 2 deletions(-) create mode 100644 tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py diff --git a/docs/my-website/docs/providers/gemini.md b/docs/my-website/docs/providers/gemini.md index 0aaf3d5ae8..c8c9114ea8 100644 --- a/docs/my-website/docs/providers/gemini.md +++ b/docs/my-website/docs/providers/gemini.md @@ -54,6 +54,7 @@ response = completion( - stream - tools - tool_choice +- include_server_side_tool_invocations - functions - response_format - n @@ -856,7 +857,112 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \ -### URL Context +### Context Circulation (Server-Side Tool Combination) + +Context circulation allows Gemini 3+ models to combine **built-in tools** (like Google Search) with **your custom functions** in the same request. Without it, Gemini returns an error if you try to use both. + +When enabled, Gemini can execute Google Search server-side, use those results to decide whether to call your custom functions, and return the full chain of reasoning. + +**How it works:** +1. You pass `include_server_side_tool_invocations=True` along with both Google Search and your function tools +2. Gemini executes server-side tools internally and returns `toolCall`/`toolResponse` parts alongside any `functionCall` parts +3. LiteLLM extracts the server-side invocations into `provider_specific_fields["server_side_tool_invocations"]` +4. On subsequent turns, include the full assistant message in your conversation history — LiteLLM re-injects the server-side parts automatically + + + + +```python +from litellm import completion + +response = completion( + model="gemini/gemini-3-flash-preview", + messages=[{"role": "user", "content": "What's the weather in Buenos Aires? If it's raining, schedule a meeting."}], + tools=[ + {"type": "web_search_preview"}, # Google Search (server-side) + { + "type": "function", + "function": { + "name": "schedule_meeting", + "description": "Schedule a meeting", + "parameters": { + "type": "object", + "properties": {"reason": {"type": "string"}}, + "required": ["reason"], + }, + }, + }, + ], + include_server_side_tool_invocations=True, +) + +msg = response.choices[0].message + +# Server-side tool results are in provider_specific_fields +psf = msg.provider_specific_fields or {} +for invocation in psf.get("server_side_tool_invocations", []): + print(invocation["tool_type"]) # e.g. "GOOGLE_SEARCH_WEB" + print(invocation["id"]) + print(invocation["args"]) # e.g. {"queries": ["weather Buenos Aires"]} + print(invocation["response"]) # Search results from Google + +# For multi-turn: just append the full message to history +messages.append(msg) +messages.append({"role": "user", "content": "Thanks!"}) +# LiteLLM automatically re-injects the server-side parts + thought signatures +response2 = completion( + model="gemini/gemini-3-flash-preview", + messages=messages, + tools=tools, + include_server_side_tool_invocations=True, +) +``` + + + + +1. Setup config.yaml +```yaml +model_list: + - model_name: gemini-3-flash + litellm_params: + model: gemini/gemini-3-flash-preview + api_key: os.environ/GEMINI_API_KEY +``` + +2. Start Proxy +```bash +$ litellm --config /path/to/config.yaml +``` + +3. Make Request +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gemini-3-flash", + "messages": [{"role": "user", "content": "What is the weather in Buenos Aires?"}], + "tools": [ + {"type": "web_search_preview"}, + {"type": "function", "function": {"name": "schedule_meeting", "description": "Schedule a meeting", "parameters": {"type": "object", "properties": {"reason": {"type": "string"}}}}} + ], + "include_server_side_tool_invocations": true +}' +``` + + + + +:::info + +- Context circulation requires **Gemini 3+** models +- Server-side tool invocations (`toolCall`/`toolResponse`) are **not** included in `tool_calls` — they are in `provider_specific_fields["server_side_tool_invocations"]` because they were already executed by Google, not by your code +- `thought_signatures` are automatically preserved alongside server-side invocations for multi-turn coherence + +::: + +### URL Context diff --git a/litellm/llms/vertex_ai/gemini/transformation.py b/litellm/llms/vertex_ai/gemini/transformation.py index d7b96b4db7..f6310778c7 100644 --- a/litellm/llms/vertex_ai/gemini/transformation.py +++ b/litellm/llms/vertex_ai/gemini/transformation.py @@ -540,6 +540,39 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915 assistant_content.append(gemini_tool_call_part) last_message_with_tool_calls = assistant_msg + ## HANDLE SERVER-SIDE TOOL INVOCATIONS (context circulation) + _psf = assistant_msg.get("provider_specific_fields") + if isinstance(_psf, dict): + _ss_invocations = _psf.get("server_side_tool_invocations") + if isinstance(_ss_invocations, list): + for invocation in _ss_invocations: + # Re-inject toolCall part + tc_part: Dict[str, Any] = { + "toolCall": { + "toolType": invocation.get("tool_type"), + "id": invocation.get("id"), + "args": invocation.get("args"), + } + } + if "thought_signature" in invocation: + tc_part["thoughtSignature"] = invocation["thought_signature"] + assistant_content.append(tc_part) # type: ignore + + # Re-inject toolResponse part if response is present + if "response" in invocation: + tr_dict: Dict[str, Any] = { + "id": invocation.get("id"), + "response": invocation.get("response"), + } + if invocation.get("tool_type"): + tr_dict["toolType"] = invocation["tool_type"] + tr_part: Dict[str, Any] = { + "toolResponse": tr_dict + } + if "thought_signature" in invocation: + tr_part["thoughtSignature"] = invocation["thought_signature"] + assistant_content.append(tr_part) # type: ignore + msg_i += 1 if assistant_content: @@ -666,6 +699,9 @@ def _transform_request_body( # noqa: PLR0915 ) tools: Optional[Tools] = optional_params.pop("tools", None) tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + include_server_side_tool_invocations: bool = optional_params.pop( + "include_server_side_tool_invocations", False + ) safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( "safety_settings", None ) # type: ignore @@ -715,6 +751,10 @@ def _transform_request_body( # noqa: PLR0915 data["tools"] = tools if tool_choice is not None: data["toolConfig"] = tool_choice + if include_server_side_tool_invocations: + if "toolConfig" not in data: + data["toolConfig"] = {} + data["toolConfig"]["includeServerSideToolInvocations"] = True if safety_settings is not None: data["safetySettings"] = safety_settings if generation_config is not None and len(generation_config) > 0: diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 3f1bccaccf..3555d3c719 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -316,6 +316,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): "audio", "parallel_tool_calls", "web_search_options", + "include_server_side_tool_invocations", ] # Add penalty parameters only for non-preview models @@ -1119,6 +1120,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): optional_params = self._add_tools_to_optional_params( optional_params, [_tools] ) + elif param == "include_server_side_tool_invocations" and value is True: + optional_params["include_server_side_tool_invocations"] = True if litellm.vertex_ai_safety_settings is not None: optional_params["safety_settings"] = litellm.vertex_ai_safety_settings @@ -1360,6 +1363,67 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): signatures.append(signature) return signatures if signatures else None + @staticmethod + def _extract_server_side_tool_invocations( + parts: List[HttpxPartType], + ) -> Optional[List[Dict[str, Any]]]: + """Extract server-side tool invocations (toolCall/toolResponse) from parts. + + These are returned by Gemini when context circulation is enabled + (includeServerSideToolInvocations=true). They represent tools executed + server-side (e.g. Google Search) and must be circulated back in + subsequent turns for multi-turn coherence. + + Returns: + List of server-side invocation dicts if any found, None otherwise. + """ + invocations: List[Dict[str, Any]] = [] + # Index toolCalls by id so we can pair them with responses + tool_calls_by_id: Dict[str, Dict[str, Any]] = {} + tool_responses_by_id: Dict[str, Dict[str, Any]] = {} + + for part in parts: + if "toolCall" in part: + tc = part["toolCall"] + entry: Dict[str, Any] = { + "tool_type": tc.get("toolType"), + "id": tc.get("id"), + "args": tc.get("args"), + } + signature = part.get("thoughtSignature") + if signature is not None: + entry["thought_signature"] = signature + tool_calls_by_id[tc.get("id", "")] = entry + + elif "toolResponse" in part: + tr = part["toolResponse"] + entry = { + "id": tr.get("id"), + "tool_type": tr.get("toolType"), + "response": tr.get("response"), + } + signature = part.get("thoughtSignature") + if signature is not None: + entry["thought_signature"] = signature + tool_responses_by_id[tr.get("id", "")] = entry + + # Merge calls with their responses + for call_id, call_entry in tool_calls_by_id.items(): + merged = dict(call_entry) + resp = tool_responses_by_id.pop(call_id, None) + if resp is not None: + merged["response"] = resp.get("response") + # Keep response signature if call didn't have one + if "thought_signature" not in merged and "thought_signature" in resp: + merged["thought_signature"] = resp["thought_signature"] + invocations.append(merged) + + # Any orphan responses (shouldn't happen, but be safe) + for resp_id, resp_entry in tool_responses_by_id.items(): + invocations.append(resp_entry) + + return invocations if invocations else None + def _extract_image_response_from_parts( self, parts: List[HttpxPartType] ) -> Optional[List[ImageURLListItem]]: @@ -2018,6 +2082,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None reasoning_content: Optional[str] = None thought_signatures: Optional[Any] = None + server_side_tool_invocations: Optional[List[Dict[str, Any]]] = None for idx, candidate in enumerate(_candidates): if "content" not in candidate: @@ -2068,6 +2133,13 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): ) ) + # Extract server-side tool invocations (context circulation) + server_side_tool_invocations = ( + VertexGeminiConfig._extract_server_side_tool_invocations( + parts=candidate["content"]["parts"] + ) + ) + if audio_response is not None: cast(Dict[str, Any], chat_completion_message)[ "audio" @@ -2139,6 +2211,12 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): chat_completion_message["provider_specific_fields"] = {} chat_completion_message["provider_specific_fields"]["thought_signatures"] = thought_signatures # type: ignore + # Store server-side tool invocations in provider_specific_fields + if server_side_tool_invocations is not None: + if "provider_specific_fields" not in chat_completion_message: + chat_completion_message["provider_specific_fields"] = {} + chat_completion_message["provider_specific_fields"]["server_side_tool_invocations"] = server_side_tool_invocations # type: ignore + if isinstance(model_response, ModelResponseStream): choice = VertexGeminiConfig._create_streaming_choice( chat_completion_message=chat_completion_message, diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 201854369f..66c6ca436c 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -244,8 +244,9 @@ class Tools(TypedDict, total=False): retrieval: Retrieval -class ToolConfig(TypedDict): +class ToolConfig(TypedDict, total=False): functionCallingConfig: FunctionCallingConfig + includeServerSideToolInvocations: bool class TTL(TypedDict, total=False): diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py b/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py new file mode 100644 index 0000000000..c3038840d8 --- /dev/null +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_context_circulation.py @@ -0,0 +1,234 @@ +""" +Tests for Gemini context circulation (server-side tool invocations). + +When includeServerSideToolInvocations=true is set, Gemini returns toolCall/toolResponse +parts for server-side tools (e.g. Google Search). These must be: +1. Extracted from the response into provider_specific_fields["server_side_tool_invocations"] +2. Re-injected as raw toolCall/toolResponse parts when converting messages back to Gemini format +3. The includeServerSideToolInvocations flag must be passed through to toolConfig +""" + +import json +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest + +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + VertexGeminiConfig, +) +from litellm.llms.vertex_ai.gemini.transformation import ( + _gemini_convert_messages_with_history, +) +from litellm.types.llms.vertex_ai import HttpxPartType + + +# --- Response extraction tests --- + + +class TestExtractServerSideToolInvocations: + """Test _extract_server_side_tool_invocations from response parts.""" + + def test_extracts_tool_call_and_response(self): + """Basic case: one toolCall + one toolResponse with same id.""" + parts: List[HttpxPartType] = [ + { + "thoughtSignature": "sig_call_1", + "toolCall": { + "toolType": "GOOGLE_SEARCH_WEB", + "id": "abc123", + "args": {"queries": ["weather Buenos Aires"]}, + }, + }, + { + "thoughtSignature": "sig_resp_1", + "toolResponse": { + "toolType": "GOOGLE_SEARCH_WEB", + "id": "abc123", + "response": {"weather": "Sunny, 20°C"}, + }, + }, + { + "text": "The weather in Buenos Aires is sunny.", + "thoughtSignature": "sig_text", + }, + ] + + result = VertexGeminiConfig._extract_server_side_tool_invocations(parts) + + assert result is not None + assert len(result) == 1 + assert result[0]["tool_type"] == "GOOGLE_SEARCH_WEB" + assert result[0]["id"] == "abc123" + assert result[0]["args"] == {"queries": ["weather Buenos Aires"]} + assert result[0]["response"] == {"weather": "Sunny, 20°C"} + assert result[0]["thought_signature"] == "sig_call_1" + + def test_returns_none_when_no_server_side_tools(self): + """No toolCall/toolResponse parts → returns None.""" + parts: List[HttpxPartType] = [ + {"text": "Hello world", "thoughtSignature": "sig1"}, + { + "functionCall": { + "name": "get_weather", + "args": {"location": "Paris"}, + }, + "thoughtSignature": "sig2", + }, + ] + + result = VertexGeminiConfig._extract_server_side_tool_invocations(parts) + assert result is None + + def test_multiple_server_side_invocations(self): + """Multiple toolCall/toolResponse pairs.""" + parts: List[HttpxPartType] = [ + { + "toolCall": { + "toolType": "GOOGLE_SEARCH_WEB", + "id": "search1", + "args": {"queries": ["query1"]}, + }, + "thoughtSignature": "sig1", + }, + { + "toolResponse": {"toolType": "GOOGLE_SEARCH_WEB", "id": "search1", "response": "result1"}, + "thoughtSignature": "sig2", + }, + { + "toolCall": { + "toolType": "GOOGLE_SEARCH_WEB", + "id": "search2", + "args": {"queries": ["query2"]}, + }, + "thoughtSignature": "sig3", + }, + { + "toolResponse": {"toolType": "GOOGLE_SEARCH_WEB", "id": "search2", "response": "result2"}, + "thoughtSignature": "sig4", + }, + ] + + result = VertexGeminiConfig._extract_server_side_tool_invocations(parts) + + assert result is not None + assert len(result) == 2 + assert result[0]["id"] == "search1" + assert result[0]["response"] == "result1" + assert result[1]["id"] == "search2" + assert result[1]["response"] == "result2" + + def test_tool_call_without_response(self): + """toolCall without matching toolResponse is still captured.""" + parts: List[HttpxPartType] = [ + { + "toolCall": { + "toolType": "CODE_EXECUTION", + "id": "exec1", + "args": {"code": "print('hello')"}, + }, + }, + ] + + result = VertexGeminiConfig._extract_server_side_tool_invocations(parts) + + assert result is not None + assert len(result) == 1 + assert result[0]["id"] == "exec1" + assert "response" not in result[0] + + +# --- Input re-injection tests --- + + +class TestReInjectServerSideToolInvocations: + """Test that server_side_tool_invocations are re-injected into Gemini parts.""" + + def test_roundtrip_single_invocation(self): + """Server-side invocations from assistant message are converted back to Gemini parts.""" + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "It's sunny in Buenos Aires.", + "provider_specific_fields": { + "server_side_tool_invocations": [ + { + "tool_type": "GOOGLE_SEARCH_WEB", + "id": "abc123", + "args": {"queries": ["weather Buenos Aires"]}, + "response": {"weather": "Sunny, 20°C"}, + "thought_signature": "sig_abc", + } + ] + }, + }, + {"role": "user", "content": "Thanks!"}, + ] + + contents = _gemini_convert_messages_with_history(messages) + + # Find the model turn + model_turn = [c for c in contents if c["role"] == "model"] + assert len(model_turn) == 1 + + parts = model_turn[0]["parts"] + # Should have: text part + toolCall part + toolResponse part + tool_call_parts = [p for p in parts if "toolCall" in p] + tool_response_parts = [p for p in parts if "toolResponse" in p] + + assert len(tool_call_parts) == 1 + assert tool_call_parts[0]["toolCall"]["toolType"] == "GOOGLE_SEARCH_WEB" + assert tool_call_parts[0]["toolCall"]["id"] == "abc123" + assert tool_call_parts[0]["toolCall"]["args"] == {"queries": ["weather Buenos Aires"]} + assert tool_call_parts[0]["thoughtSignature"] == "sig_abc" + + assert len(tool_response_parts) == 1 + assert tool_response_parts[0]["toolResponse"]["id"] == "abc123" + assert tool_response_parts[0]["toolResponse"]["toolType"] == "GOOGLE_SEARCH_WEB" + assert tool_response_parts[0]["toolResponse"]["response"] == {"weather": "Sunny, 20°C"} + + def test_no_invocations_no_extra_parts(self): + """Without server_side_tool_invocations, no extra parts are added.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "Bye"}, + ] + + contents = _gemini_convert_messages_with_history(messages) + model_turn = [c for c in contents if c["role"] == "model"] + assert len(model_turn) == 1 + + parts = model_turn[0]["parts"] + assert len(parts) == 1 + assert "text" in parts[0] + assert "toolCall" not in parts[0] + + +# --- toolConfig flag tests --- + + +class TestIncludeServerSideToolInvocationsConfig: + """Test that the flag is passed through to toolConfig.""" + + def test_flag_added_to_tool_config(self): + """include_server_side_tool_invocations=True should be mapped to optional_params.""" + config = VertexGeminiConfig() + non_default_params = {"include_server_side_tool_invocations": True} + optional_params: Dict[str, Any] = {} + + result = config.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model="gemini-3-flash-preview", + drop_params=False, + ) + + assert result["include_server_side_tool_invocations"] is True + + def test_flag_in_supported_params(self): + """include_server_side_tool_invocations should be in supported params.""" + config = VertexGeminiConfig() + supported = config.get_supported_openai_params(model="gemini-3-flash-preview") + assert "include_server_side_tool_invocations" in supported From 286b8d14604c9ad5f736038fb5bb6146ecaed7bf Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 23:03:20 -0300 Subject: [PATCH 140/539] fix: also populate required for all properties in strict mode OpenAI strict mode requires both additionalProperties:false AND all property keys in required. Without required, OpenAI rejects the schema even with additionalProperties:false set. --- .../adapters/transformation.py | 7 +++-- ...al_pass_through_adapters_transformation.py | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py index 47c9a223f8..3ee6c5f0d2 100644 --- a/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/adapters/transformation.py @@ -843,15 +843,18 @@ class LiteLLMAnthropicMessagesAdapter: @staticmethod def _add_additional_properties_false(schema: dict) -> None: """ - Recursively add 'additionalProperties': false to all object schemas. + Recursively ensure object schemas comply with OpenAI strict mode. - OpenAI's strict mode requires this at every object nesting level. + OpenAI's strict mode requires: + 1. 'additionalProperties': false at every object nesting level + 2. All property keys listed in 'required' """ if not isinstance(schema, dict): return if schema.get("type") == "object" and "properties" in schema: schema["additionalProperties"] = False + schema["required"] = list(schema["properties"].keys()) for prop in schema["properties"].values(): LiteLLMAnthropicMessagesAdapter._add_additional_properties_false(prop) diff --git a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py index b0442ebf0e..ae970e1ff0 100644 --- a/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py +++ b/tests/test_litellm/llms/anthropic/experimental_pass_through/adapters/test_anthropic_experimental_pass_through_adapters_transformation.py @@ -2004,6 +2004,7 @@ class TestTranslateAnthropicOutputFormatToOpenAI: assert result is not None schema = result["json_schema"]["schema"] assert schema["additionalProperties"] is False + assert schema["required"] == ["name"] def test_nested_objects_adds_additional_properties_false(self): output_format = { @@ -2028,8 +2029,11 @@ class TestTranslateAnthropicOutputFormatToOpenAI: assert result is not None schema = result["json_schema"]["schema"] assert schema["additionalProperties"] is False + assert schema["required"] == ["user"] assert schema["properties"]["user"]["additionalProperties"] is False + assert schema["properties"]["user"]["required"] == ["name", "address"] assert schema["properties"]["user"]["properties"]["address"]["additionalProperties"] is False + assert schema["properties"]["user"]["properties"]["address"]["required"] == ["city"] def test_array_items_object_adds_additional_properties_false(self): output_format = { @@ -2061,6 +2065,7 @@ class TestTranslateAnthropicOutputFormatToOpenAI: output_format = {"type": "json_schema", "schema": original_schema} self.adapter.translate_anthropic_output_format_to_openai(output_format) assert "additionalProperties" not in original_schema + assert "required" not in original_schema def test_defs_adds_additional_properties_false(self): output_format = { @@ -2080,6 +2085,27 @@ class TestTranslateAnthropicOutputFormatToOpenAI: assert result is not None schema = result["json_schema"]["schema"] assert schema["$defs"]["Item"]["additionalProperties"] is False + assert schema["$defs"]["Item"]["required"] == ["value"] + + def test_incomplete_required_gets_completed(self): + """OpenAI strict mode requires ALL properties in required.""" + output_format = { + "type": "json_schema", + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": "string"}, + }, + "required": ["name"], # only 1 of 3 + }, + } + result = self.adapter.translate_anthropic_output_format_to_openai(output_format) + assert result is not None + schema = result["json_schema"]["schema"] + assert schema["additionalProperties"] is False + assert sorted(schema["required"]) == ["age", "email", "name"] def test_invalid_output_format_returns_none(self): assert self.adapter.translate_anthropic_output_format_to_openai("invalid") is None From 60c234270a4b42cdf20888ac353071b5eae5a378 Mon Sep 17 00:00:00 2001 From: Chesars Date: Wed, 18 Mar 2026 23:07:56 -0300 Subject: [PATCH 141/539] feat(bedrock): support cache_control_injection_points for tool_config location Add support for {"location": "tool_config"} in cache_control_injection_points, which appends a cachePoint block to the Bedrock Converse toolConfig.tools array. This enables prompt caching of tool definitions on Bedrock Claude models. Also update the cache control hook to pass through non-message injection points to provider-specific handling instead of silently dropping them. Fixes #21969 --- .../anthropic_cache_control_hook.py | 9 +- .../bedrock/chat/converse_transformation.py | 10 ++ .../chat/test_converse_transformation.py | 97 +++++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py index 8e4d40c460..0e99537d5d 100644 --- a/litellm/integrations/anthropic_cache_control_hook.py +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -60,13 +60,20 @@ class AnthropicCacheControlHook(CustomPromptManagement): # Create a deep copy of messages to avoid modifying the original list processed_messages = copy.deepcopy(messages) - # Process message-level cache controls + # Separate message-level and non-message-level injection points + remaining_points = [] for point in injection_points: if point.get("location") == "message": point = cast(CacheControlMessageInjectionPoint, point) processed_messages = self._process_message_injection( point=point, messages=processed_messages ) + else: + remaining_points.append(point) + + # Pass through non-message injection points for provider-specific handling + if remaining_points: + non_default_params["cache_control_injection_points"] = remaining_points return model, processed_messages, non_default_params diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 229457a73b..dd8b1b0a69 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -1446,6 +1446,16 @@ class AmazonConverseConfig(BaseConfig): original_tools, model, headers, additional_request_params ) + # Append cachePoint to tools if cache_control_injection_points has tool_config + cache_injection_points = additional_request_params.pop( + "cache_control_injection_points", None + ) + if cache_injection_points and len(bedrock_tools) > 0: + for point in cache_injection_points: + if point.get("location") == "tool_config": + bedrock_tools.append({"cachePoint": {"type": "default"}}) + break + bedrock_tool_config: Optional[ToolConfigBlock] = None if len(bedrock_tools) > 0: tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( diff --git a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py index a305009659..e9aaa97a42 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py @@ -3803,3 +3803,100 @@ def test_streaming_without_json_mode_passes_all_tools(): assert tool_use_delta is not None assert tool_use_delta["function"]["arguments"] == '{"data": 1}' + +def test_cache_control_injection_tool_config(): + """Test that cache_control_injection_points with location=tool_config appends cachePoint to tools.""" + config = AmazonConverseConfig() + messages = [ + {"role": "user", "content": "What is the weather?"}, + ] + optional_params = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + }, + }, + } + ], + "cache_control_injection_points": [ + {"location": "tool_config"}, + ], + } + result = config._transform_request( + model="anthropic.claude-3-5-haiku-20241022-v1:0", + messages=messages, + optional_params=optional_params, + litellm_params={}, + ) + tool_config = result["toolConfig"] + tools = tool_config["tools"] + # Last element should be a cachePoint block + assert tools[-1] == {"cachePoint": {"type": "default"}} + # First element should be the actual tool + assert "toolSpec" in tools[0] + + +def test_cache_control_injection_tool_config_no_tools(): + """Test that tool_config injection is ignored when no tools are provided.""" + config = AmazonConverseConfig() + messages = [ + {"role": "user", "content": "Hello"}, + ] + optional_params = { + "cache_control_injection_points": [ + {"location": "tool_config"}, + ], + } + result = config._transform_request( + model="anthropic.claude-3-5-haiku-20241022-v1:0", + messages=messages, + optional_params=optional_params, + litellm_params={}, + ) + assert "toolConfig" not in result + + +def test_cache_control_injection_tool_config_not_added_without_injection_point(): + """Test that cachePoint is NOT appended when cache_control_injection_points doesn't include tool_config.""" + config = AmazonConverseConfig() + messages = [ + {"role": "user", "content": "What is the weather?"}, + ] + optional_params = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, + } + ], + "cache_control_injection_points": [ + {"location": "message", "role": "system"}, + ], + } + result = config._transform_request( + model="anthropic.claude-3-5-haiku-20241022-v1:0", + messages=messages, + optional_params=optional_params, + litellm_params={}, + ) + tools = result["toolConfig"]["tools"] + # No cachePoint should be appended + assert all("cachePoint" not in tool for tool in tools) + From bd0c3bfdc4d9cd364c8b68b975ed8398c55b0dc0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 18 Mar 2026 20:58:41 -0700 Subject: [PATCH 142/539] fix: fix logging for response incomplete streaming --- litellm/litellm_core_utils/litellm_logging.py | 294 ++++++++---------- .../anthropic_passthrough_logging_handler.py | 40 ++- litellm/responses/streaming_iterator.py | 52 ++-- .../test_litellm_logging.py | 192 ++++++++++-- 4 files changed, 342 insertions(+), 236 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 4e63dd7076..5e34a4c992 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -12,47 +12,28 @@ import time import traceback from datetime import datetime as dt_object from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Literal, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, + Optional, Tuple, Type, Union, cast) from httpx import Response from pydantic import BaseModel import litellm -from litellm import ( - _custom_logger_compatible_callbacks_literal, - json_logs, - log_raw_request_response, - turn_off_message_logging, -) +from litellm import (_custom_logger_compatible_callbacks_literal, json_logs, + log_raw_request_response, turn_off_message_logging) from litellm._logging import _is_debugging_on, verbose_logger from litellm._uuid import uuid from litellm.batches.batch_utils import _handle_completed_batch from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching_handler import LLMCachingHandler -from litellm.constants import ( - DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, - DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, - SENTRY_DENYLIST, - SENTRY_PII_DENYLIST, -) -from litellm.cost_calculator import ( - RealtimeAPITokenUsageProcessor, - _select_model_name_for_cost_calc, -) +from litellm.constants import (DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, + DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, + SENTRY_DENYLIST, SENTRY_PII_DENYLIST) +from litellm.cost_calculator import (RealtimeAPITokenUsageProcessor, + _select_model_name_for_cost_calc) from litellm.integrations.agentops import AgentOps -from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook +from litellm.integrations.anthropic_cache_control_hook import \ + AnthropicCacheControlHook from litellm.integrations.arize.arize import ArizeLogger from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger @@ -61,70 +42,48 @@ from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.sqs import SQSLogger from litellm.litellm_core_utils.core_helpers import reconstruct_model_name from litellm.litellm_core_utils.get_litellm_params import get_litellm_params -from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( - StandardBuiltInToolCostTracking, -) -from litellm.litellm_core_utils.logging_utils import truncate_base64_in_messages +from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import \ + StandardBuiltInToolCostTracking +from litellm.litellm_core_utils.logging_utils import \ + truncate_base64_in_messages from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, - redact_message_input_output_from_logging, -) + redact_message_input_output_from_logging) from litellm.llms.base_llm.ocr.transformation import OCRResponse from litellm.llms.base_llm.search.transformation import SearchResponse from litellm.responses.utils import ResponseAPILoggingUtils from litellm.types.agents import LiteLLMSendMessageResponse from litellm.types.containers.main import ContainerObject -from litellm.types.llms.openai import ( - AllMessageValues, - Batch, - FineTuningJob, - HttpxBinaryResponseContent, - OpenAIFileObject, - OpenAIModerationResponse, - ResponseAPIUsage, - ResponseCompletedEvent, - ResponsesAPIResponse, -) +from litellm.types.llms.openai import (AllMessageValues, Batch, FineTuningJob, + HttpxBinaryResponseContent, + OpenAIFileObject, + OpenAIModerationResponse, + ResponseAPIUsage, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponsesAPIResponse) from litellm.types.mcp import MCPPostCallResponseObject from litellm.types.prompts.init_prompts import PromptSpec from litellm.types.rerank import RerankResponse from litellm.types.utils import ( - CachingDetails, - CallTypes, - CostBreakdown, - CostResponseTypes, - CustomPricingLiteLLMParams, - DynamicPromptManagementParamLiteral, - EmbeddingResponse, - GuardrailStatus, - ImageResponse, - LiteLLMBatch, - LiteLLMLoggingBaseClass, - LiteLLMRealtimeStreamLoggingObject, - ModelResponse, - ModelResponseStream, - RawRequestTypedDict, - StandardBuiltInToolsParams, - StandardCallbackDynamicParams, - StandardLoggingAdditionalHeaders, - StandardLoggingHiddenParams, - StandardLoggingMCPToolCall, - StandardLoggingMetadata, - StandardLoggingModelCostFailureDebugInformation, - StandardLoggingModelInformation, - StandardLoggingPayload, - StandardLoggingPayloadErrorInformation, - StandardLoggingPayloadStatus, + CachingDetails, CallTypes, CostBreakdown, CostResponseTypes, + CustomPricingLiteLLMParams, DynamicPromptManagementParamLiteral, + EmbeddingResponse, GuardrailStatus, ImageResponse, LiteLLMBatch, + LiteLLMLoggingBaseClass, LiteLLMRealtimeStreamLoggingObject, ModelResponse, + ModelResponseStream, RawRequestTypedDict, StandardBuiltInToolsParams, + StandardCallbackDynamicParams, StandardLoggingAdditionalHeaders, + StandardLoggingHiddenParams, StandardLoggingMCPToolCall, + StandardLoggingMetadata, StandardLoggingModelCostFailureDebugInformation, + StandardLoggingModelInformation, StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, StandardLoggingPayloadStatus, StandardLoggingPayloadStatusFields, - StandardLoggingPromptManagementMetadata, - StandardLoggingVectorStoreRequest, - TextCompletionResponse, - TranscriptionResponse, - Usage, -) + StandardLoggingPromptManagementMetadata, StandardLoggingVectorStoreRequest, + TextCompletionResponse, TranscriptionResponse, Usage) from litellm.types.videos.main import VideoObject -from litellm.utils import _get_base_model_from_metadata, executor, print_verbose +from litellm.utils import (_get_base_model_from_metadata, executor, + print_verbose) from ..integrations.argilla import ArgillaLogger from ..integrations.arize.arize_phoenix import ArizePhoenixLogger @@ -146,7 +105,8 @@ from ..integrations.humanloop import HumanloopLogger from ..integrations.lago import LagoLogger from ..integrations.langfuse.langfuse import LangFuseLogger from ..integrations.langfuse.langfuse_handler import LangFuseHandler -from ..integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement +from ..integrations.langfuse.langfuse_prompt_management import \ + LangfusePromptManagement from ..integrations.langsmith import LangsmithLogger from ..integrations.litellm_agent import LiteLLMAgentModelResolver from ..integrations.literal_ai import LiteralAILogger @@ -161,34 +121,30 @@ from ..integrations.s3_v2 import S3Logger as S3V2Logger from ..integrations.supabase import Supabase from ..integrations.traceloop import TraceloopLogger from .exception_mapping_utils import _get_response_headers -from .initialize_dynamic_callback_params import ( - initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params, -) +from .initialize_dynamic_callback_params import \ + initialize_standard_callback_dynamic_params as \ + _initialize_standard_callback_dynamic_params from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache if TYPE_CHECKING: - from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig + from litellm.llms.base_llm.passthrough.transformation import \ + BasePassthroughConfig try: - from litellm_enterprise.enterprise_callbacks.callback_controls import ( - EnterpriseCallbackControls, - ) - from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import ( - PagerDutyAlerting, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( - ResendEmailLogger, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import ( - SendGridEmailLogger, - ) - from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import ( - SMTPEmailLogger, - ) - from litellm_enterprise.litellm_core_utils.litellm_logging import ( - StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup, - ) + from litellm_enterprise.enterprise_callbacks.callback_controls import \ + EnterpriseCallbackControls + from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import \ + PagerDutyAlerting + from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import \ + ResendEmailLogger + from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import \ + SendGridEmailLogger + from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import \ + SMTPEmailLogger + from litellm_enterprise.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup - from litellm.integrations.generic_api.generic_api_callback import GenericAPILogger + from litellm.integrations.generic_api.generic_api_callback import \ + GenericAPILogger EnterpriseStandardLoggingPayloadSetupVAR: Optional[ Type[EnterpriseStandardLoggingPayloadSetup] @@ -516,6 +472,23 @@ class Logging(LiteLLMLoggingBaseClass): ), ) + def get_router_model_id(self) -> Optional[str]: + """Extract the router deployment model_id from litellm_params. + + Checks both litellm_metadata and metadata for model_info.id. + Used by cost calculators to look up custom pricing registered + under the deployment's model_info.id in litellm.model_cost. + """ + if not hasattr(self, "litellm_params"): + return None + for key in ("litellm_metadata", "metadata"): + meta = self.litellm_params.get(key, {}) or {} + info = meta.get("model_info", {}) or {} + model_id = info.get("id") + if model_id is not None: + return model_id + return None + def update_environment_variables( self, litellm_params: Dict, @@ -1458,16 +1431,8 @@ class Logging(LiteLLMLoggingBaseClass): # Fallback: extract router_model_id from litellm_params when not available # from the result object. ResponsesAPIResponse objects (used by /v1/responses # streaming) don't carry _hidden_params["model_id"] like ModelResponse does. - if router_model_id is None and hasattr(self, "litellm_params"): - for metadata_key in ("litellm_metadata", "metadata"): - _metadata: dict = ( - self.litellm_params.get(metadata_key, {}) or {} - ) - _model_info: dict = _metadata.get("model_info", {}) or {} - _model_id = _model_info.get("id") - if _model_id is not None: - router_model_id = _model_id - break + if router_model_id is None: + router_model_id = self.get_router_model_id() ## RESPONSE COST ## custom_pricing = use_custom_pricing_for_model( @@ -1758,9 +1723,8 @@ class Logging(LiteLLMLoggingBaseClass): ) standard_logging_payload["response"] = response_dict elif isinstance(result, TranscriptionResponse): - from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import ( - TranscriptionUsageObjectTransformation, - ) + from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import \ + TranscriptionUsageObjectTransformation result = result.model_copy() transformed_usage = TranscriptionUsageObjectTransformation.transform_transcription_usage_object(result.usage) # type: ignore @@ -2443,9 +2407,8 @@ class Logging(LiteLLMLoggingBaseClass): ): # polling job will query these frequently, don't spam db logs return - from litellm.proxy.openai_files_endpoints.common_utils import ( - _is_base64_encoded_unified_file_id, - ) + from litellm.proxy.openai_files_endpoints.common_utils import \ + _is_base64_encoded_unified_file_id # check if file id is a unified file id is_base64_unified_file_id = _is_base64_encoded_unified_file_id(result.id) @@ -3321,7 +3284,7 @@ class Logging(LiteLLMLoggingBaseClass): return result elif isinstance(result, TextCompletionResponse): return result - elif isinstance(result, ResponseCompletedEvent): + elif isinstance(result, (ResponseCompletedEvent, ResponseIncompleteEvent, ResponseFailedEvent)): ## return unified Usage object if isinstance(result.response.usage, ResponseAPIUsage): transformed_usage = ( @@ -3588,7 +3551,8 @@ def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 elif callback == "s3": s3Logger = S3Logger() elif callback == "wandb": - from litellm.integrations.weights_biases import WeightsBiasesLogger + from litellm.integrations.weights_biases import \ + WeightsBiasesLogger weightsBiasesLogger = WeightsBiasesLogger() elif callback == "logfire": @@ -3652,7 +3616,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_posthog_logger) return _posthog_logger # type: ignore elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import BraintrustLogger + from litellm.integrations.braintrust_logging import \ + BraintrustLogger for callback in _in_memory_loggers: if isinstance(callback, BraintrustLogger): @@ -3773,9 +3738,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 return _opik_logger # type: ignore elif logging_integration == "arize": from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) + OpenTelemetry, OpenTelemetryConfig) arize_config = ArizeLogger.get_arize_config() if arize_config.endpoint is None: @@ -3802,9 +3765,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 return _arize_otel_logger # type: ignore elif logging_integration == "arize_phoenix": from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) + OpenTelemetry, OpenTelemetryConfig) arize_phoenix_config = ArizePhoenixLogger.get_arize_phoenix_config() otel_config = OpenTelemetryConfig( @@ -3858,9 +3819,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 elif logging_integration == "levo": from litellm.integrations.levo.levo import LevoLogger from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) + OpenTelemetry, OpenTelemetryConfig) levo_config = LevoLogger.get_levo_config() otel_config = OpenTelemetryConfig( @@ -3909,7 +3868,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(galileo_logger) return galileo_logger # type: ignore elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger + from litellm.integrations.cloudzero.cloudzero import \ + CloudZeroLogger for callback in _in_memory_loggers: if isinstance(callback, CloudZeroLogger): @@ -3929,7 +3889,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(focus_logger) return focus_logger # type: ignore elif logging_integration == "vantage": - from litellm.integrations.vantage.vantage_logger import VantageLogger + from litellm.integrations.vantage.vantage_logger import \ + VantageLogger for callback in _in_memory_loggers: if isinstance(callback, VantageLogger): @@ -3949,9 +3910,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 if "LOGFIRE_TOKEN" not in os.environ: raise ValueError("LOGFIRE_TOKEN not found in environment variables") from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) + OpenTelemetry, OpenTelemetryConfig) logfire_base_url = os.getenv( "LOGFIRE_BASE_URL", "https://logfire-api.pydantic.dev" @@ -3969,9 +3928,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import ( - _PROXY_DynamicRateLimitHandler, - ) + from litellm.proxy.hooks.dynamic_rate_limiter import \ + _PROXY_DynamicRateLimitHandler for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandler): @@ -3993,9 +3951,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(dynamic_rate_limiter_obj) return dynamic_rate_limiter_obj # type: ignore elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( - _PROXY_DynamicRateLimitHandlerV3, - ) + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import \ + _PROXY_DynamicRateLimitHandlerV3 for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): @@ -4021,9 +3978,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 raise ValueError("LANGTRACE_API_KEY not found in environment variables") from litellm.integrations.opentelemetry import ( - OpenTelemetry, - OpenTelemetryConfig, - ) + OpenTelemetry, OpenTelemetryConfig) otel_config = OpenTelemetryConfig( exporter="otlp_http", @@ -4059,7 +4014,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(langfuse_logger) return langfuse_logger # type: ignore elif logging_integration == "langfuse_otel": - from litellm.integrations.langfuse.langfuse_otel import LangfuseOtelLogger + from litellm.integrations.langfuse.langfuse_otel import \ + LangfuseOtelLogger for callback in _in_memory_loggers: if ( @@ -4077,9 +4033,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 elif logging_integration == "weave_otel": from litellm.integrations.opentelemetry import OpenTelemetryConfig from litellm.integrations.weave.weave_otel import ( - WeaveOtelLogger, - get_weave_otel_config, - ) + WeaveOtelLogger, get_weave_otel_config) weave_otel_config = get_weave_otel_config() @@ -4115,9 +4069,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(anthropic_cache_control_hook) return anthropic_cache_control_hook # type: ignore elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( - VectorStorePreCallHook, - ) + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import \ + VectorStorePreCallHook for callback in _in_memory_loggers: if isinstance(callback, VectorStorePreCallHook): @@ -4177,9 +4130,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(dotprompt_logger) return dotprompt_logger # type: ignore elif logging_integration == "bitbucket": - from litellm.integrations.bitbucket.bitbucket_prompt_manager import ( - BitBucketPromptManager, - ) + from litellm.integrations.bitbucket.bitbucket_prompt_manager import \ + BitBucketPromptManager for callback in _in_memory_loggers: if isinstance(callback, BitBucketPromptManager): @@ -4196,9 +4148,8 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(bitbucket_logger) return bitbucket_logger # type: ignore elif logging_integration == "gitlab": - from litellm.integrations.gitlab.gitlab_prompt_manager import ( - GitLabPromptManager, - ) + from litellm.integrations.gitlab.gitlab_prompt_manager import \ + GitLabPromptManager for callback in _in_memory_loggers: if isinstance(callback, GitLabPromptManager): @@ -4286,7 +4237,8 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, OpenMeterLogger): return callback elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import BraintrustLogger + from litellm.integrations.braintrust_logging import \ + BraintrustLogger for callback in _in_memory_loggers: if isinstance(callback, BraintrustLogger): @@ -4296,7 +4248,8 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, GalileoObserve): return callback elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger + from litellm.integrations.cloudzero.cloudzero import \ + CloudZeroLogger for callback in _in_memory_loggers: if isinstance(callback, CloudZeroLogger): @@ -4310,7 +4263,8 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 ): # exact match; exclude subclasses like VantageLogger return callback elif logging_integration == "vantage": - from litellm.integrations.vantage.vantage_logger import VantageLogger + from litellm.integrations.vantage.vantage_logger import \ + VantageLogger for callback in _in_memory_loggers: if isinstance(callback, VantageLogger): @@ -4410,17 +4364,15 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 return callback # type: ignore elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import ( - _PROXY_DynamicRateLimitHandler, - ) + from litellm.proxy.hooks.dynamic_rate_limiter import \ + _PROXY_DynamicRateLimitHandler for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandler): return callback # type: ignore elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( - _PROXY_DynamicRateLimitHandlerV3, - ) + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import \ + _PROXY_DynamicRateLimitHandlerV3 for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): @@ -4452,9 +4404,8 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, AnthropicCacheControlHook): return callback elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( - VectorStorePreCallHook, - ) + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import \ + VectorStorePreCallHook for callback in _in_memory_loggers: if isinstance(callback, VectorStorePreCallHook): @@ -5626,7 +5577,6 @@ def _get_traceback_str_for_error(error_str: str) -> str: from decimal import Decimal - # used for unit testing from typing import Any, Dict, List, Optional, Union diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 20d06b7d53..3241c1ca93 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -6,20 +6,23 @@ import httpx import litellm from litellm._logging import verbose_proxy_logger -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import \ + use_custom_pricing_for_model from litellm.llms.anthropic import get_anthropic_config -from litellm.llms.anthropic.chat.handler import ( - ModelResponseIterator as AnthropicModelResponseIterator, -) +from litellm.llms.anthropic.chat.handler import \ + ModelResponseIterator as AnthropicModelResponseIterator from litellm.proxy._types import PassThroughEndpointLoggingTypedDict from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body -from litellm.types.passthrough_endpoints.pass_through_endpoints import ( - PassthroughStandardLoggingPayload, -) -from litellm.types.utils import LiteLLMBatch, ModelResponse, TextCompletionResponse +from litellm.types.passthrough_endpoints.pass_through_endpoints import \ + PassthroughStandardLoggingPayload +from litellm.types.utils import (LiteLLMBatch, ModelResponse, + TextCompletionResponse) if TYPE_CHECKING: - from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType + from litellm.types.passthrough_endpoints.pass_through_endpoints import \ + EndpointType from ..success_handler import PassThroughEndpointLogging else: @@ -124,10 +127,21 @@ class AnthropicPassthroughLoggingHandler: if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"): model_for_cost = f"{custom_llm_provider}/{model}" + router_model_id = logging_obj.get_router_model_id() + custom_pricing = use_custom_pricing_for_model( + litellm_params=( + logging_obj.litellm_params + if hasattr(logging_obj, "litellm_params") + else None + ) + ) + response_cost = litellm.completion_cost( completion_response=litellm_model_response, model=model_for_cost, custom_llm_provider=custom_llm_provider, + custom_pricing=custom_pricing, + router_model_id=router_model_id, ) kwargs["response_cost"] = response_cost @@ -319,9 +333,8 @@ class AnthropicPassthroughLoggingHandler: import base64 from litellm._uuid import uuid - from litellm.llms.anthropic.batches.transformation import ( - AnthropicBatchesConfig, - ) + from litellm.llms.anthropic.batches.transformation import \ + AnthropicBatchesConfig from litellm.types.utils import Choices, SpecialEnums try: @@ -537,7 +550,8 @@ class AnthropicPassthroughLoggingHandler: managed_files_hook, "store_unified_object_id" ): # Create a mock user API key dict for the managed object storage - from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy._types import (LitellmUserRoles, + UserAPIKeyAuth) user_api_key_dict = UserAPIKeyAuth( user_id=kwargs.get("user_id", "default-user"), diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index 073ee92606..a6a1074067 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -8,31 +8,29 @@ from typing import Any, Dict, List, Optional import httpx import litellm -from litellm.constants import ( - LITELLM_MAX_STREAMING_DURATION_SECONDS, - STREAM_SSE_DONE_STRING, -) +from litellm.constants import (LITELLM_MAX_STREAMING_DURATION_SECONDS, + STREAM_SSE_DONE_STRING) from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import process_response_headers -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base -from litellm.litellm_core_utils.llm_response_utils.response_metadata import ( - update_response_metadata, -) +from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.llm_response_utils.get_api_base import \ + get_api_base +from litellm.litellm_core_utils.llm_response_utils.response_metadata import \ + update_response_metadata from litellm.litellm_core_utils.thread_pool_executor import executor -from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig +from litellm.llms.base_llm.responses.transformation import \ + BaseResponsesAPIConfig from litellm.responses.utils import ResponsesAPIRequestUtils -from litellm.types.llms.openai import ( - OutputTextDeltaEvent, - ResponseAPIUsage, - ResponseCompletedEvent, - ResponsesAPIRequestParams, - ResponsesAPIResponse, - ResponsesAPIStreamEvents, - ResponsesAPIStreamingResponse, -) +from litellm.types.llms.openai import (OutputTextDeltaEvent, ResponseAPIUsage, + ResponseCompletedEvent, + ResponsesAPIRequestParams, + ResponsesAPIResponse, + ResponsesAPIStreamEvents, + ResponsesAPIStreamingResponse) from litellm.types.utils import CallTypes -from litellm.utils import CustomStreamWrapper, async_post_call_success_deployment_hook +from litellm.utils import (CustomStreamWrapper, + async_post_call_success_deployment_hook) class BaseResponsesAPIStreamingIterator: @@ -166,11 +164,16 @@ class BaseResponsesAPIStreamingIterator: ) setattr(item, "encrypted_content", wrapped_content) - # Store the completed response + # Store the completed response (also for incomplete/failed so logging still fires) + _chunk_type = getattr(openai_responses_api_chunk, "type", None) if ( openai_responses_api_chunk - and getattr(openai_responses_api_chunk, "type", None) - == ResponsesAPIStreamEvents.RESPONSE_COMPLETED + and _chunk_type + in ( + ResponsesAPIStreamEvents.RESPONSE_COMPLETED, + ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE, + ResponsesAPIStreamEvents.RESPONSE_FAILED, + ) ): self.completed_response = openai_responses_api_chunk # Add cost to usage object if include_cost_in_streaming_usage is True @@ -694,7 +697,8 @@ class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): # --------------------------------------------------------------------------- from litellm._logging import verbose_logger -from litellm.litellm_core_utils.thread_pool_executor import executor as _ws_executor +from litellm.litellm_core_utils.thread_pool_executor import \ + executor as _ws_executor RESPONSES_WS_LOGGED_EVENT_TYPES = [ "response.created", diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index fe4851283f..0f950f6da7 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -11,7 +11,8 @@ sys.path.insert( import time from litellm.constants import SENTRY_DENYLIST, SENTRY_PII_DENYLIST -from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging +from litellm.litellm_core_utils.litellm_logging import \ + Logging as LitellmLogging from litellm.litellm_core_utils.litellm_logging import set_callbacks from litellm.types.utils import ModelResponse, TextCompletionResponse @@ -139,7 +140,8 @@ def test_sentry_environment(): def test_use_custom_pricing_for_model(): - from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import \ + use_custom_pricing_for_model litellm_params = { "custom_llm_provider": "azure", @@ -154,7 +156,8 @@ def test_use_custom_pricing_for_model_via_litellm_metadata(): Generic API call routes (/messages, /responses) store model_info under litellm_metadata, not metadata. Regression test for #23185. """ - from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import \ + use_custom_pricing_for_model litellm_params = { "litellm_metadata": { @@ -170,7 +173,8 @@ def test_use_custom_pricing_for_model_via_litellm_metadata(): def test_use_custom_pricing_not_detected_litellm_metadata_no_pricing(): """Should return False when litellm_metadata.model_info has no pricing keys.""" - from litellm.litellm_core_utils.litellm_logging import use_custom_pricing_for_model + from litellm.litellm_core_utils.litellm_logging import \ + use_custom_pricing_for_model litellm_params = { "litellm_metadata": { @@ -186,7 +190,8 @@ def test_response_cost_calculator_uses_router_model_id_from_litellm_metadata(): does not carry _hidden_params (e.g. ResponsesAPIResponse from /v1/responses streaming). Regression test for custom pricing on streaming responses.""" import litellm - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj from litellm.types.llms.openai import ResponsesAPIResponse custom_model_id = "gpt-5-custom-pricing" @@ -256,6 +261,121 @@ def test_response_cost_calculator_uses_router_model_id_from_litellm_metadata(): litellm.model_cost.pop(custom_model_id, None) +class TestGetRouterModelId: + """Tests for the get_router_model_id helper method.""" + + def test_returns_id_from_litellm_metadata(self, logging_obj): + """Should extract model_info.id from litellm_metadata.""" + logging_obj.litellm_params = { + "litellm_metadata": { + "model_info": {"id": "custom-deploy-1"}, + }, + } + assert logging_obj.get_router_model_id() == "custom-deploy-1" + + def test_returns_id_from_metadata(self, logging_obj): + """Should fall back to metadata when litellm_metadata has no model_info.""" + logging_obj.litellm_params = { + "metadata": { + "model_info": {"id": "custom-deploy-2"}, + }, + } + assert logging_obj.get_router_model_id() == "custom-deploy-2" + + def test_prefers_litellm_metadata_over_metadata(self, logging_obj): + """litellm_metadata should take priority over metadata.""" + logging_obj.litellm_params = { + "litellm_metadata": { + "model_info": {"id": "from-litellm-meta"}, + }, + "metadata": { + "model_info": {"id": "from-meta"}, + }, + } + assert logging_obj.get_router_model_id() == "from-litellm-meta" + + def test_returns_none_when_no_model_info(self, logging_obj): + """Should return None when no model_info is present.""" + logging_obj.litellm_params = {"api_base": ""} + assert logging_obj.get_router_model_id() is None + + def test_returns_none_when_no_litellm_params(self): + """Should return None when litellm_params is not set.""" + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj + + obj = LiteLLMLoggingObj( + model="test", + messages=[], + stream=False, + call_type="completion", + start_time=time.time(), + litellm_call_id="x", + function_id="x", + ) + # litellm_params exists but is empty by default + assert obj.get_router_model_id() is None + + +class TestAnthropicPassthroughCustomPricing: + """Verify the Anthropic pass-through handler forwards custom pricing.""" + + def test_completion_cost_receives_custom_pricing_args(self): + """_create_anthropic_response_logging_payload should pass + custom_pricing and router_model_id to litellm.completion_cost + when the logging object carries custom pricing in model_info.""" + from unittest.mock import patch + + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj + from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import \ + AnthropicPassthroughLoggingHandler + + logging_obj = LiteLLMLoggingObj( + model="claude-sonnet-4-20250514", + messages=[{"role": "user", "content": "Hi"}], + stream=False, + call_type="anthropic_messages", + start_time=time.time(), + litellm_call_id="test-456", + function_id="test-fn", + ) + logging_obj.update_environment_variables( + model="claude-sonnet-4-20250514", + user="", + optional_params={}, + litellm_params={ + "api_base": "", + "litellm_metadata": { + "model_info": { + "id": "claude-custom-pricing", + "input_cost_per_token": 0.5, + "output_cost_per_token": 1.5, + }, + }, + }, + ) + logging_obj.model_call_details["custom_llm_provider"] = "anthropic" + + mock_response = ModelResponse() + mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5} # type: ignore + + with patch("litellm.completion_cost", return_value=42.0) as mock_cost: + AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=mock_response, + model="claude-sonnet-4-20250514", + kwargs={}, + start_time=time.time(), + end_time=time.time(), + logging_obj=logging_obj, + ) + + mock_cost.assert_called_once() + call_kwargs = mock_cost.call_args + assert call_kwargs.kwargs.get("custom_pricing") is True + assert call_kwargs.kwargs.get("router_model_id") == "claude-custom-pricing" + + class TestUpdateFromKwargs: """Tests for the update_from_kwargs convenience wrapper.""" @@ -321,9 +441,8 @@ class TestUpdateFromKwargs: def test_custom_pricing_detected_via_litellm_metadata(self, logging_obj): """Custom pricing in litellm_metadata.model_info should set custom_pricing flag.""" - from litellm.litellm_core_utils.litellm_logging import ( - use_custom_pricing_for_model, - ) + from litellm.litellm_core_utils.litellm_logging import \ + use_custom_pricing_for_model lm_meta = { "model_info": { @@ -382,7 +501,8 @@ async def test_datadog_logger_not_shadowed_by_llm_obs(monkeypatch): monkeypatch.setenv("DD_SITE", "us5.datadoghq.com") from litellm.integrations.datadog.datadog import DataDogLogger - from litellm.integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger + from litellm.integrations.datadog.datadog_llm_obs import \ + DataDogLLMObsLogger from litellm.litellm_core_utils import litellm_logging as logging_module logging_module._in_memory_loggers.clear() @@ -423,7 +543,8 @@ async def test_logfire_logger_accepts_env_vars_for_base_url(monkeypatch): ) # no trailing slash on purpose # Import after env vars are set (important if module-level caching exists) - from litellm.integrations.opentelemetry import OpenTelemetry # logger class + from litellm.integrations.opentelemetry import \ + OpenTelemetry # logger class from litellm.litellm_core_utils import litellm_logging as logging_module logging_module._in_memory_loggers.clear() @@ -752,7 +873,8 @@ def test_success_handler_runs_guardrail_logging_hook_when_enabled(logging_obj): def test_get_user_agent_tags(): - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup tags = StandardLoggingPayloadSetup._get_user_agent_tags( proxy_server_request={ @@ -767,7 +889,8 @@ def test_get_user_agent_tags(): def test_get_request_tags(): - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup tags = StandardLoggingPayloadSetup._get_request_tags( litellm_params={"metadata": {"tags": ["test-tag"]}}, @@ -794,7 +917,8 @@ def test_get_request_tags_from_metadata_and_litellm_metadata(): 4. No tags in either 5. None values for metadata/litellm_metadata """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Test case 1: Tags in metadata only tags = StandardLoggingPayloadSetup._get_request_tags( @@ -875,7 +999,8 @@ def test_get_request_tags_does_not_mutate_original_tags(): would cause User-Agent tags to be duplicated because the function was mutating the original tags list instead of creating a copy. """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create metadata with original tags original_tags = ["custom-tag-1", "custom-tag-2"] @@ -935,7 +1060,8 @@ def test_get_request_tags_does_not_mutate_original_tags(): def test_get_extra_header_tags(): """Test the _get_extra_header_tags method with various scenarios.""" import litellm - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Store original value to restore later original_extra_headers = getattr(litellm, "extra_spend_tag_headers", None) @@ -1156,7 +1282,8 @@ async def test_e2e_generate_cold_storage_object_key_successful(): from datetime import datetime, timezone from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1198,7 +1325,8 @@ async def test_e2e_generate_cold_storage_object_key_with_custom_logger_s3_path() from datetime import datetime, timezone from unittest.mock import MagicMock, patch - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1249,7 +1377,8 @@ async def test_e2e_generate_cold_storage_object_key_with_logger_no_s3_path(): from datetime import datetime, timezone from unittest.mock import MagicMock, patch - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1296,7 +1425,8 @@ async def test_e2e_generate_cold_storage_object_key_not_configured(): from unittest.mock import patch import litellm - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create test data start_time = datetime(2025, 1, 15, 10, 30, 45, 123456, timezone.utc) @@ -1320,7 +1450,8 @@ def test_get_final_response_obj_with_empty_response_obj_and_list_init(): When response_obj is empty (falsy), the method should return init_response_obj if it's a list. """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Create test objects class TestObject1: @@ -1356,7 +1487,8 @@ def test_get_usage_as_dict(): """ Test get_usage_as_dict returns usage as plain dict from response_obj or combined_usage_object. """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup from litellm.types.utils import Usage # Test case 1: None response_obj returns empty usage dict @@ -1394,7 +1526,8 @@ def test_append_system_prompt_messages(): """ Test append_system_prompt_messages prepends system message from kwargs to messages list. """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Test case 1: system in kwargs with existing messages kwargs = {"system": "You are a helpful assistant"} @@ -1465,7 +1598,8 @@ async def test_async_success_handler_sets_standard_logging_object_for_pass_throu from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a pass-through endpoint @@ -1546,7 +1680,8 @@ async def test_async_success_handler_prevents_reprocessing_for_pass_through_endp from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a pass-through endpoint @@ -1622,7 +1757,8 @@ async def test_async_success_handler_sets_standard_logging_object_for_streaming_ from datetime import datetime from unittest.mock import patch - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj from litellm.types.utils import StandardPassThroughResponseObject # Create a logging object for a streaming pass-through endpoint @@ -1678,7 +1814,8 @@ def test_get_error_information_error_code_priority(): Test get_error_information prioritizes 'code' attribute over 'status_code' attribute and handles edge cases like empty strings and "None" string values. """ - from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + from litellm.litellm_core_utils.litellm_logging import \ + StandardLoggingPayloadSetup # Test case 1: Exception with 'code' attribute (ProxyException style) class ProxyException(Exception): @@ -1871,7 +2008,8 @@ async def test_async_success_handler_preserves_response_cost_for_pass_through_en by pass-through handlers (Gemini/Vertex).""" from datetime import datetime - from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.litellm_core_utils.litellm_logging import \ + Logging as LiteLLMLoggingObj from litellm.types.utils import ModelResponse, Usage logging_obj = LiteLLMLoggingObj( From 4829de610278774d360868475316ca144b2c0e27 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 09:28:55 +0530 Subject: [PATCH 143/539] fix(proxy): allow non-admin users to access pass-through subpath routes with auth When a pass-through endpoint has both auth=true and include_subpath=true, non-admin users got 401 errors on subpath requests because only the base path was registered in openai_routes. Now the wildcard path is also registered so the auth check recognizes subpath requests as LLM API routes. Also fixes pre-existing pyright error where logging_obj was possibly unbound in the except block. --- .../pass_through_endpoints.py | 7 ++++ .../proxy/auth/test_route_checks.py | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 0aa9968520..cf6ead974f 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -651,6 +651,7 @@ async def pass_through_request( # noqa: PLR0915 _parsed_body: Optional[dict] = None # kwargs for pass through endpoint, contains metadata, litellm_params, call_type, litellm_call_id, passthrough_logging_payload kwargs: Optional[dict] = None + logging_obj: Optional[Logging] = None ######################################################### try: @@ -2262,6 +2263,12 @@ async def initialize_pass_through_endpoints( # Add wildcard route for sub-paths if endpoint.get("include_subpath", False) is True: + # Register wildcard path in openai_routes so non-admin users + # can access subpath routes when auth is enabled + if _auth is not None and str(_auth).lower() == "true": + _wildcard_path = _path.rstrip("/") + "/*" + if _wildcard_path not in LiteLLMRoutes.openai_routes.value: + LiteLLMRoutes.openai_routes.value.append(_wildcard_path) InitPassThroughEndpointHelpers.add_subpath_route( app=app, path=_path, diff --git a/tests/test_litellm/proxy/auth/test_route_checks.py b/tests/test_litellm/proxy/auth/test_route_checks.py index f20c14aa61..20fb56e2dc 100644 --- a/tests/test_litellm/proxy/auth/test_route_checks.py +++ b/tests/test_litellm/proxy/auth/test_route_checks.py @@ -1329,3 +1329,41 @@ def test_non_org_admin_with_organizations_list(): organization_memberships=[membership], ) assert _user_is_org_admin({"organizations": ["org-1"]}, user_obj) is False + + +def test_pass_through_subpath_auth_with_wildcard_in_openai_routes(): + """ + Test that pass-through endpoints with include_subpath=true and auth=true + are accessible to non-admin users via wildcard route matching. + + When auth=true and include_subpath=true, the wildcard path (e.g. /custom-endpoint/*) + should be added to openai_routes so that subpath requests like + /custom-endpoint/v1/infer are recognized as LLM API routes. + + Regression test for: non-admin users getting 401 "Only proxy admin" error + on pass-through subpath requests. + """ + from litellm.proxy._types import LiteLLMRoutes + + base_path = "/v1/ocr/nvidia/community/nemoretriever-ocr-v1" + wildcard_path = base_path + "/*" + + # Simulate what init_pass_through_endpoints does when auth=true + include_subpath=true + original_routes = LiteLLMRoutes.openai_routes.value[:] + try: + LiteLLMRoutes.openai_routes.value.append(base_path) + LiteLLMRoutes.openai_routes.value.append(wildcard_path) + + # Exact path should match + assert RouteChecks.is_llm_api_route(base_path) is True + + # Subpath should match via wildcard + assert RouteChecks.is_llm_api_route(base_path + "/v1/infer") is True + + # Deeper subpath should also match + assert RouteChecks.is_llm_api_route(base_path + "/v1/some/deep/path") is True + + # Unrelated route should not match + assert RouteChecks.is_llm_api_route("/v1/some-other-endpoint") is False + finally: + LiteLLMRoutes.openai_routes.value[:] = original_routes From 97b7358791052b4056200002d17c76b88479e8f8 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 09:44:30 +0530 Subject: [PATCH 144/539] fix(proxy): dedup openai_routes on reload and clean up on endpoint removal - Add dedup guard for base path registration (prevents unbounded list growth on config reload) - Clean up base path and wildcard path from openai_routes when an endpoint is removed via remove_endpoint_routes - Rewrite test to exercise initialize_pass_through_endpoints directly, covering registration, dedup on reload, and cleanup on removal --- .../pass_through_endpoints.py | 16 ++++- .../proxy/auth/test_route_checks.py | 70 +++++++++++++------ 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index cf6ead974f..3beaa32b31 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -2022,13 +2022,24 @@ class InitPassThroughEndpointHelpers: @staticmethod def remove_endpoint_routes(endpoint_id: str): - """Remove all routes for a specific endpoint ID from the registry""" + """Remove all routes for a specific endpoint ID from the registry + and clean up corresponding entries from LiteLLMRoutes.openai_routes.""" keys_to_remove = [ key for key, value in _registered_pass_through_routes.items() if value["endpoint_id"] == endpoint_id ] for key in keys_to_remove: + route_info = _registered_pass_through_routes[key] + path = route_info.get("path") + if isinstance(path, str): + # Remove base path and wildcard path from openai_routes + openai_routes = LiteLLMRoutes.openai_routes.value + if path in openai_routes: + openai_routes.remove(path) + wildcard_path = path.rstrip("/") + "/*" + if wildcard_path in openai_routes: + openai_routes.remove(wildcard_path) del _registered_pass_through_routes[key] verbose_proxy_logger.debug( "Removed pass-through route from registry: %s", key @@ -2224,7 +2235,8 @@ async def initialize_pass_through_endpoints( ) ) _dependencies = [Depends(user_api_key_auth)] - LiteLLMRoutes.openai_routes.value.append(_path) + if _path not in LiteLLMRoutes.openai_routes.value: + LiteLLMRoutes.openai_routes.value.append(_path) if _target is None: continue diff --git a/tests/test_litellm/proxy/auth/test_route_checks.py b/tests/test_litellm/proxy/auth/test_route_checks.py index 20fb56e2dc..b6c4fa73df 100644 --- a/tests/test_litellm/proxy/auth/test_route_checks.py +++ b/tests/test_litellm/proxy/auth/test_route_checks.py @@ -1331,39 +1331,69 @@ def test_non_org_admin_with_organizations_list(): assert _user_is_org_admin({"organizations": ["org-1"]}, user_obj) is False -def test_pass_through_subpath_auth_with_wildcard_in_openai_routes(): +@pytest.mark.asyncio +async def test_initialize_pass_through_registers_wildcard_for_auth_subpath(): """ - Test that pass-through endpoints with include_subpath=true and auth=true - are accessible to non-admin users via wildcard route matching. + Test that initialize_pass_through_endpoints registers both base path and + wildcard path in openai_routes when auth=true and include_subpath=true, + and that subpath requests pass is_llm_api_route. - When auth=true and include_subpath=true, the wildcard path (e.g. /custom-endpoint/*) - should be added to openai_routes so that subpath requests like - /custom-endpoint/v1/infer are recognized as LLM API routes. - - Regression test for: non-admin users getting 401 "Only proxy admin" error - on pass-through subpath requests. + Also verifies: + - Dedup: calling init twice does not duplicate entries + - Cleanup: removing the endpoint cleans up openai_routes """ from litellm.proxy._types import LiteLLMRoutes + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + InitPassThroughEndpointHelpers, + initialize_pass_through_endpoints, + ) base_path = "/v1/ocr/nvidia/community/nemoretriever-ocr-v1" wildcard_path = base_path + "/*" - # Simulate what init_pass_through_endpoints does when auth=true + include_subpath=true + endpoint_config = { + "path": base_path, + "target": "https://httpbin.org/post", + "include_subpath": True, + "auth": True, + "headers": {"content-type": "application/json"}, + } + original_routes = LiteLLMRoutes.openai_routes.value[:] try: - LiteLLMRoutes.openai_routes.value.append(base_path) - LiteLLMRoutes.openai_routes.value.append(wildcard_path) + with patch( + "litellm.proxy.proxy_server.app", + MagicMock(), + ), patch( + "litellm.proxy.proxy_server.premium_user", + True, + ), patch( + "litellm.proxy.proxy_server.config_passthrough_endpoints", + None, + ): + await initialize_pass_through_endpoints([endpoint_config]) - # Exact path should match - assert RouteChecks.is_llm_api_route(base_path) is True + # Both base and wildcard paths should be registered + assert base_path in LiteLLMRoutes.openai_routes.value + assert wildcard_path in LiteLLMRoutes.openai_routes.value - # Subpath should match via wildcard - assert RouteChecks.is_llm_api_route(base_path + "/v1/infer") is True + # Subpath requests should pass the auth route check + assert RouteChecks.is_llm_api_route(base_path) is True + assert RouteChecks.is_llm_api_route(base_path + "/v1/infer") is True - # Deeper subpath should also match - assert RouteChecks.is_llm_api_route(base_path + "/v1/some/deep/path") is True + # Calling init again should not duplicate entries + await initialize_pass_through_endpoints([endpoint_config]) + assert LiteLLMRoutes.openai_routes.value.count(base_path) == 1 + assert LiteLLMRoutes.openai_routes.value.count(wildcard_path) == 1 - # Unrelated route should not match - assert RouteChecks.is_llm_api_route("/v1/some-other-endpoint") is False + # Removing the endpoint should clean up openai_routes + # remove_endpoint_routes takes endpoint_id (UUID portion of + # the route key "{id}:exact:{path}:{methods}") + registered = InitPassThroughEndpointHelpers.get_all_registered_pass_through_routes() + endpoint_ids = {k.split(":")[0] for k in registered} + for eid in endpoint_ids: + InitPassThroughEndpointHelpers.remove_endpoint_routes(eid) + assert base_path not in LiteLLMRoutes.openai_routes.value + assert wildcard_path not in LiteLLMRoutes.openai_routes.value finally: LiteLLMRoutes.openai_routes.value[:] = original_routes From f7803d2d6d337d94faf34bac82d9441b52507c2f Mon Sep 17 00:00:00 2001 From: joereyna Date: Wed, 18 Mar 2026 21:21:07 -0700 Subject: [PATCH 145/539] chore: regenerate poetry.lock to unblock CI (pyproject.toml content hash drift) --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index e3b083d778..dc25864442 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8018,4 +8018,4 @@ utils = ["numpydoc"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "eda34dfd8b35474beffee18893d6782c7b3d0d3d2c610f66237eb97176f43527" +content-hash = "2cf958f1a04fd5f1ab0e5cfc33bdbf441b518ed6c82d0f2546bf64cd3d2f89be" From ab1744f9fe197208085483e641c1abd27a30d516 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 09:53:22 +0530 Subject: [PATCH 146/539] fix(proxy): scope wildcard cleanup to subpath entries and restore registry in test - Only remove wildcard path from openai_routes when the route entry has type="subpath", avoiding accidental removal when two endpoints share the same base path but differ in include_subpath - Clean up _registered_pass_through_routes in the test finally block to prevent stale entries from polluting subsequent tests on failure --- .../pass_through_endpoints/pass_through_endpoints.py | 8 ++++---- tests/test_litellm/proxy/auth/test_route_checks.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 3beaa32b31..0f676a1feb 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -2033,13 +2033,13 @@ class InitPassThroughEndpointHelpers: route_info = _registered_pass_through_routes[key] path = route_info.get("path") if isinstance(path, str): - # Remove base path and wildcard path from openai_routes openai_routes = LiteLLMRoutes.openai_routes.value if path in openai_routes: openai_routes.remove(path) - wildcard_path = path.rstrip("/") + "/*" - if wildcard_path in openai_routes: - openai_routes.remove(wildcard_path) + if route_info.get("type") == "subpath": + wildcard_path = path.rstrip("/") + "/*" + if wildcard_path in openai_routes: + openai_routes.remove(wildcard_path) del _registered_pass_through_routes[key] verbose_proxy_logger.debug( "Removed pass-through route from registry: %s", key diff --git a/tests/test_litellm/proxy/auth/test_route_checks.py b/tests/test_litellm/proxy/auth/test_route_checks.py index b6c4fa73df..83703cd4ed 100644 --- a/tests/test_litellm/proxy/auth/test_route_checks.py +++ b/tests/test_litellm/proxy/auth/test_route_checks.py @@ -1397,3 +1397,10 @@ async def test_initialize_pass_through_registers_wildcard_for_auth_subpath(): assert wildcard_path not in LiteLLMRoutes.openai_routes.value finally: LiteLLMRoutes.openai_routes.value[:] = original_routes + # Clean up any routes registered during this test to avoid + # polluting the module-level _registered_pass_through_routes + registered = InitPassThroughEndpointHelpers.get_all_registered_pass_through_routes() + for k in registered: + InitPassThroughEndpointHelpers.remove_endpoint_routes( + k.split(":")[0] + ) From 08f0cbc2e939c6456feb2f7ef8edf118644748fa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 18 Mar 2026 21:36:39 -0700 Subject: [PATCH 147/539] fix: address greptile feedback --- litellm/litellm_core_utils/litellm_logging.py | 266 +++++++++++------- litellm/responses/streaming_iterator.py | 73 +++-- ...t_base_responses_api_streaming_iterator.py | 156 +++++++++- 3 files changed, 369 insertions(+), 126 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 5e34a4c992..27cef85818 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -12,28 +12,47 @@ import time import traceback from datetime import datetime as dt_object from functools import lru_cache -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, - Optional, Tuple, Type, Union, cast) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, + cast, +) from httpx import Response from pydantic import BaseModel import litellm -from litellm import (_custom_logger_compatible_callbacks_literal, json_logs, - log_raw_request_response, turn_off_message_logging) +from litellm import ( + _custom_logger_compatible_callbacks_literal, + json_logs, + log_raw_request_response, + turn_off_message_logging, +) from litellm._logging import _is_debugging_on, verbose_logger from litellm._uuid import uuid from litellm.batches.batch_utils import _handle_completed_batch from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching_handler import LLMCachingHandler -from litellm.constants import (DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, - DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, - SENTRY_DENYLIST, SENTRY_PII_DENYLIST) -from litellm.cost_calculator import (RealtimeAPITokenUsageProcessor, - _select_model_name_for_cost_calc) +from litellm.constants import ( + DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT, + DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT, + SENTRY_DENYLIST, + SENTRY_PII_DENYLIST, +) +from litellm.cost_calculator import ( + RealtimeAPITokenUsageProcessor, + _select_model_name_for_cost_calc, +) from litellm.integrations.agentops import AgentOps -from litellm.integrations.anthropic_cache_control_hook import \ - AnthropicCacheControlHook +from litellm.integrations.anthropic_cache_control_hook import AnthropicCacheControlHook from litellm.integrations.arize.arize import ArizeLogger from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger @@ -42,48 +61,72 @@ from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.sqs import SQSLogger from litellm.litellm_core_utils.core_helpers import reconstruct_model_name from litellm.litellm_core_utils.get_litellm_params import get_litellm_params -from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import \ - StandardBuiltInToolCostTracking -from litellm.litellm_core_utils.logging_utils import \ - truncate_base64_in_messages +from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( + StandardBuiltInToolCostTracking, +) +from litellm.litellm_core_utils.logging_utils import truncate_base64_in_messages from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, - redact_message_input_output_from_logging) + redact_message_input_output_from_logging, +) from litellm.llms.base_llm.ocr.transformation import OCRResponse from litellm.llms.base_llm.search.transformation import SearchResponse from litellm.responses.utils import ResponseAPILoggingUtils from litellm.types.agents import LiteLLMSendMessageResponse from litellm.types.containers.main import ContainerObject -from litellm.types.llms.openai import (AllMessageValues, Batch, FineTuningJob, - HttpxBinaryResponseContent, - OpenAIFileObject, - OpenAIModerationResponse, - ResponseAPIUsage, - ResponseCompletedEvent, - ResponseFailedEvent, - ResponseIncompleteEvent, - ResponsesAPIResponse) +from litellm.types.llms.openai import ( + AllMessageValues, + Batch, + FineTuningJob, + HttpxBinaryResponseContent, + OpenAIFileObject, + OpenAIModerationResponse, + ResponseAPIUsage, + ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponsesAPIResponse, +) from litellm.types.mcp import MCPPostCallResponseObject from litellm.types.prompts.init_prompts import PromptSpec from litellm.types.rerank import RerankResponse from litellm.types.utils import ( - CachingDetails, CallTypes, CostBreakdown, CostResponseTypes, - CustomPricingLiteLLMParams, DynamicPromptManagementParamLiteral, - EmbeddingResponse, GuardrailStatus, ImageResponse, LiteLLMBatch, - LiteLLMLoggingBaseClass, LiteLLMRealtimeStreamLoggingObject, ModelResponse, - ModelResponseStream, RawRequestTypedDict, StandardBuiltInToolsParams, - StandardCallbackDynamicParams, StandardLoggingAdditionalHeaders, - StandardLoggingHiddenParams, StandardLoggingMCPToolCall, - StandardLoggingMetadata, StandardLoggingModelCostFailureDebugInformation, - StandardLoggingModelInformation, StandardLoggingPayload, - StandardLoggingPayloadErrorInformation, StandardLoggingPayloadStatus, + CachingDetails, + CallTypes, + CostBreakdown, + CostResponseTypes, + CustomPricingLiteLLMParams, + DynamicPromptManagementParamLiteral, + EmbeddingResponse, + GuardrailStatus, + ImageResponse, + LiteLLMBatch, + LiteLLMLoggingBaseClass, + LiteLLMRealtimeStreamLoggingObject, + ModelResponse, + ModelResponseStream, + RawRequestTypedDict, + StandardBuiltInToolsParams, + StandardCallbackDynamicParams, + StandardLoggingAdditionalHeaders, + StandardLoggingHiddenParams, + StandardLoggingMCPToolCall, + StandardLoggingMetadata, + StandardLoggingModelCostFailureDebugInformation, + StandardLoggingModelInformation, + StandardLoggingPayload, + StandardLoggingPayloadErrorInformation, + StandardLoggingPayloadStatus, StandardLoggingPayloadStatusFields, - StandardLoggingPromptManagementMetadata, StandardLoggingVectorStoreRequest, - TextCompletionResponse, TranscriptionResponse, Usage) + StandardLoggingPromptManagementMetadata, + StandardLoggingVectorStoreRequest, + TextCompletionResponse, + TranscriptionResponse, + Usage, +) from litellm.types.videos.main import VideoObject -from litellm.utils import (_get_base_model_from_metadata, executor, - print_verbose) +from litellm.utils import _get_base_model_from_metadata, executor, print_verbose from ..integrations.argilla import ArgillaLogger from ..integrations.arize.arize_phoenix import ArizePhoenixLogger @@ -105,8 +148,7 @@ from ..integrations.humanloop import HumanloopLogger from ..integrations.lago import LagoLogger from ..integrations.langfuse.langfuse import LangFuseLogger from ..integrations.langfuse.langfuse_handler import LangFuseHandler -from ..integrations.langfuse.langfuse_prompt_management import \ - LangfusePromptManagement +from ..integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement from ..integrations.langsmith import LangsmithLogger from ..integrations.litellm_agent import LiteLLMAgentModelResolver from ..integrations.literal_ai import LiteralAILogger @@ -121,30 +163,34 @@ from ..integrations.s3_v2 import S3Logger as S3V2Logger from ..integrations.supabase import Supabase from ..integrations.traceloop import TraceloopLogger from .exception_mapping_utils import _get_response_headers -from .initialize_dynamic_callback_params import \ - initialize_standard_callback_dynamic_params as \ - _initialize_standard_callback_dynamic_params +from .initialize_dynamic_callback_params import ( + initialize_standard_callback_dynamic_params as _initialize_standard_callback_dynamic_params, +) from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache if TYPE_CHECKING: - from litellm.llms.base_llm.passthrough.transformation import \ - BasePassthroughConfig + from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig try: - from litellm_enterprise.enterprise_callbacks.callback_controls import \ - EnterpriseCallbackControls - from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import \ - PagerDutyAlerting - from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import \ - ResendEmailLogger - from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import \ - SendGridEmailLogger - from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import \ - SMTPEmailLogger - from litellm_enterprise.litellm_core_utils.litellm_logging import \ - StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup + from litellm_enterprise.enterprise_callbacks.callback_controls import ( + EnterpriseCallbackControls, + ) + from litellm_enterprise.enterprise_callbacks.pagerduty.pagerduty import ( + PagerDutyAlerting, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.resend_email import ( + ResendEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.sendgrid_email import ( + SendGridEmailLogger, + ) + from litellm_enterprise.enterprise_callbacks.send_emails.smtp_email import ( + SMTPEmailLogger, + ) + from litellm_enterprise.litellm_core_utils.litellm_logging import ( + StandardLoggingPayloadSetup as EnterpriseStandardLoggingPayloadSetup, + ) - from litellm.integrations.generic_api.generic_api_callback import \ - GenericAPILogger + from litellm.integrations.generic_api.generic_api_callback import GenericAPILogger EnterpriseStandardLoggingPayloadSetupVAR: Optional[ Type[EnterpriseStandardLoggingPayloadSetup] @@ -1723,8 +1769,9 @@ class Logging(LiteLLMLoggingBaseClass): ) standard_logging_payload["response"] = response_dict elif isinstance(result, TranscriptionResponse): - from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import \ - TranscriptionUsageObjectTransformation + from litellm.litellm_core_utils.llm_cost_calc.usage_object_transformation import ( + TranscriptionUsageObjectTransformation, + ) result = result.model_copy() transformed_usage = TranscriptionUsageObjectTransformation.transform_transcription_usage_object(result.usage) # type: ignore @@ -2407,8 +2454,9 @@ class Logging(LiteLLMLoggingBaseClass): ): # polling job will query these frequently, don't spam db logs return - from litellm.proxy.openai_files_endpoints.common_utils import \ - _is_base64_encoded_unified_file_id + from litellm.proxy.openai_files_endpoints.common_utils import ( + _is_base64_encoded_unified_file_id, + ) # check if file id is a unified file id is_base64_unified_file_id = _is_base64_encoded_unified_file_id(result.id) @@ -3305,7 +3353,6 @@ class Logging(LiteLLMLoggingBaseClass): return result.response else: return None - return None def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse: """ @@ -3551,8 +3598,7 @@ def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 elif callback == "s3": s3Logger = S3Logger() elif callback == "wandb": - from litellm.integrations.weights_biases import \ - WeightsBiasesLogger + from litellm.integrations.weights_biases import WeightsBiasesLogger weightsBiasesLogger = WeightsBiasesLogger() elif callback == "logfire": @@ -3616,8 +3662,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_posthog_logger) return _posthog_logger # type: ignore elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import \ - BraintrustLogger + from litellm.integrations.braintrust_logging import BraintrustLogger for callback in _in_memory_loggers: if isinstance(callback, BraintrustLogger): @@ -3738,7 +3783,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 return _opik_logger # type: ignore elif logging_integration == "arize": from litellm.integrations.opentelemetry import ( - OpenTelemetry, OpenTelemetryConfig) + OpenTelemetry, + OpenTelemetryConfig, + ) arize_config = ArizeLogger.get_arize_config() if arize_config.endpoint is None: @@ -3765,7 +3812,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 return _arize_otel_logger # type: ignore elif logging_integration == "arize_phoenix": from litellm.integrations.opentelemetry import ( - OpenTelemetry, OpenTelemetryConfig) + OpenTelemetry, + OpenTelemetryConfig, + ) arize_phoenix_config = ArizePhoenixLogger.get_arize_phoenix_config() otel_config = OpenTelemetryConfig( @@ -3819,7 +3868,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 elif logging_integration == "levo": from litellm.integrations.levo.levo import LevoLogger from litellm.integrations.opentelemetry import ( - OpenTelemetry, OpenTelemetryConfig) + OpenTelemetry, + OpenTelemetryConfig, + ) levo_config = LevoLogger.get_levo_config() otel_config = OpenTelemetryConfig( @@ -3868,8 +3919,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(galileo_logger) return galileo_logger # type: ignore elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import \ - CloudZeroLogger + from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger for callback in _in_memory_loggers: if isinstance(callback, CloudZeroLogger): @@ -3889,8 +3939,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(focus_logger) return focus_logger # type: ignore elif logging_integration == "vantage": - from litellm.integrations.vantage.vantage_logger import \ - VantageLogger + from litellm.integrations.vantage.vantage_logger import VantageLogger for callback in _in_memory_loggers: if isinstance(callback, VantageLogger): @@ -3910,7 +3959,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 if "LOGFIRE_TOKEN" not in os.environ: raise ValueError("LOGFIRE_TOKEN not found in environment variables") from litellm.integrations.opentelemetry import ( - OpenTelemetry, OpenTelemetryConfig) + OpenTelemetry, + OpenTelemetryConfig, + ) logfire_base_url = os.getenv( "LOGFIRE_BASE_URL", "https://logfire-api.pydantic.dev" @@ -3928,8 +3979,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import \ - _PROXY_DynamicRateLimitHandler + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandler): @@ -3951,8 +4003,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(dynamic_rate_limiter_obj) return dynamic_rate_limiter_obj # type: ignore elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import \ - _PROXY_DynamicRateLimitHandlerV3 + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( + _PROXY_DynamicRateLimitHandlerV3, + ) for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): @@ -3978,7 +4031,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 raise ValueError("LANGTRACE_API_KEY not found in environment variables") from litellm.integrations.opentelemetry import ( - OpenTelemetry, OpenTelemetryConfig) + OpenTelemetry, + OpenTelemetryConfig, + ) otel_config = OpenTelemetryConfig( exporter="otlp_http", @@ -4014,8 +4069,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(langfuse_logger) return langfuse_logger # type: ignore elif logging_integration == "langfuse_otel": - from litellm.integrations.langfuse.langfuse_otel import \ - LangfuseOtelLogger + from litellm.integrations.langfuse.langfuse_otel import LangfuseOtelLogger for callback in _in_memory_loggers: if ( @@ -4033,7 +4087,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 elif logging_integration == "weave_otel": from litellm.integrations.opentelemetry import OpenTelemetryConfig from litellm.integrations.weave.weave_otel import ( - WeaveOtelLogger, get_weave_otel_config) + WeaveOtelLogger, + get_weave_otel_config, + ) weave_otel_config = get_weave_otel_config() @@ -4069,8 +4125,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(anthropic_cache_control_hook) return anthropic_cache_control_hook # type: ignore elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import \ - VectorStorePreCallHook + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( + VectorStorePreCallHook, + ) for callback in _in_memory_loggers: if isinstance(callback, VectorStorePreCallHook): @@ -4130,8 +4187,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(dotprompt_logger) return dotprompt_logger # type: ignore elif logging_integration == "bitbucket": - from litellm.integrations.bitbucket.bitbucket_prompt_manager import \ - BitBucketPromptManager + from litellm.integrations.bitbucket.bitbucket_prompt_manager import ( + BitBucketPromptManager, + ) for callback in _in_memory_loggers: if isinstance(callback, BitBucketPromptManager): @@ -4148,8 +4206,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(bitbucket_logger) return bitbucket_logger # type: ignore elif logging_integration == "gitlab": - from litellm.integrations.gitlab.gitlab_prompt_manager import \ - GitLabPromptManager + from litellm.integrations.gitlab.gitlab_prompt_manager import ( + GitLabPromptManager, + ) for callback in _in_memory_loggers: if isinstance(callback, GitLabPromptManager): @@ -4237,8 +4296,7 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, OpenMeterLogger): return callback elif logging_integration == "braintrust": - from litellm.integrations.braintrust_logging import \ - BraintrustLogger + from litellm.integrations.braintrust_logging import BraintrustLogger for callback in _in_memory_loggers: if isinstance(callback, BraintrustLogger): @@ -4248,8 +4306,7 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, GalileoObserve): return callback elif logging_integration == "cloudzero": - from litellm.integrations.cloudzero.cloudzero import \ - CloudZeroLogger + from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger for callback in _in_memory_loggers: if isinstance(callback, CloudZeroLogger): @@ -4263,8 +4320,7 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 ): # exact match; exclude subclasses like VantageLogger return callback elif logging_integration == "vantage": - from litellm.integrations.vantage.vantage_logger import \ - VantageLogger + from litellm.integrations.vantage.vantage_logger import VantageLogger for callback in _in_memory_loggers: if isinstance(callback, VantageLogger): @@ -4364,15 +4420,17 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 return callback # type: ignore elif logging_integration == "dynamic_rate_limiter": - from litellm.proxy.hooks.dynamic_rate_limiter import \ - _PROXY_DynamicRateLimitHandler + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandler): return callback # type: ignore elif logging_integration == "dynamic_rate_limiter_v3": - from litellm.proxy.hooks.dynamic_rate_limiter_v3 import \ - _PROXY_DynamicRateLimitHandlerV3 + from litellm.proxy.hooks.dynamic_rate_limiter_v3 import ( + _PROXY_DynamicRateLimitHandlerV3, + ) for callback in _in_memory_loggers: if isinstance(callback, _PROXY_DynamicRateLimitHandlerV3): @@ -4404,8 +4462,9 @@ def get_custom_logger_compatible_class( # noqa: PLR0915 if isinstance(callback, AnthropicCacheControlHook): return callback elif logging_integration == "vector_store_pre_call_hook": - from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import \ - VectorStorePreCallHook + from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import ( + VectorStorePreCallHook, + ) for callback in _in_memory_loggers: if isinstance(callback, VectorStorePreCallHook): @@ -5577,6 +5636,7 @@ def _get_traceback_str_for_error(error_str: str) -> str: from decimal import Decimal + # used for unit testing from typing import Any, Dict, List, Optional, Union diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index a6a1074067..8a91368dd6 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -8,29 +8,31 @@ from typing import Any, Dict, List, Optional import httpx import litellm -from litellm.constants import (LITELLM_MAX_STREAMING_DURATION_SECONDS, - STREAM_SSE_DONE_STRING) +from litellm.constants import ( + LITELLM_MAX_STREAMING_DURATION_SECONDS, + STREAM_SSE_DONE_STRING, +) from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import process_response_headers -from litellm.litellm_core_utils.litellm_logging import \ - Logging as LiteLLMLoggingObj -from litellm.litellm_core_utils.llm_response_utils.get_api_base import \ - get_api_base -from litellm.litellm_core_utils.llm_response_utils.response_metadata import \ - update_response_metadata +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base +from litellm.litellm_core_utils.llm_response_utils.response_metadata import ( + update_response_metadata, +) from litellm.litellm_core_utils.thread_pool_executor import executor -from litellm.llms.base_llm.responses.transformation import \ - BaseResponsesAPIConfig +from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.responses.utils import ResponsesAPIRequestUtils -from litellm.types.llms.openai import (OutputTextDeltaEvent, ResponseAPIUsage, - ResponseCompletedEvent, - ResponsesAPIRequestParams, - ResponsesAPIResponse, - ResponsesAPIStreamEvents, - ResponsesAPIStreamingResponse) +from litellm.types.llms.openai import ( + OutputTextDeltaEvent, + ResponseAPIUsage, + ResponseCompletedEvent, + ResponsesAPIRequestParams, + ResponsesAPIResponse, + ResponsesAPIStreamEvents, + ResponsesAPIStreamingResponse, +) from litellm.types.utils import CallTypes -from litellm.utils import (CustomStreamWrapper, - async_post_call_success_deployment_hook) +from litellm.utils import CustomStreamWrapper, async_post_call_success_deployment_hook class BaseResponsesAPIStreamingIterator: @@ -198,10 +200,12 @@ class BaseResponsesAPIStreamingIterator: if cost is not None: setattr(usage_obj, "cost", cost) except Exception: - # If cost calculation fails, continue without cost pass - self._handle_logging_completed_response() + if _chunk_type == ResponsesAPIStreamEvents.RESPONSE_FAILED: + self._handle_logging_failed_response() + else: + self._handle_logging_completed_response() return openai_responses_api_chunk @@ -219,6 +223,32 @@ class BaseResponsesAPIStreamingIterator: """Base implementation - should be overridden by subclasses""" pass + def _handle_logging_failed_response(self): + """ + Handle logging for RESPONSE_FAILED events by routing to failure handlers. + + Unlike _handle_logging_completed_response (which calls success handlers), + this constructs an exception from the response error and routes to + async_failure_handler / failure_handler so logging integrations correctly + record the call as failed. + """ + response_obj = ( + getattr(self.completed_response, "response", None) + if self.completed_response + else None + ) + error_info = getattr(response_obj, "error", None) if response_obj else None + error_message = "Response failed" + if isinstance(error_info, dict): + error_message = error_info.get("message", str(error_info)) + exception = litellm.APIError( + status_code=500, + message=error_message, + llm_provider=self.custom_llm_provider or "", + model=self.model or "", + ) + self._handle_failure(exception) + async def _call_post_streaming_deployment_hook(self, chunk): """ Allow callbacks to modify streaming chunks before returning (parity with chat). @@ -697,8 +727,7 @@ class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): # --------------------------------------------------------------------------- from litellm._logging import verbose_logger -from litellm.litellm_core_utils.thread_pool_executor import \ - executor as _ws_executor +from litellm.litellm_core_utils.thread_pool_executor import executor as _ws_executor RESPONSES_WS_LOGGED_EVENT_TYPES = [ "response.created", diff --git a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py index 860445d875..e9181d810e 100644 --- a/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py +++ b/tests/llm_responses_api_testing/test_base_responses_api_streaming_iterator.py @@ -30,9 +30,11 @@ from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterat from litellm.responses.utils import ResponsesAPIRequestUtils from litellm.types.llms.openai import ( ResponseCompletedEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, ResponsesAPIResponse, ResponsesAPIStreamEvents, - OutputTextDeltaEvent + OutputTextDeltaEvent, ) @@ -429,3 +431,155 @@ class TestBaseResponsesAPIStreamingIterator: mock_logging_obj.async_failure_handler.assert_not_called() mock_logging_obj.failure_handler.assert_not_called() + def test_process_chunk_response_failed_calls_failure_handler(self): + """ + Test that a RESPONSE_FAILED event routes to failure handlers, + not success handlers. Failed responses represent genuine LLM-level + errors and should be logged as failures. + """ + from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator + + mock_response = Mock() + mock_response.headers = {} + mock_response.aiter_lines = Mock() + mock_logging_obj = Mock(spec=LiteLLMLoggingObj) + mock_logging_obj.model_call_details = {"litellm_params": {}} + mock_logging_obj.async_failure_handler = Mock() + mock_logging_obj.failure_handler = Mock() + mock_logging_obj.async_success_handler = Mock() + mock_logging_obj.success_handler = Mock() + mock_config = Mock(spec=BaseResponsesAPIConfig) + + mock_responses_api_response = Mock(spec=ResponsesAPIResponse) + mock_responses_api_response.id = "resp_failed_123" + mock_responses_api_response.error = { + "type": "server_error", + "message": "The model encountered an error", + } + mock_responses_api_response.usage = None + + mock_failed_event = Mock(spec=ResponseFailedEvent) + mock_failed_event.type = ResponsesAPIStreamEvents.RESPONSE_FAILED + mock_failed_event.response = mock_responses_api_response + + mock_config.transform_streaming_response.return_value = mock_failed_event + + iterator = ResponsesAPIStreamingIterator( + response=mock_response, + model="gpt-4", + responses_api_provider_config=mock_config, + logging_obj=mock_logging_obj, + litellm_metadata={"model_info": {"id": "model_123"}}, + custom_llm_provider="openai", + ) + + test_chunk_data = { + "type": "response.failed", + "response": { + "id": "resp_failed_123", + "error": { + "type": "server_error", + "message": "The model encountered an error", + }, + }, + } + + with patch.object( + ResponsesAPIRequestUtils, + "_update_responses_api_response_id_with_model_id", + return_value=mock_responses_api_response, + ), patch( + "litellm.responses.streaming_iterator.run_async_function" + ) as mock_run_async, patch( + "litellm.responses.streaming_iterator.executor" + ) as mock_executor: + result = iterator._process_chunk(json.dumps(test_chunk_data)) + + assert result is not None + assert result.type == ResponsesAPIStreamEvents.RESPONSE_FAILED + assert iterator.completed_response == result + + # Failure handler should have been called via _handle_failure + mock_run_async.assert_called_once() + call_kwargs = mock_run_async.call_args + assert ( + call_kwargs[1]["async_function"] + == mock_logging_obj.async_failure_handler + ) + + mock_executor.submit.assert_called_once() + submit_args = mock_executor.submit.call_args + assert submit_args[0][0] == mock_logging_obj.failure_handler + + def test_process_chunk_response_incomplete_calls_success_handler(self): + """ + Test that a RESPONSE_INCOMPLETE event routes to success handlers. + Incomplete responses (e.g. max_output_tokens reached) are still valid + responses with usage data — analogous to finish_reason='length' in chat. + """ + from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator + + mock_response = Mock() + mock_response.headers = {} + mock_response.aiter_lines = Mock() + mock_logging_obj = Mock(spec=LiteLLMLoggingObj) + mock_logging_obj.model_call_details = {"litellm_params": {}} + mock_logging_obj.async_failure_handler = Mock() + mock_logging_obj.failure_handler = Mock() + mock_logging_obj.async_success_handler = Mock() + mock_logging_obj.success_handler = Mock() + mock_config = Mock(spec=BaseResponsesAPIConfig) + + mock_responses_api_response = Mock(spec=ResponsesAPIResponse) + mock_responses_api_response.id = "resp_incomplete_123" + mock_responses_api_response.incomplete_details = { + "reason": "max_output_tokens" + } + mock_responses_api_response.usage = None + + mock_incomplete_event = Mock(spec=ResponseIncompleteEvent) + mock_incomplete_event.type = ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE + mock_incomplete_event.response = mock_responses_api_response + + mock_config.transform_streaming_response.return_value = mock_incomplete_event + + iterator = ResponsesAPIStreamingIterator( + response=mock_response, + model="gpt-4", + responses_api_provider_config=mock_config, + logging_obj=mock_logging_obj, + litellm_metadata={"model_info": {"id": "model_123"}}, + custom_llm_provider="openai", + ) + + test_chunk_data = { + "type": "response.incomplete", + "response": { + "id": "resp_incomplete_123", + "incomplete_details": {"reason": "max_output_tokens"}, + }, + } + + with patch.object( + ResponsesAPIRequestUtils, + "_update_responses_api_response_id_with_model_id", + return_value=mock_responses_api_response, + ), patch( + "asyncio.create_task" + ) as mock_create_task, patch( + "litellm.responses.streaming_iterator.executor" + ) as mock_executor: + result = iterator._process_chunk(json.dumps(test_chunk_data)) + + assert result is not None + assert result.type == ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE + assert iterator.completed_response == result + + # Success handler should have been called (via _handle_logging_completed_response) + mock_create_task.assert_called_once() + mock_executor.submit.assert_called_once() + + # Failure handlers should NOT have been called + mock_logging_obj.async_failure_handler.assert_not_called() + mock_logging_obj.failure_handler.assert_not_called() + From df38fbcc973b269d70d6c6c6891d444d70e9f04d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 19 Mar 2026 04:47:30 +0000 Subject: [PATCH 148/539] docs: add Contributing to Guardrails section to Guardrail Providers sidebar - Add 'Contributing to Guardrails' category with links to: - Generic Guardrail API (integrate without PR) - Adding a New Guardrail Integration tutorial - Adding Guardrail Support to Endpoints - Add 'Team Bring-Your-Own Guardrails' link for team BYOG workflow These docs existed but were only accessible from the 'LiteLLM AI Gateway' sidebar. Now they're also accessible when browsing the 'Guardrail Providers' section. Co-authored-by: Krish Dholakia --- docs/my-website/sidebars.js | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 7db72da276..4c0471fb8f 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -48,6 +48,20 @@ const sidebars = { slug: "/guardrail_providers" }, items: [ + { + type: "category", + label: "Contributing to Guardrails", + items: [ + "adding_provider/generic_guardrail_api", + "adding_provider/simple_guardrail_tutorial", + "adding_provider/adding_guardrail_support", + ] + }, + { + type: "doc", + id: "proxy/guardrails/team_based_guardrails", + label: "Team Bring-Your-Own Guardrails", + }, ...[ "proxy/guardrails/qualifire", "proxy/guardrails/aim_security", From cac685014ff2ad795c9d21a9e42475e46fdcd4b5 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 01:30:18 -0400 Subject: [PATCH 149/539] feat: add proxy-wide default tpm/rpm limits per deployment Adds `default_api_key_tpm_limit` and `default_api_key_rpm_limit` to `GenericLiteLLMParams` so operators can set per-deployment rate limit defaults in config.yaml. When a key has no model-specific tpm/rpm limit configured, the proxy falls back to these deployment defaults (Case 2 in spec). Key-level limits always take priority (Case 1). - Extends `get_key_model_tpm_limit` / `get_key_model_rpm_limit` with a `model_name` param and a priority-4 deployment-default fallback - Passes `model_name=requested_model` in the parallel request limiter so the fallback is triggered at enforcement time - Adds `"limit"` to `SensitiveDataMasker` non-sensitive overrides so `*_limit` fields are not masked in `/model/info` responses - Adds 17 unit tests covering both spec cases and the `/model/info` path Co-Authored-By: Claude (claude-sonnet-4-6) --- .../sensitive_data_masker.py | 4 +- litellm/proxy/auth/auth_utils.py | 64 ++++++- .../hooks/parallel_request_limiter_v3.py | 8 +- litellm/types/router.py | 5 + .../proxy/auth/test_auth_utils.py | 131 +++++++++++++- .../proxy/test_model_info_default_limits.py | 163 ++++++++++++++++++ 6 files changed, 369 insertions(+), 6 deletions(-) create mode 100644 tests/test_litellm/proxy/test_model_info_default_limits.py diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index 663c3fac80..f22cfa11a3 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -30,7 +30,9 @@ class SensitiveDataMasker: # If any key segment matches one of these, the key is not considered sensitive # even if it also matches a sensitive pattern. For example, "input_cost_per_token" # contains "token" but "cost" overrides that — it's a pricing field, not a secret. - self.non_sensitive_overrides = non_sensitive_overrides or {"cost"} + # Similarly, "*_limit" fields (tpm_limit, rpm_limit, etc.) are rate/budget caps, + # not credentials, even though their names may contain "key" (e.g. default_api_key_tpm_limit). + self.non_sensitive_overrides = non_sensitive_overrides or {"cost", "limit"} self.visible_prefix = visible_prefix self.visible_suffix = visible_suffix diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index a03e1fb94c..235b217610 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -539,8 +539,49 @@ def bytes_to_mb(bytes_value: int): # helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key +def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: + """ + Return the default_api_key_rpm_limit configured on the deployment for model_name, + or None if not set. + """ + from litellm.proxy.proxy_server import llm_router + + if llm_router is None: + return None + deployments = llm_router.get_model_list(model_name=model_name) + if not deployments: + return None + for deployment in deployments: + litellm_params = deployment.get("litellm_params", {}) + limit = litellm_params.get("default_api_key_rpm_limit") + if limit is not None: + return int(limit) + return None + + +def _get_deployment_default_tpm_limit(model_name: str) -> Optional[int]: + """ + Return the default_api_key_tpm_limit configured on the deployment for model_name, + or None if not set. + """ + from litellm.proxy.proxy_server import llm_router + + if llm_router is None: + return None + deployments = llm_router.get_model_list(model_name=model_name) + if not deployments: + return None + for deployment in deployments: + litellm_params = deployment.get("litellm_params", {}) + limit = litellm_params.get("default_api_key_tpm_limit") + if limit is not None: + return int(limit) + return None + + def get_key_model_rpm_limit( user_api_key_dict: UserAPIKeyAuth, + model_name: Optional[str] = None, ) -> Optional[Dict[str, int]]: """ Get the model rpm limit for a given api key. @@ -549,6 +590,7 @@ def get_key_model_rpm_limit( 1. Key metadata (model_rpm_limit) 2. Key model_max_budget (rpm_limit per model) 3. Team metadata (model_rpm_limit) + 4. Deployment default_api_key_rpm_limit (when model_name is provided) """ # 1. Check key metadata first (takes priority) if user_api_key_dict.metadata: @@ -567,13 +609,22 @@ def get_key_model_rpm_limit( # 3. Fallback to team metadata if user_api_key_dict.team_metadata: - return user_api_key_dict.team_metadata.get("model_rpm_limit") + team_limit = user_api_key_dict.team_metadata.get("model_rpm_limit") + if team_limit: + return team_limit + + # 4. Fallback to deployment default_api_key_rpm_limit + if model_name is not None: + default_limit = _get_deployment_default_rpm_limit(model_name) + if default_limit is not None: + return {model_name: default_limit} return None def get_key_model_tpm_limit( user_api_key_dict: UserAPIKeyAuth, + model_name: Optional[str] = None, ) -> Optional[Dict[str, int]]: """ Get the model tpm limit for a given api key. @@ -582,6 +633,7 @@ def get_key_model_tpm_limit( 1. Key metadata (model_tpm_limit) 2. Key model_max_budget (tpm_limit per model) 3. Team metadata (model_tpm_limit) + 4. Deployment default_api_key_tpm_limit (when model_name is provided) """ # 1. Check key metadata first (takes priority) if user_api_key_dict.metadata: @@ -600,7 +652,15 @@ def get_key_model_tpm_limit( # 3. Fallback to team metadata if user_api_key_dict.team_metadata: - return user_api_key_dict.team_metadata.get("model_tpm_limit") + team_limit = user_api_key_dict.team_metadata.get("model_tpm_limit") + if team_limit: + return team_limit + + # 4. Fallback to deployment default_api_key_tpm_limit + if model_name is not None: + default_limit = _get_deployment_default_tpm_limit(model_name) + if default_limit is not None: + return {model_name: default_limit} return None diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index 19c8c484b4..5aaac088dc 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -687,8 +687,12 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): if not requested_model: return - _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) - _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) + _tpm_limit_for_key_model = get_key_model_tpm_limit( + user_api_key_dict, model_name=requested_model + ) + _rpm_limit_for_key_model = get_key_model_rpm_limit( + user_api_key_dict, model_name=requested_model + ) if _tpm_limit_for_key_model is None and _rpm_limit_for_key_model is None: return diff --git a/litellm/types/router.py b/litellm/types/router.py index e8ff2115ff..5d28349b5e 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -188,6 +188,11 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): max_file_size_mb: Optional[float] = None + # Proxy-wide default rate limits applied to any API key using this deployment + # when the key does not have a model-specific tpm/rpm limit configured. + default_api_key_tpm_limit: Optional[int] = None + default_api_key_rpm_limit: Optional[int] = None + # Deployment budgets max_budget: Optional[float] = None budget_duration: Optional[str] = None diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index 5e42b110aa..be4db666a0 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -2,7 +2,7 @@ Unit tests for auth_utils functions related to rate limiting and customer ID extraction. """ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.auth_utils import ( @@ -315,3 +315,132 @@ def test_get_end_user_id_falls_back_to_deprecated_user_header_name(): result = get_end_user_id_from_request_body(request_body={}, request_headers=headers) assert result == "user-legacy" + + +def _make_deployment_dict(model_name: str, tpm: int = None, rpm: int = None) -> dict: + """Helper to build a minimal deployment dict as returned by router.get_model_list.""" + litellm_params: dict = {"model": model_name} + if tpm is not None: + litellm_params["default_api_key_tpm_limit"] = tpm + if rpm is not None: + litellm_params["default_api_key_rpm_limit"] = rpm + return {"model_name": model_name, "litellm_params": litellm_params} + + +_ROUTER_PATCH = "litellm.proxy.proxy_server.llm_router" + + +class TestDeploymentDefaultRpmLimit: + """Tests for deployment default_api_key_rpm_limit fallback in get_key_model_rpm_limit.""" + + def test_returns_deployment_default_when_key_has_no_limits(self): + """Case 2 from spec: key has no model-specific limits, falls back to deployment default.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", rpm=200) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 200} + + def test_key_model_limit_takes_priority_over_deployment_default(self): + """Case 1 from spec: key model-specific limit wins over deployment default.""" + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-123", + metadata={"model_rpm_limit": {"model1": 10}}, + ) + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", rpm=200) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 10} + + def test_returns_none_when_no_deployment_default_and_no_key_limits(self): + """Returns None when neither the key nor the deployment has any rpm limit.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1") # no rpm default + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result is None + + def test_returns_none_without_model_name_even_when_deployment_has_default(self): + """No model_name means deployment fallback is skipped.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", rpm=200) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict) + assert result is None + + def test_returns_none_when_llm_router_is_none(self): + """No router means deployment fallback returns None gracefully.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + with patch(_ROUTER_PATCH, None): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result is None + + +class TestDeploymentDefaultTpmLimit: + """Tests for deployment default_api_key_tpm_limit fallback in get_key_model_tpm_limit.""" + + def test_returns_deployment_default_when_key_has_no_limits(self): + """Case 2 from spec: key has no model-specific limits, falls back to deployment default.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", tpm=100) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 100} + + def test_key_model_limit_takes_priority_over_deployment_default(self): + """Case 1 from spec: key model-specific limit wins over deployment default.""" + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-123", + metadata={"model_tpm_limit": {"model1": 20}}, + ) + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", tpm=100) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 20} + + def test_returns_none_when_no_deployment_default_and_no_key_limits(self): + """Returns None when neither the key nor the deployment has any tpm limit.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1") # no tpm default + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result is None + + def test_returns_none_without_model_name_even_when_deployment_has_default(self): + """No model_name means deployment fallback is skipped.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", tpm=100) + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict) + assert result is None + + def test_returns_none_when_llm_router_is_none(self): + """No router means deployment fallback returns None gracefully.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + with patch(_ROUTER_PATCH, None): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result is None diff --git a/tests/test_litellm/proxy/test_model_info_default_limits.py b/tests/test_litellm/proxy/test_model_info_default_limits.py new file mode 100644 index 0000000000..e749c84dfb --- /dev/null +++ b/tests/test_litellm/proxy/test_model_info_default_limits.py @@ -0,0 +1,163 @@ +""" +Tests verifying that default_api_key_tpm_limit and default_api_key_rpm_limit set in +litellm_params are returned by the /model/info endpoint. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from litellm.proxy.proxy_server import _get_proxy_model_info +from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo + + +def _make_deployment( + model_name: str, + default_tpm: int = None, + default_rpm: int = None, +) -> Deployment: + params: dict = {"model": f"openai/{model_name}"} + if default_tpm is not None: + params["default_api_key_tpm_limit"] = default_tpm + if default_rpm is not None: + params["default_api_key_rpm_limit"] = default_rpm + return Deployment( + model_name=model_name, + litellm_params=LiteLLM_Params(**params), + model_info=ModelInfo(), + ) + + +class TestModelInfoDefaultLimitsInResponse: + """ + Verify _get_proxy_model_info (the helper used by the /model/info endpoint) returns + default_api_key_tpm_limit and default_api_key_rpm_limit from litellm_params. + """ + + def test_default_tpm_and_rpm_present_in_model_info_response(self): + """Both defaults should appear in the litellm_params section of the response.""" + deployment = _make_deployment("model1", default_tpm=100, default_rpm=200) + model_dict = deployment.model_dump(exclude_none=True) + + result = _get_proxy_model_info(model=model_dict) + + litellm_params = result["litellm_params"] + assert litellm_params.get("default_api_key_tpm_limit") == 100 + assert litellm_params.get("default_api_key_rpm_limit") == 200 + + def test_default_tpm_only_present_when_only_tpm_configured(self): + """Only the configured default appears; the other stays absent.""" + deployment = _make_deployment("model1", default_tpm=500) + model_dict = deployment.model_dump(exclude_none=True) + + result = _get_proxy_model_info(model=model_dict) + + litellm_params = result["litellm_params"] + assert litellm_params.get("default_api_key_tpm_limit") == 500 + assert "default_api_key_rpm_limit" not in litellm_params + + def test_default_rpm_only_present_when_only_rpm_configured(self): + """Only the configured default appears; the other stays absent.""" + deployment = _make_deployment("model1", default_rpm=300) + model_dict = deployment.model_dump(exclude_none=True) + + result = _get_proxy_model_info(model=model_dict) + + litellm_params = result["litellm_params"] + assert litellm_params.get("default_api_key_rpm_limit") == 300 + assert "default_api_key_tpm_limit" not in litellm_params + + def test_defaults_absent_when_not_configured(self): + """Neither field appears when not set on the deployment.""" + deployment = _make_deployment("model1") + model_dict = deployment.model_dump(exclude_none=True) + + result = _get_proxy_model_info(model=model_dict) + + litellm_params = result["litellm_params"] + assert "default_api_key_tpm_limit" not in litellm_params + assert "default_api_key_rpm_limit" not in litellm_params + + def test_defaults_not_masked_or_stripped_by_sensitive_data_filter(self): + """ + default_api_key_tpm_limit / default_api_key_rpm_limit must not be + treated as sensitive and must survive remove_sensitive_info_from_deployment. + """ + deployment = _make_deployment("model1", default_tpm=100, default_rpm=200) + model_dict = deployment.model_dump(exclude_none=True) + + result = _get_proxy_model_info(model=model_dict) + + # Values should be unchanged integers, not masked strings + assert result["litellm_params"]["default_api_key_tpm_limit"] == 100 + assert result["litellm_params"]["default_api_key_rpm_limit"] == 200 + + +class TestModelInfoEndpointWithRouter: + """ + Integration-style tests simulating the /model/info endpoint reading from the router. + """ + + @pytest.mark.asyncio + async def test_model_info_endpoint_returns_defaults_for_specific_model_id(self): + """ + When litellm_model_id is provided, the endpoint should return the deployment's + default limits in litellm_params. + """ + from litellm.proxy.proxy_server import model_info_v1 + from litellm.proxy._types import UserAPIKeyAuth + + deployment = _make_deployment("model1", default_tpm=100, default_rpm=200) + + mock_router = MagicMock() + mock_router.get_deployment.return_value = deployment + + user_api_key_dict = UserAPIKeyAuth(api_key="sk-test") + + with patch("litellm.proxy.proxy_server.llm_router", mock_router), \ + patch("litellm.proxy.proxy_server.llm_model_list", [{}]), \ + patch("litellm.proxy.proxy_server.user_model", None): + response = await model_info_v1( + user_api_key_dict=user_api_key_dict, + litellm_model_id="some-model-id", + ) + + assert len(response["data"]) == 1 + litellm_params = response["data"][0]["litellm_params"] + assert litellm_params.get("default_api_key_tpm_limit") == 100 + assert litellm_params.get("default_api_key_rpm_limit") == 200 + + @pytest.mark.asyncio + async def test_model_info_endpoint_returns_defaults_in_full_model_list(self): + """ + Without litellm_model_id, the endpoint iterates all models. Each deployment's + default limits should appear in its litellm_params entry. + """ + from litellm.proxy.proxy_server import model_info_v1 + from litellm.proxy._types import UserAPIKeyAuth + + deployment = _make_deployment("model1", default_tpm=100, default_rpm=200) + deployment_dict = deployment.model_dump(exclude_none=True) + + mock_router = MagicMock() + mock_router.get_model_names.return_value = ["model1"] + mock_router.get_model_access_groups.return_value = {} + mock_router.get_model_list.return_value = [deployment_dict] + + user_api_key_dict = UserAPIKeyAuth(api_key="sk-test") + + with patch("litellm.proxy.proxy_server.llm_router", mock_router), \ + patch("litellm.proxy.proxy_server.llm_model_list", [deployment_dict]), \ + patch("litellm.proxy.proxy_server.user_model", None), \ + patch("litellm.proxy.proxy_server.get_key_models", return_value=["model1"]), \ + patch("litellm.proxy.proxy_server.get_team_models", return_value=["model1"]), \ + patch("litellm.proxy.proxy_server.get_complete_model_list", return_value=["model1"]): + response = await model_info_v1( + user_api_key_dict=user_api_key_dict, + litellm_model_id=None, + ) + + assert len(response["data"]) >= 1 + litellm_params = response["data"][0]["litellm_params"] + assert litellm_params.get("default_api_key_tpm_limit") == 100 + assert litellm_params.get("default_api_key_rpm_limit") == 200 From bba3b1fe4c468576c145affbb08f6a6587eedc2c Mon Sep 17 00:00:00 2001 From: joereyna Date: Wed, 18 Mar 2026 22:42:25 -0700 Subject: [PATCH 150/539] docs(release-notes): add missing Helicone and Langfuse entries to v1.82.3 changelog Helicone (PRs #19288, #22603) and Langfuse (#22390) were present in the v1.82.0-stable...v1.82.3-stable diff but omitted from the AI Integrations logging section. Also updates the AI Integrations diff summary count from 2 to 4. Co-Authored-By: Claude Sonnet 4.6 --- docs/my-website/release_notes/v1.82.3/index.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/my-website/release_notes/v1.82.3/index.md b/docs/my-website/release_notes/v1.82.3/index.md index 15df33fdb8..b586cab7d6 100644 --- a/docs/my-website/release_notes/v1.82.3/index.md +++ b/docs/my-website/release_notes/v1.82.3/index.md @@ -297,6 +297,13 @@ pip install litellm==1.82.3 ### Logging +- **[Helicone](../../docs/observability/helicone_integration)** + - Add Gemini and Vertex AI support to HeliconeLogger — routes Gemini and Vertex AI requests through the correct Helicone provider URL - [PR #19288](https://github.com/BerriAI/litellm/pull/19288) + - Fix correct provider URL for Vertex AI Gemini models - [PR #22603](https://github.com/BerriAI/litellm/pull/22603) + +- **[Langfuse](../../docs/proxy/logging#langfuse)** + - Fix failure path kwargs inconsistency causing dropped traces on failed requests - [PR #22390](https://github.com/BerriAI/litellm/pull/22390) + - **[Vantage](https://vantage.sh)** - Add Vantage integration for FOCUS 1.2 CSV export — export LiteLLM proxy spend data as FinOps Open Cost & Usage Specification reports, with time-windowed filenames to prevent overwrites - [PR #23333](https://github.com/BerriAI/litellm/pull/23333) @@ -363,7 +370,7 @@ No major secret manager changes in this release. * New Models / Updated Models: 116 new, 132 removed * LLM API Endpoints: 5 * Management Endpoints / UI: 11 -* AI Integrations: 2 +* AI Integrations: 4 * Performance / Reliability: 5 * Security: 3 * Database / Proxy Operations: 2 From 36dc893770fa69cfcc014e5c535b87d58ca12c89 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 01:43:27 -0400 Subject: [PATCH 151/539] fix: address review feedback on default tpm/rpm limits - Use min() across all matching deployments instead of first-wins when resolving default_api_key_tpm/rpm_limit for a model group, so load-balanced setups with different per-deployment limits always apply the most conservative value - Replace the global SensitiveDataMasker non_sensitive_overrides change with a targeted excluded_keys set at the remove_sensitive_info_from_deployment call site, avoiding unintended suppression of other fields - Update the v1 parallel request limiter to pass model_name to get_key_model_tpm/rpm_limit so deployment defaults apply there too - Add 4 tests covering multi-deployment min semantics Co-Authored-By: Claude (claude-sonnet-4-6) --- .../sensitive_data_masker.py | 4 +- litellm/proxy/auth/auth_utils.py | 44 ++++++++++------ .../common_utils/openai_endpoint_utils.py | 11 +++- .../proxy/hooks/parallel_request_limiter.py | 14 ++++-- .../proxy/auth/test_auth_utils.py | 50 +++++++++++++++++++ .../proxy/test_model_info_default_limits.py | 3 ++ 6 files changed, 101 insertions(+), 25 deletions(-) diff --git a/litellm/litellm_core_utils/sensitive_data_masker.py b/litellm/litellm_core_utils/sensitive_data_masker.py index f22cfa11a3..663c3fac80 100644 --- a/litellm/litellm_core_utils/sensitive_data_masker.py +++ b/litellm/litellm_core_utils/sensitive_data_masker.py @@ -30,9 +30,7 @@ class SensitiveDataMasker: # If any key segment matches one of these, the key is not considered sensitive # even if it also matches a sensitive pattern. For example, "input_cost_per_token" # contains "token" but "cost" overrides that — it's a pricing field, not a secret. - # Similarly, "*_limit" fields (tpm_limit, rpm_limit, etc.) are rate/budget caps, - # not credentials, even though their names may contain "key" (e.g. default_api_key_tpm_limit). - self.non_sensitive_overrides = non_sensitive_overrides or {"cost", "limit"} + self.non_sensitive_overrides = non_sensitive_overrides or {"cost"} self.visible_prefix = visible_prefix self.visible_suffix = visible_suffix diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 235b217610..ace39c05ff 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -541,8 +541,13 @@ def bytes_to_mb(bytes_value: int): # helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: """ - Return the default_api_key_rpm_limit configured on the deployment for model_name, - or None if not set. + Return the default_api_key_rpm_limit for model_name. + + When multiple deployments share the same model name, returns the minimum + across all deployments that have the field set. This is the safest choice + for load-balanced setups: it ensures no deployment is over-consumed + regardless of which one actually serves a given request. + Returns None if no deployment has the field set. """ from litellm.proxy.proxy_server import llm_router @@ -551,18 +556,24 @@ def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: deployments = llm_router.get_model_list(model_name=model_name) if not deployments: return None - for deployment in deployments: - litellm_params = deployment.get("litellm_params", {}) - limit = litellm_params.get("default_api_key_rpm_limit") - if limit is not None: - return int(limit) - return None + limits = [ + int(deployment.get("litellm_params", {}).get("default_api_key_rpm_limit")) + for deployment in deployments + if deployment.get("litellm_params", {}).get("default_api_key_rpm_limit") + is not None + ] + return min(limits) if limits else None def _get_deployment_default_tpm_limit(model_name: str) -> Optional[int]: """ - Return the default_api_key_tpm_limit configured on the deployment for model_name, - or None if not set. + Return the default_api_key_tpm_limit for model_name. + + When multiple deployments share the same model name, returns the minimum + across all deployments that have the field set. This is the safest choice + for load-balanced setups: it ensures no deployment is over-consumed + regardless of which one actually serves a given request. + Returns None if no deployment has the field set. """ from litellm.proxy.proxy_server import llm_router @@ -571,12 +582,13 @@ def _get_deployment_default_tpm_limit(model_name: str) -> Optional[int]: deployments = llm_router.get_model_list(model_name=model_name) if not deployments: return None - for deployment in deployments: - litellm_params = deployment.get("litellm_params", {}) - limit = litellm_params.get("default_api_key_tpm_limit") - if limit is not None: - return int(limit) - return None + limits = [ + int(deployment.get("litellm_params", {}).get("default_api_key_tpm_limit")) + for deployment in deployments + if deployment.get("litellm_params", {}).get("default_api_key_tpm_limit") + is not None + ] + return min(limits) if limits else None def get_key_model_rpm_limit( diff --git a/litellm/proxy/common_utils/openai_endpoint_utils.py b/litellm/proxy/common_utils/openai_endpoint_utils.py index 6df5491f37..7e5c83500a 100644 --- a/litellm/proxy/common_utils/openai_endpoint_utils.py +++ b/litellm/proxy/common_utils/openai_endpoint_utils.py @@ -32,8 +32,17 @@ def remove_sensitive_info_from_deployment( deployment_dict["litellm_params"].pop("aws_access_key_id", None) deployment_dict["litellm_params"].pop("aws_secret_access_key", None) + # Rate-limit config fields must never be masked — they are integers, not credentials. + # The field names contain "key" which matches the masker's sensitive pattern, so we + # explicitly exclude them here rather than widening the global non_sensitive_overrides. + _rate_limit_config_keys = { + "default_api_key_tpm_limit", + "default_api_key_rpm_limit", + } + _excluded = (excluded_keys or set()) | _rate_limit_config_keys + deployment_dict["litellm_params"] = SENSITIVE_DATA_MASKER.mask_dict( - deployment_dict["litellm_params"], excluded_keys=excluded_keys + deployment_dict["litellm_params"], excluded_keys=_excluded ) return deployment_dict diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index c7bfc27d6b..48bf255ac1 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -295,16 +295,20 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ) # Check if request under RPM/TPM per model for a given API Key + _model = data.get("model", None) if ( - get_key_model_tpm_limit(user_api_key_dict) is not None - or get_key_model_rpm_limit(user_api_key_dict) is not None + get_key_model_tpm_limit(user_api_key_dict, model_name=_model) is not None + or get_key_model_rpm_limit(user_api_key_dict, model_name=_model) is not None ): - _model = data.get("model", None) request_count_api_key = ( f"{api_key}::{_model}::{precise_minute}::request_count" ) - _tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict) - _rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict) + _tpm_limit_for_key_model = get_key_model_tpm_limit( + user_api_key_dict, model_name=_model + ) + _rpm_limit_for_key_model = get_key_model_rpm_limit( + user_api_key_dict, model_name=_model + ) tpm_limit_for_model = None rpm_limit_for_model = None diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index be4db666a0..d64f17e70f 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -387,6 +387,31 @@ class TestDeploymentDefaultRpmLimit: result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") assert result is None + def test_returns_minimum_across_multiple_deployments(self): + """When multiple deployments share a model name, the minimum rpm limit is used.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", rpm=200), + _make_deployment_dict("model1", rpm=50), + _make_deployment_dict("model1", rpm=150), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 50} + + def test_ignores_deployments_without_default_when_others_have_it(self): + """Deployments missing the field are skipped; min is taken over those that have it.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1"), # no rpm default + _make_deployment_dict("model1", rpm=75), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 75} + class TestDeploymentDefaultTpmLimit: """Tests for deployment default_api_key_tpm_limit fallback in get_key_model_tpm_limit.""" @@ -444,3 +469,28 @@ class TestDeploymentDefaultTpmLimit: with patch(_ROUTER_PATCH, None): result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") assert result is None + + def test_returns_minimum_across_multiple_deployments(self): + """When multiple deployments share a model name, the minimum tpm limit is used.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1", tpm=1000), + _make_deployment_dict("model1", tpm=300), + _make_deployment_dict("model1", tpm=700), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 300} + + def test_ignores_deployments_without_default_when_others_have_it(self): + """Deployments missing the field are skipped; min is taken over those that have it.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + _make_deployment_dict("model1"), # no tpm default + _make_deployment_dict("model1", tpm=400), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + assert result == {"model1": 400} diff --git a/tests/test_litellm/proxy/test_model_info_default_limits.py b/tests/test_litellm/proxy/test_model_info_default_limits.py index e749c84dfb..907f2390a8 100644 --- a/tests/test_litellm/proxy/test_model_info_default_limits.py +++ b/tests/test_litellm/proxy/test_model_info_default_limits.py @@ -82,6 +82,9 @@ class TestModelInfoDefaultLimitsInResponse: """ default_api_key_tpm_limit / default_api_key_rpm_limit must not be treated as sensitive and must survive remove_sensitive_info_from_deployment. + They contain "key" which normally triggers masking; the call site explicitly + excludes these two fields via excluded_keys rather than widening the global + non_sensitive_overrides. """ deployment = _make_deployment("model1", default_tpm=100, default_rpm=200) model_dict = deployment.model_dump(exclude_none=True) From b90f5207488c71ad22acc700c24999862e0e08e9 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 01:57:09 -0400 Subject: [PATCH 152/539] perf: eliminate redundant router lookups in v1 parallel request limiter Compute get_key_model_tpm/rpm_limit once before the guard condition instead of calling each function twice (once to check non-None, once to retrieve). Removes 2 extra llm_router.get_model_list() calls per request when deployment defaults are active. Co-Authored-By: Claude (claude-sonnet-4-6) --- litellm/proxy/hooks/parallel_request_limiter.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 48bf255ac1..6e34c3eee1 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -296,19 +296,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Check if request under RPM/TPM per model for a given API Key _model = data.get("model", None) - if ( - get_key_model_tpm_limit(user_api_key_dict, model_name=_model) is not None - or get_key_model_rpm_limit(user_api_key_dict, model_name=_model) is not None - ): + _tpm_limit_for_key_model = get_key_model_tpm_limit( + user_api_key_dict, model_name=_model + ) + _rpm_limit_for_key_model = get_key_model_rpm_limit( + user_api_key_dict, model_name=_model + ) + if _tpm_limit_for_key_model is not None or _rpm_limit_for_key_model is not None: request_count_api_key = ( f"{api_key}::{_model}::{precise_minute}::request_count" ) - _tpm_limit_for_key_model = get_key_model_tpm_limit( - user_api_key_dict, model_name=_model - ) - _rpm_limit_for_key_model = get_key_model_rpm_limit( - user_api_key_dict, model_name=_model - ) tpm_limit_for_model = None rpm_limit_for_model = None From dab8721ba316a6b859d8636c80bc36b94f32a767 Mon Sep 17 00:00:00 2001 From: joereyna Date: Wed, 18 Mar 2026 22:57:38 -0700 Subject: [PATCH 153/539] chore: apply black formatting to fix lint CI --- litellm/litellm_core_utils/litellm_logging.py | 7 ++----- litellm/proxy/management_endpoints/team_endpoints.py | 4 +++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 4e63dd7076..1b612c7091 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1460,9 +1460,7 @@ class Logging(LiteLLMLoggingBaseClass): # streaming) don't carry _hidden_params["model_id"] like ModelResponse does. if router_model_id is None and hasattr(self, "litellm_params"): for metadata_key in ("litellm_metadata", "metadata"): - _metadata: dict = ( - self.litellm_params.get(metadata_key, {}) or {} - ) + _metadata: dict = self.litellm_params.get(metadata_key, {}) or {} _model_info: dict = _metadata.get("model_info", {}) or {} _model_id = _model_info.get("id") if _model_id is not None: @@ -2972,8 +2970,7 @@ class Logging(LiteLLMLoggingBaseClass): if ( isinstance(callback, CustomLogger) and is_sync_request - and self.call_type - != CallTypes.pass_through.value + and self.call_type != CallTypes.pass_through.value ): # custom logger class callback.log_failure_event( start_time=start_time, diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index b033d3fce0..3d4488b8a7 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -3334,7 +3334,9 @@ def _convert_teams_to_response_models( use_deleted_table: bool, ) -> List[Union[TeamListItem, LiteLLM_TeamTable, LiteLLM_DeletedTeamTable]]: """Convert raw Prisma team rows to response models.""" - team_list: List[Union[TeamListItem, LiteLLM_TeamTable, LiteLLM_DeletedTeamTable]] = [] + team_list: List[ + Union[TeamListItem, LiteLLM_TeamTable, LiteLLM_DeletedTeamTable] + ] = [] for team in teams: try: team_dict = team.model_dump() From 48cb4a83435faa458888b412d408eb06cbca694d Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 01:59:41 -0400 Subject: [PATCH 154/539] fix: update success-event handler to track tokens for deployment-default limits async_log_success_event only updated the per-model cache counter when model_rpm_limit / model_tpm_limit were present in key metadata or model_max_budget was set. For the new deployment-default path (default_api_key_tpm_limit / default_api_key_rpm_limit), none of those conditions held, so current_tpm stayed at zero and tpm enforcement was never applied across multiple requests. Extend the guard condition to also trigger when the model group has a deployment-default tpm or rpm limit, and import the two helpers at module level. Co-Authored-By: Claude (claude-sonnet-4-6) --- litellm/proxy/hooks/parallel_request_limiter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 6e34c3eee1..b0046fd035 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -14,6 +14,8 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import ( + _get_deployment_default_rpm_limit, + _get_deployment_default_tpm_limit, get_key_model_rpm_limit, get_key_model_tpm_limit, ) @@ -546,6 +548,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "model_rpm_limit" in user_api_key_metadata or "model_tpm_limit" in user_api_key_metadata or user_api_key_model_max_budget is not None + or _get_deployment_default_tpm_limit(model_group) is not None + or _get_deployment_default_rpm_limit(model_group) is not None ) ): request_count_api_key = ( From 477c54184bda814745bb06d4fbe4900cfc816bd2 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 02:07:50 -0400 Subject: [PATCH 155/539] perf: avoid unconditional router lookups in success handler Replace bare _get_deployment_default_tpm/rpm_limit calls in the async_log_success_event condition with get_key_model_tpm/rpm_limit (model_name=model_group). The higher-level getters short-circuit on key/team metadata hits before ever reaching the router, so requests that don't use deployment defaults incur no extra router lookup. Remove the now-unused bare helper imports. Also fix invalid `int = None` type hints in test helper signatures to `Optional[int] = None`. Co-Authored-By: Claude (claude-sonnet-4-6) --- litellm/proxy/hooks/parallel_request_limiter.py | 12 ++++++++---- tests/test_litellm/proxy/auth/test_auth_utils.py | 3 ++- .../proxy/test_model_info_default_limits.py | 5 +++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index b0046fd035..49c6436c22 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -14,8 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import ( - _get_deployment_default_rpm_limit, - _get_deployment_default_tpm_limit, get_key_model_rpm_limit, get_key_model_tpm_limit, ) @@ -548,8 +546,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "model_rpm_limit" in user_api_key_metadata or "model_tpm_limit" in user_api_key_metadata or user_api_key_model_max_budget is not None - or _get_deployment_default_tpm_limit(model_group) is not None - or _get_deployment_default_rpm_limit(model_group) is not None + or get_key_model_tpm_limit( + user_api_key_dict, model_name=model_group + ) + is not None + or get_key_model_rpm_limit( + user_api_key_dict, model_name=model_group + ) + is not None ) ): request_count_api_key = ( diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index d64f17e70f..2058f61cb0 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -2,6 +2,7 @@ Unit tests for auth_utils functions related to rate limiting and customer ID extraction. """ +from typing import Optional from unittest.mock import MagicMock, patch from litellm.proxy._types import UserAPIKeyAuth @@ -317,7 +318,7 @@ def test_get_end_user_id_falls_back_to_deprecated_user_header_name(): assert result == "user-legacy" -def _make_deployment_dict(model_name: str, tpm: int = None, rpm: int = None) -> dict: +def _make_deployment_dict(model_name: str, tpm: Optional[int] = None, rpm: Optional[int] = None) -> dict: """Helper to build a minimal deployment dict as returned by router.get_model_list.""" litellm_params: dict = {"model": model_name} if tpm is not None: diff --git a/tests/test_litellm/proxy/test_model_info_default_limits.py b/tests/test_litellm/proxy/test_model_info_default_limits.py index 907f2390a8..8b85531785 100644 --- a/tests/test_litellm/proxy/test_model_info_default_limits.py +++ b/tests/test_litellm/proxy/test_model_info_default_limits.py @@ -3,6 +3,7 @@ Tests verifying that default_api_key_tpm_limit and default_api_key_rpm_limit set litellm_params are returned by the /model/info endpoint. """ +from typing import Optional from unittest.mock import MagicMock, patch import pytest @@ -13,8 +14,8 @@ from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo def _make_deployment( model_name: str, - default_tpm: int = None, - default_rpm: int = None, + default_tpm: Optional[int] = None, + default_rpm: Optional[int] = None, ) -> Deployment: params: dict = {"model": f"openai/{model_name}"} if default_tpm is not None: From 61df7471bae37e2fa25fcbfe4917b193ff5b1b81 Mon Sep 17 00:00:00 2001 From: joereyna Date: Wed, 18 Mar 2026 23:30:42 -0700 Subject: [PATCH 156/539] docs(release-notes): complete v1.82.3 changelog with 30+ missing features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Full audit of 371 PRs in v1.82.0-stable...v1.82.3-stable range. Adds previously undocumented user-facing changes: - Key Highlights: Hashicorp Vault, Responses WebSocket, Org Admin RBAC, guardrail mode defaults - New Providers: Google Search API, Bedrock Mantle (7 total, was 5) - LLM API: Anthropic Files API, Mistral Voxtral transcription, WebRTC, Responses WebSocket, litellm.acount_tokens() public API, OpenRouter image edit, Vertex AI VIDEO token tracking, input_fidelity image edit, model cost aliases, per-request json schema validation, 15+ bug fixes - Management: RBAC expansion for Org Admins, Vector Store CRUD, MCP token auth + team scoping, BYOK key precedence, virtual key spend reset, batch expiry for teams, Admin Viewer audit log access, 12+ bug fixes - Guardrails: mode default list, tag-based modes, presidio fix, OTEL fix - Secret Managers: Hashicorp Vault (was "no changes") - Spend Tracking: new section — budget-linked reset fix, flex pricing, spend log cleanup, WebSearch dedup fix - Performance: 4 additional reliability fixes - Diff summary counts updated to reflect actual scope Co-Authored-By: Claude Sonnet 4.6 --- .../my-website/release_notes/v1.82.3/index.md | 147 +++++++++++++++++- 1 file changed, 139 insertions(+), 8 deletions(-) diff --git a/docs/my-website/release_notes/v1.82.3/index.md b/docs/my-website/release_notes/v1.82.3/index.md index b586cab7d6..a51edc60c2 100644 --- a/docs/my-website/release_notes/v1.82.3/index.md +++ b/docs/my-website/release_notes/v1.82.3/index.md @@ -47,6 +47,10 @@ pip install litellm==1.82.3 - **FLUX Kontext image editing** — `flux-kontext-pro` and `flux-kontext-max` added to Black Forest Labs, alongside `flux-pro-1.0-fill` and `flux-pro-1.0-expand` for inpainting and outpainting - **116 new models, 132 deprecated models cleaned up** — Major model map refresh including Mistral Magistral, Dashscope Qwen3 VL, xAI Grok via Azure AI, ZAI GLM-5, Serper Search; removal of OpenAI GPT-3.5/GPT-4 legacy variants, Gemini 1.5, and Vertex AI PaLM2 - **SageMaker Nova provider** — [New `sagemaker_nova` provider for Amazon Nova models on SageMaker](../../docs/providers/aws_sagemaker) - [PR #21542](https://github.com/BerriAI/litellm/pull/21542) +- **Hashicorp Vault secret manager** — Config override backend powered by Hashicorp Vault, with full UI for managing vault-sourced credentials - [PR #22939](https://github.com/BerriAI/litellm/pull/22939), [PR #23036](https://github.com/BerriAI/litellm/pull/23036) +- **Responses API WebSocket streaming** — Real-time WebSocket streaming for the Responses API, including support across all providers - [PR #22559](https://github.com/BerriAI/litellm/pull/22559), [PR #22771](https://github.com/BerriAI/litellm/pull/22771) +- **Org Admin RBAC expansion** — Org Admins can now access team management endpoints, view and invite internal users, and manage team membership without requiring a global admin role - [PR #23085](https://github.com/BerriAI/litellm/pull/23085), [PR #23080](https://github.com/BerriAI/litellm/pull/23080) +- **Guardrail mode defaults and tag-based modes** — Set a default guardrail mode list globally, and specify a list of modes in tag-based guardrail configs - [PR #22676](https://github.com/BerriAI/litellm/pull/22676), [PR #23020](https://github.com/BerriAI/litellm/pull/23020) - **Secret redaction in logs** — API keys, tokens, and credentials automatically scrubbed from all proxy log output. Enabled by default; opt out with `LITELLM_DISABLE_REDACT_SECRETS=true` - [PR #23668](https://github.com/BerriAI/litellm/pull/23668) - **Streaming stability fix** — Critical fix for `RuntimeError: Cannot send a request, as the client has been closed.` crashes after ~1 hour in production - [PR #22926](https://github.com/BerriAI/litellm/pull/22926) @@ -54,7 +58,7 @@ pip install litellm==1.82.3 ## New Providers and Endpoints -### New Providers (5 new providers) +### New Providers (7 new providers) | Provider | Supported LiteLLM Endpoints | Description | | -------- | --------------------------- | ----------- | @@ -63,6 +67,8 @@ pip install litellm==1.82.3 | [Black Forest Labs](../../docs/providers/black_forest_labs) (`black_forest_labs/`) | `/images/generations`, `/images/edits` | FLUX image generation and editing — Kontext Pro/Max, Pro 1.0 Fill/Expand | | [Serper](../../docs/providers/serper) (`serper/`) | `/search` | Web search via Serper API | | [SageMaker Nova](../../docs/providers/aws_sagemaker) (`sagemaker_nova/`) | `/chat/completions` | Amazon Nova models via SageMaker endpoint | +| [Google Search API](../../docs/providers/google_search) (`google_search/`) | `/search` | Google Search API integration - [PR #22752](https://github.com/BerriAI/litellm/pull/22752) | +| [Bedrock Mantle](../../docs/providers/bedrock) (`bedrock_mantle/`) | `/chat/completions` | Amazon Bedrock via Mantle — alternative auth and routing path for Bedrock models - [PR #22866](https://github.com/BerriAI/litellm/pull/22866) | --- @@ -238,22 +244,92 @@ pip install litellm==1.82.3 - **[Responses API](../../docs/response_api)** - Handle `response.failed`, `response.incomplete`, and `response.cancelled` terminal event types in background streaming — previously only `response.completed` was handled - [PR #23492](https://github.com/BerriAI/litellm/pull/23492) + - WebSocket streaming support for Responses API — real-time streaming via WebSocket for all providers - [PR #22559](https://github.com/BerriAI/litellm/pull/22559), [PR #22771](https://github.com/BerriAI/litellm/pull/22771) + - WebRTC support for real-time audio/video communication - [PR #23446](https://github.com/BerriAI/litellm/pull/23446) + - Responses API support for OpenAI-compatible JSON providers (`openai_like`) - [PR #21398](https://github.com/BerriAI/litellm/pull/21398) + - Route `gpt-5.4+` calls using both tools and reasoning to the Responses API automatically - [PR #23577](https://github.com/BerriAI/litellm/pull/23577) + +- **[Anthropic Files API](../../docs/providers/anthropic)** + - Full Anthropic Files API support — upload, retrieve, list, and delete files; use file references in messages - [PR #16594](https://github.com/BerriAI/litellm/pull/16594) + +- **[Mistral](../../docs/providers/mistral)** + - Voxtral audio transcription support — `mistral/voxtral-mini-*` and `mistral/voxtral-*` for audio transcription via Mistral - [PR #22801](https://github.com/BerriAI/litellm/pull/22801) + +- **[OpenAI](../../docs/providers/openai)** + - `litellm.acount_tokens()` public API — async token counting with full OpenAI provider support - [PR #22809](https://github.com/BerriAI/litellm/pull/22809) + - Normalize `reasoning_effort` dict to string for chat completion API - [PR #22981](https://github.com/BerriAI/litellm/pull/22981) + +- **[OpenRouter](../../docs/providers/openrouter)** + - Image edit support for OpenRouter models - [PR #22403](https://github.com/BerriAI/litellm/pull/22403) + +- **[Google Vertex AI](../../docs/providers/vertex)** + - VIDEO modality token usage tracking in `completion_tokens_details` - [PR #22550](https://github.com/BerriAI/litellm/pull/22550) + +- **Images API** + - `input_fidelity` parameter for image edit API - [PR #23201](https://github.com/BerriAI/litellm/pull/23201) + +- **General** + - Per-request `enable_json_schema_validation` flag for thread-safe JSON schema validation - [PR #21233](https://github.com/BerriAI/litellm/pull/21233) + - Model cost aliases expansion — define aliases in the cost map that inherit pricing from a parent model - [PR #23314](https://github.com/BerriAI/litellm/pull/23314), [PR #23457](https://github.com/BerriAI/litellm/pull/23457) + - Wildcards model support for the Files API - [PR #22740](https://github.com/BerriAI/litellm/pull/22740) #### Bug Fixes - **[Anthropic](../../docs/providers/anthropic)** - Preserve native tool format (web_search, bash, tool_search, etc.) when guardrails convert tools for the Anthropic Messages API - [PR #23526](https://github.com/BerriAI/litellm/pull/23526) + - Enforce `type: "object"` on tool input schemas in `_map_tool_helper` — fixes tool call failures for strict-schema providers - [PR #23103](https://github.com/BerriAI/litellm/pull/23103) + - Deduplicate `tool_result` messages by `tool_call_id` — prevents duplicate tool result errors in multi-turn conversations - [PR #23104](https://github.com/BerriAI/litellm/pull/23104) + - Map `reasoning_effort` to `output_config` for Claude 4.6 models - [PR #22220](https://github.com/BerriAI/litellm/pull/22220) + +- **[Google Gemini](../../docs/providers/gemini)** + - Correct streaming `finish_reason` for tool calls — was incorrectly returning `null` instead of `tool_calls` - [PR #21577](https://github.com/BerriAI/litellm/pull/21577) + - Preserve `$ref` in JSON Schema for Gemini 2.0+ — schema references were being stripped, breaking structured output - [PR #21597](https://github.com/BerriAI/litellm/pull/21597) + - Handle `minimal` `reasoning_effort` param for Gemini 3.1 models - [PR #22920](https://github.com/BerriAI/litellm/pull/22920) + +- **[Google Vertex AI](../../docs/providers/vertex)** + - Pass through native Gemini `imageConfig` params for image generation - [PR #21585](https://github.com/BerriAI/litellm/pull/21585) + - Prevent content truncation when `finish_reason` races ahead of content chunks in streaming - [PR #22692](https://github.com/BerriAI/litellm/pull/22692) + - Strip LiteLLM-internal keys from `extra_body` before merging to Gemini request body - [PR #23131](https://github.com/BerriAI/litellm/pull/23131) + - Drop unsupported `output_config` parameter from all Vertex AI requests - [PR #22884](https://github.com/BerriAI/litellm/pull/22884) + - Skip schema transforms for Gemini 2.0+ tool parameters — avoids breaking native Gemini schema handling - [PR #23265](https://github.com/BerriAI/litellm/pull/23265) + +- **[OpenRouter](../../docs/providers/openrouter)** + - Pattern-based fix for native model double-stripping when provider prefix matches model name - [PR #22320](https://github.com/BerriAI/litellm/pull/22320) + - Use provider-reported usage in streaming responses when `stream_options` is not set - [PR #21592](https://github.com/BerriAI/litellm/pull/21592) + +- **[AWS Bedrock](../../docs/providers/bedrock)** + - Extract region and model ID from `bedrock/{region}/{model}` path format - [PR #22546](https://github.com/BerriAI/litellm/pull/22546) + - Strip `scope` from `cache_control` for Anthropic messages on Bedrock and Azure AI - [PR #22867](https://github.com/BerriAI/litellm/pull/22867) + - Populate `completion_tokens_details` in Responses API responses - [PR #23243](https://github.com/BerriAI/litellm/pull/23243) + +- **[Azure AI](../../docs/providers/azure_ai)** + - Resolve `api_base` from environment variable in Document Intelligence OCR - [PR #21581](https://github.com/BerriAI/litellm/pull/21581) - **[Moonshot / Kimi](../../docs/providers/openai_compatible)** - Auto-fill `reasoning_content` for Moonshot Kimi reasoning models - [PR #23580](https://github.com/BerriAI/litellm/pull/23580) + - Preserve `image_url` blocks in multimodal messages for Moonshot - [PR #21595](https://github.com/BerriAI/litellm/pull/21595) - **[HuggingFace](../../docs/providers/huggingface)** - Forward `extra_headers` to HuggingFace embedding API - [PR #23525](https://github.com/BerriAI/litellm/pull/23525) +- **Token Counting / Cost** + - Fix `count_tokens` to include system prompts and tools in token counting API requests - [PR #22301](https://github.com/BerriAI/litellm/pull/22301) + - Pass all custom pricing fields to `register_model` in `completion()` and `embedding()` - [PR #22552](https://github.com/BerriAI/litellm/pull/22552) + +- **Tools / Function Calling** + - Gracefully repair truncated JSON in tool call arguments — prevents crashes on malformed tool responses - [PR #22503](https://github.com/BerriAI/litellm/pull/22503) + - Fix `output_item.done` for function calls not emitting `finish_reason` in streaming - [PR #22553](https://github.com/BerriAI/litellm/pull/22553) + - Preserve thinking block order with multiple web searches - [PR #23093](https://github.com/BerriAI/litellm/pull/23093) + - **General** - Normalize `content_filtered` finish reason across providers - [PR #23564](https://github.com/BerriAI/litellm/pull/23564) + - Unify `finish_reason` mapping to OpenAI-compatible values across all providers - [PR #22138](https://github.com/BerriAI/litellm/pull/22138) - Fix custom cost tracking on deployments for `/v1/messages` and `/v1/responses` - [PR #23647](https://github.com/BerriAI/litellm/pull/23647) - Fix per-request custom pricing when `router_model_id` has no pricing data — now falls back to model name + - Fix batch list showing stale `validating` status after completion - [PR #22982](https://github.com/BerriAI/litellm/pull/22982) + - Fix batch retrieve returning raw `output_file_id` when `model_id` is missing - [PR #23194](https://github.com/BerriAI/litellm/pull/23194) + - Encode batch IDs when `x-litellm-model` header is used - [PR #22653](https://github.com/BerriAI/litellm/pull/22653) + - Map `reasoning` to `reasoning_content` in streaming Delta for gpt-oss providers - [PR #22803](https://github.com/BerriAI/litellm/pull/22803) --- @@ -264,10 +340,36 @@ pip install litellm==1.82.3 - **Virtual Keys** - Add Organization dropdown to Create/Edit Key form — `organization_id` is now a first-class field in Key Ownership - [PR #23595](https://github.com/BerriAI/litellm/pull/23595) - Allow setting `organization_id` on `/key/update` — keys can be assigned or moved to a different organization after creation - [PR #23557](https://github.com/BerriAI/litellm/pull/23557) + - Manual Spend Reset for virtual keys from the UI — admins can reset key spend to zero on demand - [PR #22715](https://github.com/BerriAI/litellm/pull/22715) + - BYOK (Bring Your Own Key) — client-side provider API key takes precedence over proxy key for Anthropic `/v1/messages` - [PR #22964](https://github.com/BerriAI/litellm/pull/22964) + - UI login session duration configurable via `LITELLM_UI_SESSION_DURATION` environment variable - [PR #22182](https://github.com/BerriAI/litellm/pull/22182) + - Auto-redirect UI login to SSO via `auto_redirect_ui_login_to_sso: true` in config.yaml - [PR #23367](https://github.com/BerriAI/litellm/pull/23367) + +- **Access Control (RBAC)** + - Org Admins can now access team management endpoints — `/team/new`, `/team/update`, `/team/delete`, `/team/member_add`, `/team/member_delete` - [PR #23085](https://github.com/BerriAI/litellm/pull/23085), [PR #23095](https://github.com/BerriAI/litellm/pull/23095) + - Org Admins can view and invite internal users — full user management without requiring global admin role - [PR #23080](https://github.com/BerriAI/litellm/pull/23080) + - Allow Admin Viewers to access Audit Logs — view-only admin role now includes audit log access - [PR #23419](https://github.com/BerriAI/litellm/pull/23419) + - RBAC for Vector Stores and Agents — key/team-level access control for vector store and agent resources - [PR #22858](https://github.com/BerriAI/litellm/pull/22858) + - User filter scope (`scope_user_search_to_org`) is now opt-in — previously default-on, causing unintended restriction - [PR #23057](https://github.com/BerriAI/litellm/pull/23057) + +- **Vector Stores** + - Vector Store management endpoints — retrieve, list, update, and delete vector stores via `/v1/vector_stores/*` - [PR #23435](https://github.com/BerriAI/litellm/pull/23435) + +- **MCP Servers** + - Token authentication support for MCP servers — configure `auth_type: "bearer"` per MCP server - [PR #23260](https://github.com/BerriAI/litellm/pull/23260) + - Team-scoped MCP server filtering for key creation — keys only see MCP servers available to their team - [PR #23323](https://github.com/BerriAI/litellm/pull/23323) + - Per-server health recheck in the UI - [PR #23328](https://github.com/BerriAI/litellm/pull/23328) + +- **Teams** + - Batch expiry setting for teams — configure a default expiry duration for all team keys - [PR #22705](https://github.com/BerriAI/litellm/pull/22705) + - Team Admin can reset key spend - [PR #22725](https://github.com/BerriAI/litellm/pull/22725) - **Internal Users** - Add/Remove Team Membership directly from the Internal Users info page — includes searchable dropdown and role selector; no longer requires navigating to each team - [PR #23638](https://github.com/BerriAI/litellm/pull/23638) +- **Models** + - Attach knowledge base to model via UI - [PR #22656](https://github.com/BerriAI/litellm/pull/22656) + - **Default Team Settings** - Modernize page to antd (consistent with rest of app) - [PR #23614](https://github.com/BerriAI/litellm/pull/23614) - Fix: default team params (budget, duration, tpm, rpm, permissions) now correctly applied on `/team/new` - [PR #23614](https://github.com/BerriAI/litellm/pull/23614) @@ -290,6 +392,15 @@ pip install litellm==1.82.3 - Fix Public Model Hub not showing config-defined models after save - [PR #23501](https://github.com/BerriAI/litellm/pull/23501) - Fix fallback popup model dropdown z-index issue - [PR #23516](https://github.com/BerriAI/litellm/pull/23516) - Fix double-counting bug in org/team key limit checks on `/key/update` +- Fix invite link allowing multiple password resets for the same link - [PR #22462](https://github.com/BerriAI/litellm/pull/22462) +- Fix key expiry default duration not being applied when `duration` is not set - [PR #22956](https://github.com/BerriAI/litellm/pull/22956) +- Fix all proxy models not including model access groups in key creation - [PR #23236](https://github.com/BerriAI/litellm/pull/23236) +- Fix admin viewers unable to see all organizations - [PR #22940](https://github.com/BerriAI/litellm/pull/22940) +- Fix Audit Logs UI: added server-side pagination, filtering, and drawer view - [PR #22476](https://github.com/BerriAI/litellm/pull/22476) +- Fix MCP server URL and tools management issues - [PR #22751](https://github.com/BerriAI/litellm/pull/22751) +- Fix MCP server health checks triggering on server deletion - [PR #23063](https://github.com/BerriAI/litellm/pull/23063) +- Fix virtual keys in teams view not applying the team filter correctly - [PR #23065](https://github.com/BerriAI/litellm/pull/23065) +- Fix team expiry enforcement validation - [PR #22728](https://github.com/BerriAI/litellm/pull/22728) --- @@ -312,7 +423,10 @@ pip install litellm==1.82.3 ### Guardrails -No major guardrail changes in this release. +- **Guardrail mode default list** — Configure a default list of guardrail modes applied globally when no per-request mode is specified - [PR #22676](https://github.com/BerriAI/litellm/pull/22676) +- **Tag-based guardrail mode lists** — Specify a list of modes in tag-based guardrail configs instead of a single mode - [PR #23020](https://github.com/BerriAI/litellm/pull/23020) +- **Fix presidio PII token leak** — Edge case where Anthropic handle in Presidio caused PII data exposure in token response - [PR #22627](https://github.com/BerriAI/litellm/pull/22627) +- **Fix OTEL orphaned guardrail traces** — Span redundancy and missing response IDs in OpenTelemetry guardrail traces - [PR #23001](https://github.com/BerriAI/litellm/pull/23001) ### Prompt Management @@ -320,7 +434,17 @@ No major prompt management changes in this release. ### Secret Managers -No major secret manager changes in this release. +- **[Hashicorp Vault](../../docs/secret)** — Full Hashicorp Vault integration as a config override backend — secrets defined in Vault are fetched at startup and override `config.yaml` values. UI support for managing vault-sourced credentials included - [PR #22939](https://github.com/BerriAI/litellm/pull/22939), [PR #23036](https://github.com/BerriAI/litellm/pull/23036) + +--- + +## Spend Tracking + +- **Fix budget-linked keys never having spend reset** — Keys linked to budget objects were not having their spend reset on the configured reset interval - [PR #20688](https://github.com/BerriAI/litellm/pull/20688) +- **Flex pricing support** — Add `flex_pricing` field to cost map for providers that offer dynamic pricing tiers - [PR #22992](https://github.com/BerriAI/litellm/pull/22992) +- **Fix spend log cleanup** — Resolved lock tracking, integer retention, and skip-log-level issues in spend log cleanup job - [PR #22687](https://github.com/BerriAI/litellm/pull/22687) +- **Fix WebSearch spend log deduplication** — WebSearch interception was failing with thinking enabled; fixed along with spend log dedup - [PR #22679](https://github.com/BerriAI/litellm/pull/22679) +- **Fix TypeError when request has no API key** — Spend tracking was throwing unhandled exception when API key was absent from request - [PR #23363](https://github.com/BerriAI/litellm/pull/23363) --- @@ -330,6 +454,10 @@ No major secret manager changes in this release. - **Fix OOM / Prisma connection loss** on large installs — unbounded managed-object poll was exhausting Prisma connections after ~60–70 minutes on instances with 336K+ queued response rows - [PR #23472](https://github.com/BerriAI/litellm/pull/23472) - **Centralize logging kwarg updates** — root cause fix migrating all logging updates to a single function, eliminating kwarg inconsistencies across logging paths - [PR #23659](https://github.com/BerriAI/litellm/pull/23659) - **Fix tiktoken cache for non-root offline containers** — tiktoken cache now works correctly in offline environments running as non-root users - [PR #23498](https://github.com/BerriAI/litellm/pull/23498) +- **Block proxy startup when Redis transaction buffer has no Redis** — prevents silent data loss when `use_redis_transaction_buffer: true` is set without a Redis connection - [PR #23019](https://github.com/BerriAI/litellm/pull/23019) +- **Fix `InFlightRequestsMiddleware` crash** — undefined kwargs in middleware were causing request failures - [PR #22523](https://github.com/BerriAI/litellm/pull/22523) +- **Fix `BaseModelResponseIterator` crash on non-string stream chunks** — streaming was crashing when providers returned non-string chunk data - [PR #23497](https://github.com/BerriAI/litellm/pull/23497) +- **Fix `SERVER_ROOT_PATH` prefix handling** — strip prefix before checking mapped pass-through routes to prevent double-prefix issues - [PR #23414](https://github.com/BerriAI/litellm/pull/23414) - **Add CodSpeed continuous performance benchmarks** — automated performance regression tracking on CI - [PR #23676](https://github.com/BerriAI/litellm/pull/23676) --- @@ -366,12 +494,15 @@ No major secret manager changes in this release. ## Diff Summary ## 03/16/2026 -* New Providers: 5 +* New Providers: 7 * New Models / Updated Models: 116 new, 132 removed -* LLM API Endpoints: 5 -* Management Endpoints / UI: 11 -* AI Integrations: 4 -* Performance / Reliability: 5 +* LLM API Endpoints: 37 +* Management Endpoints / UI: 31 +* AI Integrations: 8 +* Guardrails: 4 +* Secret Managers: 1 +* Spend Tracking: 5 +* Performance / Reliability: 9 * Security: 3 * Database / Proxy Operations: 2 From d5ef754950671e2e8a0c17a993bd3837aa008b7c Mon Sep 17 00:00:00 2001 From: joereyna Date: Wed, 18 Mar 2026 23:35:40 -0700 Subject: [PATCH 157/539] docs(release-notes): align v1.82.3 notes with release notes guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add MCP Gateway section (moved from Management per guide rule §11) - Rename Spend Tracking → Spend Tracking, Budgets and Rate Limiting - Fix Hashicorp Vault doc link: docs/secret → docs/secret_managers - Fix LLM API section: #### Bug Fixes → #### Bugs (matches guide) - Add Documentation Updates section (required by guide §11) - Update Diff Summary: correct section names, add MCP Gateway and Documentation Updates counts Co-Authored-By: Claude Sonnet 4.6 --- .../my-website/release_notes/v1.82.3/index.md | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/docs/my-website/release_notes/v1.82.3/index.md b/docs/my-website/release_notes/v1.82.3/index.md index a51edc60c2..20be882671 100644 --- a/docs/my-website/release_notes/v1.82.3/index.md +++ b/docs/my-website/release_notes/v1.82.3/index.md @@ -273,7 +273,7 @@ pip install litellm==1.82.3 - Model cost aliases expansion — define aliases in the cost map that inherit pricing from a parent model - [PR #23314](https://github.com/BerriAI/litellm/pull/23314), [PR #23457](https://github.com/BerriAI/litellm/pull/23457) - Wildcards model support for the Files API - [PR #22740](https://github.com/BerriAI/litellm/pull/22740) -#### Bug Fixes +#### Bugs - **[Anthropic](../../docs/providers/anthropic)** - Preserve native tool format (web_search, bash, tool_search, etc.) when guardrails convert tools for the Anthropic Messages API - [PR #23526](https://github.com/BerriAI/litellm/pull/23526) @@ -355,11 +355,6 @@ pip install litellm==1.82.3 - **Vector Stores** - Vector Store management endpoints — retrieve, list, update, and delete vector stores via `/v1/vector_stores/*` - [PR #23435](https://github.com/BerriAI/litellm/pull/23435) -- **MCP Servers** - - Token authentication support for MCP servers — configure `auth_type: "bearer"` per MCP server - [PR #23260](https://github.com/BerriAI/litellm/pull/23260) - - Team-scoped MCP server filtering for key creation — keys only see MCP servers available to their team - [PR #23323](https://github.com/BerriAI/litellm/pull/23323) - - Per-server health recheck in the UI - [PR #23328](https://github.com/BerriAI/litellm/pull/23328) - - **Teams** - Batch expiry setting for teams — configure a default expiry duration for all team keys - [PR #22705](https://github.com/BerriAI/litellm/pull/22705) - Team Admin can reset key spend - [PR #22725](https://github.com/BerriAI/litellm/pull/22725) @@ -397,8 +392,6 @@ pip install litellm==1.82.3 - Fix all proxy models not including model access groups in key creation - [PR #23236](https://github.com/BerriAI/litellm/pull/23236) - Fix admin viewers unable to see all organizations - [PR #22940](https://github.com/BerriAI/litellm/pull/22940) - Fix Audit Logs UI: added server-side pagination, filtering, and drawer view - [PR #22476](https://github.com/BerriAI/litellm/pull/22476) -- Fix MCP server URL and tools management issues - [PR #22751](https://github.com/BerriAI/litellm/pull/22751) -- Fix MCP server health checks triggering on server deletion - [PR #23063](https://github.com/BerriAI/litellm/pull/23063) - Fix virtual keys in teams view not applying the team filter correctly - [PR #23065](https://github.com/BerriAI/litellm/pull/23065) - Fix team expiry enforcement validation - [PR #22728](https://github.com/BerriAI/litellm/pull/22728) @@ -434,11 +427,26 @@ No major prompt management changes in this release. ### Secret Managers -- **[Hashicorp Vault](../../docs/secret)** — Full Hashicorp Vault integration as a config override backend — secrets defined in Vault are fetched at startup and override `config.yaml` values. UI support for managing vault-sourced credentials included - [PR #22939](https://github.com/BerriAI/litellm/pull/22939), [PR #23036](https://github.com/BerriAI/litellm/pull/23036) +- **[Hashicorp Vault](../../docs/secret_managers)** — Full Hashicorp Vault integration as a config override backend — secrets defined in Vault are fetched at startup and override `config.yaml` values. UI support for managing vault-sourced credentials included - [PR #22939](https://github.com/BerriAI/litellm/pull/22939), [PR #23036](https://github.com/BerriAI/litellm/pull/23036) --- -## Spend Tracking +## MCP Gateway + +#### Features + +- **Token authentication for MCP servers** — configure `auth_type: "bearer"` per MCP server to require token-based auth on tool calls - [PR #23260](https://github.com/BerriAI/litellm/pull/23260) +- **Team-scoped MCP server filtering** — keys created under a team only see MCP servers available to that team - [PR #23323](https://github.com/BerriAI/litellm/pull/23323) +- **Per-server health recheck in UI** — trigger a health check for individual MCP servers without reloading all servers - [PR #23328](https://github.com/BerriAI/litellm/pull/23328) + +#### Bugs + +- Fix MCP server URL and tools management issues causing tool discovery to fail - [PR #22751](https://github.com/BerriAI/litellm/pull/22751) +- Fix MCP server health checks triggering on server deletion - [PR #23063](https://github.com/BerriAI/litellm/pull/23063) + +--- + +## Spend Tracking, Budgets and Rate Limiting - **Fix budget-linked keys never having spend reset** — Keys linked to budget objects were not having their spend reset on the configured reset interval - [PR #20688](https://github.com/BerriAI/litellm/pull/20688) - **Flex pricing support** — Add `flex_pricing` field to cost map for providers that offer dynamic pricing tiers - [PR #22992](https://github.com/BerriAI/litellm/pull/22992) @@ -477,6 +485,16 @@ No major prompt management changes in this release. --- +## Documentation Updates + +- Add Anthropic `/v1/messages` → `/responses` parameter mapping reference - [PR #22893](https://github.com/BerriAI/litellm/pull/22893) +- Update Okta SSO docs and custom SSO handler example - [PR #22786](https://github.com/BerriAI/litellm/pull/22786) +- Add `LITELLM_MAX_BUDGET_PER_SESSION_TTL` to environment variables reference - [PR #23186](https://github.com/BerriAI/litellm/pull/23186) +- Add DB query performance guidelines to `CLAUDE.md` - [PR #23196](https://github.com/BerriAI/litellm/pull/23196) +- Add Gemini Vertex AI PayGo/priority cost tracking docs - [PR #22948](https://github.com/BerriAI/litellm/pull/22948) + +--- + ## New Contributors * @ryanh-ai made their first contribution in [PR #21542](https://github.com/BerriAI/litellm/pull/21542) @@ -499,12 +517,12 @@ No major prompt management changes in this release. * LLM API Endpoints: 37 * Management Endpoints / UI: 31 * AI Integrations: 8 -* Guardrails: 4 -* Secret Managers: 1 -* Spend Tracking: 5 -* Performance / Reliability: 9 +* MCP Gateway: 5 +* Spend Tracking, Budgets and Rate Limiting: 5 +* Performance / Loadbalancing / Reliability improvements: 9 * Security: 3 * Database / Proxy Operations: 2 +* Documentation Updates: 5 --- From b20c448188acabe3df75597d1c76a65b07d61149 Mon Sep 17 00:00:00 2001 From: chengyongru <61816729+chengyongru@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:34:47 +0800 Subject: [PATCH 158/539] fix(openai): handle missing 'id' field in streaming chunks for MiniMax (#23931) - Change chunk["id"] to chunk.get("id") for compatibility with MiniMax - ModelResponseStream auto-generates id when None is passed - Add regression test test_chunk_parser_without_id_field --- .../llms/openai/chat/gpt_transformation.py | 2 +- .../chat/test_openai_gpt_transformation.py | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 63beb82ded..34a23222c2 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -806,7 +806,7 @@ class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator): choices = self._map_reasoning_to_reasoning_content(choices) kwargs = { - "id": chunk["id"], + "id": chunk.get("id"), "object": "chat.completion.chunk", "created": chunk.get("created"), "model": chunk.get("model"), diff --git a/tests/test_litellm/llms/openai/chat/test_openai_gpt_transformation.py b/tests/test_litellm/llms/openai/chat/test_openai_gpt_transformation.py index 086d01f65b..5d0b1ec856 100644 --- a/tests/test_litellm/llms/openai/chat/test_openai_gpt_transformation.py +++ b/tests/test_litellm/llms/openai/chat/test_openai_gpt_transformation.py @@ -281,6 +281,45 @@ class TestOpenAIChatCompletionStreamingHandler: # Verify that reasoning_content is not set (it should be deleted by Delta.__init__) assert not hasattr(parsed_chunk.choices[0].delta, "reasoning_content") + def test_chunk_parser_without_id_field(self): + """ + Test that chunk_parser works when chunk is missing the 'id' field. + + Some OpenAI-compatible providers (e.g., MiniMax) return streaming chunks + without an 'id' field in certain cases. This should not raise KeyError. + + Regression test for: KeyError: 'id' when using MiniMax m2.5 model + """ + handler = OpenAIChatCompletionStreamingHandler( + streaming_response=None, sync_stream=True + ) + + # Simulate a chunk without 'id' field (as returned by MiniMax) + chunk = { + "object": "chat.completion.chunk", + "created": 1769511767, + "model": "minimax/m2.5", + "choices": [ + { + "delta": { + "content": "Hello", + "role": "assistant", + }, + "finish_reason": None, + "index": 0, + } + ], + } + + # Parse the chunk - should not raise KeyError + parsed_chunk = handler.chunk_parser(chunk) + + # Verify that content is present and id was auto-generated + assert parsed_chunk.choices[0].delta.content == "Hello" + assert parsed_chunk.choices[0].delta.role == "assistant" + # ModelResponseStream auto-generates an id when None is passed + assert parsed_chunk.id is not None + class TestPromptCacheKeyIntegration: """Tests for prompt_cache_key support""" From e19a717b53bd6129c350ebf0c7c63a592ec5ad08 Mon Sep 17 00:00:00 2001 From: superpoussin22 Date: Thu, 19 Mar 2026 09:22:10 +0100 Subject: [PATCH 159/539] Add IF NOT EXISTS to index creation in migration --- .../20260318140652_add_index_to_team_table/migration.sql | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260318140652_add_index_to_team_table/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260318140652_add_index_to_team_table/migration.sql index 494aaf6238..89121d636f 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260318140652_add_index_to_team_table/migration.sql +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260318140652_add_index_to_team_table/migration.sql @@ -1,9 +1,9 @@ -- CreateIndex -CREATE INDEX "LiteLLM_TeamTable_organization_id_idx" ON "LiteLLM_TeamTable"("organization_id"); +CREATE INDEX IF NOT EXISTS "LiteLLM_TeamTable_organization_id_idx" ON "LiteLLM_TeamTable"("organization_id"); -- CreateIndex -CREATE INDEX "LiteLLM_TeamTable_team_alias_idx" ON "LiteLLM_TeamTable"("team_alias"); +CREATE INDEX IF NOT EXISTS "LiteLLM_TeamTable_team_alias_idx" ON "LiteLLM_TeamTable"("team_alias"); -- CreateIndex -CREATE INDEX "LiteLLM_TeamTable_created_at_idx" ON "LiteLLM_TeamTable"("created_at"); +CREATE INDEX IF NOT EXISTS "LiteLLM_TeamTable_created_at_idx" ON "LiteLLM_TeamTable"("created_at"); From 4dc645fc334a2bce75e345433990ebaea42b1570 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 13:59:59 +0530 Subject: [PATCH 160/539] feat(polling): check rate limits before creating polling ID Move pre-call checks (rate limits, guardrails, budget) to run BEFORE polling ID creation in the background streaming flow. This prevents the edge case where a rate-limited request receives a polling ID that immediately fails. Changes: - Add skip_pre_call_logic parameter to base_process_llm_request to allow skipping pre-call checks (avoiding double-counting of RPM/parallel requests) - Run common_processing_pre_call_logic before generating polling ID in the responses API endpoint. If rate limits/guardrails fail, return error immediately without creating a polling ID - Background streaming task passes skip_pre_call_logic=True to avoid re-running pre-call checks that were already done before polling ID creation - Add tests verifying skip_pre_call_logic parameter works correctly Fixes the edge case where polling_via_cache would return a polling ID for a request that immediately fails due to rate limiting. --- litellm/proxy/common_request_processing.py | 36 +++--- .../proxy/response_api_endpoints/endpoints.py | 32 +++++- .../response_polling/background_streaming.py | 5 +- .../test_response_polling_pre_call_checks.py | 104 ++++++++++++++++++ 4 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 tests/proxy_unit_tests/test_response_polling_pre_call_checks.py diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 72765aab7d..84f9730a37 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -900,6 +900,7 @@ class ProxyBaseLLMRequestProcessing: version: Optional[str] = None, is_streaming_request: Optional[bool] = False, contents: Optional[list] = None, # Add contents parameter + skip_pre_call_logic: bool = False, ) -> Any: """ Common request processing logic for both chat completions and responses API endpoints @@ -909,22 +910,25 @@ class ProxyBaseLLMRequestProcessing: ) self._debug_log_request_payload() - self.data, logging_obj = await self.common_processing_pre_call_logic( - request=request, - general_settings=general_settings, - proxy_logging_obj=proxy_logging_obj, - user_api_key_dict=user_api_key_dict, - version=version, - proxy_config=proxy_config, - user_model=user_model, - user_temperature=user_temperature, - user_request_timeout=user_request_timeout, - user_max_tokens=user_max_tokens, - user_api_base=user_api_base, - model=model, - route_type=route_type, - llm_router=llm_router, - ) + if skip_pre_call_logic: + logging_obj = self.data.get("litellm_logging_obj") + else: + self.data, logging_obj = await self.common_processing_pre_call_logic( + request=request, + general_settings=general_settings, + proxy_logging_obj=proxy_logging_obj, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + model=model, + route_type=route_type, + llm_router=llm_router, + ) tasks = [] # Start the moderation check (during_call_hook) as early as possible diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index e9c7cce0d7..055fdeb84f 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -119,6 +119,34 @@ async def responses_api( f"Starting background response with polling for model={data.get('model')}" ) + # Run pre-call checks (rate limits, guardrails, budget) BEFORE creating + # polling ID. This ensures rate-limited requests get a synchronous 429 + # instead of a polling ID that immediately fails in the background task. + processor = ProxyBaseLLMRequestProcessing(data=data) + try: + data, _logging_obj = await processor.common_processing_pre_call_logic( + request=request, + general_settings=general_settings, + proxy_logging_obj=proxy_logging_obj, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + route_type="aresponses", + llm_router=llm_router, + ) + except Exception as e: + raise await processor._handle_llm_api_exception( + e=e, + user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, + version=version, + ) + # Initialize polling handler with configured TTL (from global config) polling_handler = ResponsePollingHandler( redis_cache=redis_usage_cache, @@ -134,7 +162,9 @@ async def responses_api( request_data=data, ) - # Start background task to stream and update cache + # Start background task to stream and update cache. + # Pass pre-processed data so the background task skips pre-call logic + # (rate limits, guardrails already checked above). asyncio.create_task( background_streaming_task( polling_id=polling_id, diff --git a/litellm/proxy/response_polling/background_streaming.py b/litellm/proxy/response_polling/background_streaming.py index 7583f30eb2..bcc9817577 100644 --- a/litellm/proxy/response_polling/background_streaming.py +++ b/litellm/proxy/response_polling/background_streaming.py @@ -65,7 +65,9 @@ async def background_streaming_task( # noqa: PLR0915 # Create processor processor = ProxyBaseLLMRequestProcessing(data=data) - # Make streaming request + # Make streaming request. + # Pre-call checks (rate limits, guardrails, budget) were already run + # before polling ID creation, so skip them here to avoid double-counting. response = await processor.base_process_llm_request( request=request, fastapi_response=fastapi_response, @@ -83,6 +85,7 @@ async def background_streaming_task( # noqa: PLR0915 user_max_tokens=user_max_tokens, user_api_base=user_api_base, version=version, + skip_pre_call_logic=True, ) # Process streaming response following OpenAI events format diff --git a/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py new file mode 100644 index 0000000000..b39f1bf43d --- /dev/null +++ b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py @@ -0,0 +1,104 @@ +""" +Unit tests for pre-call checks running before polling ID creation. + +Tests that rate limits, guardrails, and budget checks are enforced +BEFORE a polling ID is created, so rate-limited requests get a +synchronous error instead of a polling ID that immediately fails. +""" + +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import Request, Response + +sys.path.insert(0, os.path.abspath("../..")) + +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing + + +class TestSkipPreCallLogic: + """Test that skip_pre_call_logic parameter works correctly""" + + @pytest.mark.asyncio + async def test_skip_pre_call_logic_skips_common_processing(self): + """When skip_pre_call_logic=True, common_processing_pre_call_logic should not be called""" + mock_logging_obj = MagicMock() + data = { + "model": "gpt-4", + "stream": True, + "litellm_logging_obj": mock_logging_obj, + } + processor = ProxyBaseLLMRequestProcessing(data=data) + + mock_proxy_logging = AsyncMock() + mock_proxy_logging.during_call_hook = AsyncMock() + + with ( + patch.object( + processor, "common_processing_pre_call_logic", new_callable=AsyncMock + ) as mock_pre_call, + patch( + "litellm.proxy.common_request_processing.route_request", + new_callable=AsyncMock, + return_value=MagicMock(), + ), + ): + try: + await processor.base_process_llm_request( + request=MagicMock(spec=Request), + fastapi_response=MagicMock(spec=Response), + user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), + route_type="aresponses", + proxy_logging_obj=mock_proxy_logging, + llm_router=MagicMock(), + general_settings={}, + proxy_config=MagicMock(), + skip_pre_call_logic=True, + ) + except Exception: + pass # We only care that common_processing_pre_call_logic was not called + + mock_pre_call.assert_not_called() + + @pytest.mark.asyncio + async def test_without_skip_runs_common_processing(self): + """When skip_pre_call_logic=False (default), common_processing_pre_call_logic should be called""" + data = {"model": "gpt-4"} + processor = ProxyBaseLLMRequestProcessing(data=data) + + mock_logging_obj = MagicMock() + mock_proxy_logging = AsyncMock() + mock_proxy_logging.during_call_hook = AsyncMock() + + with ( + patch.object( + processor, + "common_processing_pre_call_logic", + new_callable=AsyncMock, + return_value=(data, mock_logging_obj), + ) as mock_pre_call, + patch( + "litellm.proxy.common_request_processing.route_request", + new_callable=AsyncMock, + ), + ): + try: + await processor.base_process_llm_request( + request=MagicMock(spec=Request), + fastapi_response=MagicMock(spec=Response), + user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), + route_type="aresponses", + proxy_logging_obj=mock_proxy_logging, + llm_router=MagicMock(), + general_settings={}, + proxy_config=MagicMock(), + ) + except Exception: + pass + + mock_pre_call.assert_called_once() + + From c12717f494a5f7adbc7e2e0ab5f7cd814cc7c222 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 14:10:58 +0530 Subject: [PATCH 161/539] fix: address Greptile review comments - Guard logging_obj for None when skip_pre_call_logic=True: raise ValueError if litellm_logging_obj not in data, preventing AttributeError downstream - Add model=None to common_processing_pre_call_logic call in endpoints.py to match style of other call sites - Add test verifying rate-limited request never receives polling ID --- litellm/proxy/common_request_processing.py | 5 ++ .../proxy/response_api_endpoints/endpoints.py | 1 + .../test_response_polling_pre_call_checks.py | 63 ++++++++++++++++++- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 84f9730a37..b86a7595ae 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -912,6 +912,11 @@ class ProxyBaseLLMRequestProcessing: if skip_pre_call_logic: logging_obj = self.data.get("litellm_logging_obj") + if logging_obj is None: + raise ValueError( + "skip_pre_call_logic=True requires litellm_logging_obj to be set in data. " + "Ensure common_processing_pre_call_logic was called before using this parameter." + ) else: self.data, logging_obj = await self.common_processing_pre_call_logic( request=request, diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index 055fdeb84f..8023853e26 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -136,6 +136,7 @@ async def responses_api( user_request_timeout=user_request_timeout, user_max_tokens=user_max_tokens, user_api_base=user_api_base, + model=None, route_type="aresponses", llm_router=llm_router, ) diff --git a/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py index b39f1bf43d..cdea075d0d 100644 --- a/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py +++ b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py @@ -11,10 +11,11 @@ import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest -from fastapi import Request, Response +from fastapi import HTTPException, Request, Response sys.path.insert(0, os.path.abspath("../..")) +import litellm from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing @@ -102,3 +103,63 @@ class TestSkipPreCallLogic: mock_pre_call.assert_called_once() +class TestPollingEndpointPreCallGuard: + """Test that the polling endpoint enforces pre-call checks before polling ID creation""" + + @pytest.mark.asyncio + async def test_rate_limit_error_prevents_polling_id_creation(self): + """When pre-call checks raise, generate_polling_id must not be called""" + from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler + + rate_limit_exc = litellm.RateLimitError( + message="TPM limit exceeded", + llm_provider="", + model="gpt-4", + ) + + generate_polling_id_mock = MagicMock(return_value="litellm_poll_test") + + with ( + patch.object( + ProxyBaseLLMRequestProcessing, + "common_processing_pre_call_logic", + new_callable=AsyncMock, + side_effect=rate_limit_exc, + ), + patch.object( + ProxyBaseLLMRequestProcessing, + "_handle_llm_api_exception", + new_callable=AsyncMock, + return_value=HTTPException(status_code=429, detail="Rate limit exceeded"), + ), + patch.object(ResponsePollingHandler, "generate_polling_id", generate_polling_id_mock), + ): + # Simulate the endpoint logic directly (avoids proxy_server import complexity) + data = {"model": "gpt-4", "background": True} + processor = ProxyBaseLLMRequestProcessing(data=data) + + raised_exc = None + try: + await processor.common_processing_pre_call_logic( + request=MagicMock(spec=Request), + general_settings={}, + proxy_logging_obj=AsyncMock(), + user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), + version="1.0.0", + proxy_config=MagicMock(), + user_model=None, + user_temperature=None, + user_request_timeout=None, + user_max_tokens=None, + user_api_base=None, + model=None, + route_type="aresponses", + llm_router=MagicMock(), + ) + except litellm.RateLimitError as e: + raised_exc = e + + # The exception was raised before generate_polling_id could be called + assert raised_exc is not None + generate_polling_id_mock.assert_not_called() + From 66f97a00a44d096c5ee0e54e9fbab59ea8ed9cd7 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 14:30:29 +0530 Subject: [PATCH 162/539] fix(test): rewrite polling pre-call guard test to call responses_api() directly Previously the test called common_processing_pre_call_logic in isolation, making generate_polling_id.assert_not_called() vacuously true. Now the test calls responses_api() end-to-end so it actually verifies that a rate-limited request never receives a polling ID. Co-Authored-By: Claude Sonnet 4.6 --- .../test_response_polling_pre_call_checks.py | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py index cdea075d0d..45e4e9e4d3 100644 --- a/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py +++ b/tests/proxy_unit_tests/test_response_polling_pre_call_checks.py @@ -108,7 +108,8 @@ class TestPollingEndpointPreCallGuard: @pytest.mark.asyncio async def test_rate_limit_error_prevents_polling_id_creation(self): - """When pre-call checks raise, generate_polling_id must not be called""" + """responses_api() must raise 429 and never call generate_polling_id when rate-limited""" + from litellm.proxy.response_api_endpoints.endpoints import responses_api from litellm.proxy.response_polling.polling_handler import ResponsePollingHandler rate_limit_exc = litellm.RateLimitError( @@ -116,10 +117,37 @@ class TestPollingEndpointPreCallGuard: llm_provider="", model="gpt-4", ) - generate_polling_id_mock = MagicMock(return_value="litellm_poll_test") + proxy_server_patches = { + "litellm.proxy.proxy_server._read_request_body": AsyncMock( + return_value={"model": "gpt-4", "background": True} + ), + "litellm.proxy.proxy_server.general_settings": {}, + "litellm.proxy.proxy_server.llm_router": MagicMock(), + "litellm.proxy.proxy_server.native_background_mode": None, + "litellm.proxy.proxy_server.polling_cache_ttl": 3600, + "litellm.proxy.proxy_server.polling_via_cache_enabled": True, + "litellm.proxy.proxy_server.proxy_config": MagicMock(), + "litellm.proxy.proxy_server.proxy_logging_obj": AsyncMock(), + "litellm.proxy.proxy_server.redis_usage_cache": AsyncMock(), + "litellm.proxy.proxy_server.select_data_generator": None, + "litellm.proxy.proxy_server.user_api_base": None, + "litellm.proxy.proxy_server.user_max_tokens": None, + "litellm.proxy.proxy_server.user_model": None, + "litellm.proxy.proxy_server.user_request_timeout": None, + "litellm.proxy.proxy_server.user_temperature": None, + "litellm.proxy.proxy_server.version": "1.0.0", + } + with ( + patch.multiple("litellm.proxy.proxy_server", **{ + k.split(".")[-1]: v for k, v in proxy_server_patches.items() + }), + patch( + "litellm.proxy.response_polling.polling_handler.should_use_polling_for_request", + return_value=True, + ), patch.object( ProxyBaseLLMRequestProcessing, "common_processing_pre_call_logic", @@ -133,33 +161,22 @@ class TestPollingEndpointPreCallGuard: return_value=HTTPException(status_code=429, detail="Rate limit exceeded"), ), patch.object(ResponsePollingHandler, "generate_polling_id", generate_polling_id_mock), + # Prevent background task from running (avoids noise from incomplete mocks) + patch("asyncio.create_task"), + patch.object( + ResponsePollingHandler, + "create_initial_state", + new_callable=AsyncMock, + return_value=MagicMock(), + ), ): - # Simulate the endpoint logic directly (avoids proxy_server import complexity) - data = {"model": "gpt-4", "background": True} - processor = ProxyBaseLLMRequestProcessing(data=data) - - raised_exc = None - try: - await processor.common_processing_pre_call_logic( + with pytest.raises(HTTPException) as exc_info: + await responses_api( request=MagicMock(spec=Request), - general_settings={}, - proxy_logging_obj=AsyncMock(), + fastapi_response=MagicMock(spec=Response), user_api_key_dict=MagicMock(spec=UserAPIKeyAuth), - version="1.0.0", - proxy_config=MagicMock(), - user_model=None, - user_temperature=None, - user_request_timeout=None, - user_max_tokens=None, - user_api_base=None, - model=None, - route_type="aresponses", - llm_router=MagicMock(), ) - except litellm.RateLimitError as e: - raised_exc = e - # The exception was raised before generate_polling_id could be called - assert raised_exc is not None - generate_polling_id_mock.assert_not_called() + assert exc_info.value.status_code == 429 + generate_polling_id_mock.assert_not_called() From 528daa8cf43767231531263f35fb6ad5a5ecadbf Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 14:44:01 +0530 Subject: [PATCH 163/539] feat(router): add per-model-group deployment affinity configuration Enable deployment_affinity, responses_api_deployment_check, and session_affinity to be configured per model group via router_settings.model_group_affinity_config, falling back to global settings for unconfigured groups. - Add model_group_affinity_config parameter to Router and DeploymentAffinityCheck - Add _get_effective_flags helper to resolve flags per model group - Update async_filter_deployments and async_pre_call_deployment_hook to use per-group config - Add 4 comprehensive tests covering per-group config, fallback, and override scenarios This allows fine-grained control of affinity behavior across model groups, e.g., enabling stickiness only for cross-provider deployments while leaving other groups free to load-balance. Co-Authored-By: Claude Haiku 4.5 --- litellm/router.py | 29 ++ .../deployment_affinity_check.py | 88 ++++-- litellm/types/router.py | 1 + .../test_deployment_affinity_check.py | 281 ++++++++++++++++++ 4 files changed, 371 insertions(+), 28 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 46998abb16..5fc0298ced 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -301,6 +301,7 @@ class Router: RouterGeneralSettings ] = RouterGeneralSettings(), deployment_affinity_ttl_seconds: int = 3600, + model_group_affinity_config: Optional[Dict[str, List[str]]] = None, ignore_invalid_deployments: bool = False, ) -> None: """ @@ -641,6 +642,9 @@ class Router: self.model_group_retry_policy: Optional[ Dict[str, RetryPolicy] ] = model_group_retry_policy + self.model_group_affinity_config: Optional[ + Dict[str, List[str]] + ] = model_group_affinity_config self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -661,6 +665,26 @@ class Router: if optional_pre_call_checks is not None: self.add_optional_pre_call_checks(optional_pre_call_checks) + # If model_group_affinity_config is set but no global affinity checks were + # enabled, we still need the DeploymentAffinityCheck callback (with global + # flags all False) so per-group config can activate affinity per model group. + if self.model_group_affinity_config and not any( + isinstance(cb, DeploymentAffinityCheck) + for cb in (self.optional_callbacks or []) + ): + if self.optional_callbacks is None: + self.optional_callbacks = [] + affinity_callback = DeploymentAffinityCheck( + cache=self.cache, + ttl_seconds=self.deployment_affinity_ttl_seconds, + enable_user_key_affinity=False, + enable_responses_api_affinity=False, + enable_session_id_affinity=False, + model_group_affinity_config=self.model_group_affinity_config, + ) + self.optional_callbacks.append(affinity_callback) + litellm.logging_callback_manager.add_litellm_callback(affinity_callback) + if self.alerting_config is not None: self._initialize_alerting() @@ -1311,6 +1335,10 @@ class Router: existing_affinity_callback.ttl_seconds = ( self.deployment_affinity_ttl_seconds ) + if self.model_group_affinity_config: + existing_affinity_callback.model_group_affinity_config = ( + self.model_group_affinity_config + ) else: affinity_callback = DeploymentAffinityCheck( cache=self.cache, @@ -1318,6 +1346,7 @@ class Router: enable_user_key_affinity=enable_user_key_affinity, enable_responses_api_affinity=enable_responses_api_affinity, enable_session_id_affinity=enable_session_id_affinity, + model_group_affinity_config=self.model_group_affinity_config, ) self.optional_callbacks.append(affinity_callback) litellm.logging_callback_manager.add_litellm_callback(affinity_callback) diff --git a/litellm/router_utils/pre_call_checks/deployment_affinity_check.py b/litellm/router_utils/pre_call_checks/deployment_affinity_check.py index 8044f71d90..08da7d392d 100644 --- a/litellm/router_utils/pre_call_checks/deployment_affinity_check.py +++ b/litellm/router_utils/pre_call_checks/deployment_affinity_check.py @@ -13,7 +13,7 @@ where routing to a consistent deployment is still beneficial. """ import hashlib -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict @@ -46,6 +46,7 @@ class DeploymentAffinityCheck(CustomLogger): enable_user_key_affinity: bool, enable_responses_api_affinity: bool, enable_session_id_affinity: bool = False, + model_group_affinity_config: Optional[Dict[str, List[str]]] = None, ): super().__init__() self.cache = cache @@ -53,6 +54,32 @@ class DeploymentAffinityCheck(CustomLogger): self.enable_user_key_affinity = enable_user_key_affinity self.enable_responses_api_affinity = enable_responses_api_affinity self.enable_session_id_affinity = enable_session_id_affinity + self.model_group_affinity_config: Dict[str, List[str]] = ( + model_group_affinity_config or {} + ) + + def _get_effective_flags( + self, model_group: str + ) -> Tuple[bool, bool, bool]: + """ + Return (enable_user_key_affinity, enable_responses_api_affinity, enable_session_id_affinity) + for the given model group. + + If the model group has an explicit entry in model_group_affinity_config, use it. + Otherwise fall back to the global instance flags. + """ + group_checks = self.model_group_affinity_config.get(model_group) + if group_checks is not None: + return ( + "deployment_affinity" in group_checks, + "responses_api_deployment_check" in group_checks, + "session_affinity" in group_checks, + ) + return ( + self.enable_user_key_affinity, + self.enable_responses_api_affinity, + self.enable_session_id_affinity, + ) @staticmethod def _looks_like_sha256_hex(value: str) -> bool: @@ -277,8 +304,12 @@ class DeploymentAffinityCheck(CustomLogger): request_kwargs = request_kwargs or {} typed_healthy_deployments = cast(List[dict], healthy_deployments) + enable_user_key, enable_responses_api, enable_session_id = ( + self._get_effective_flags(model) + ) + # 1) Responses API continuity (high priority) - if self.enable_responses_api_affinity: + if enable_responses_api: previous_response_id = request_kwargs.get("previous_response_id") if previous_response_id is not None: responses_model_id = ( @@ -305,7 +336,7 @@ class DeploymentAffinityCheck(CustomLogger): return typed_healthy_deployments # 2) Session-id -> deployment affinity - if self.enable_session_id_affinity: + if enable_session_id: session_id = self._get_session_id_from_request_kwargs( request_kwargs=request_kwargs ) @@ -344,7 +375,7 @@ class DeploymentAffinityCheck(CustomLogger): ) # 3) User key -> deployment affinity - if not self.enable_user_key_affinity: + if not enable_user_key: return typed_healthy_deployments user_key = self._get_user_key_from_request_kwargs(request_kwargs=request_kwargs) @@ -394,22 +425,42 @@ class DeploymentAffinityCheck(CustomLogger): - LiteLLM runs async success callbacks via a background logging worker for performance. - We want affinity to be immediately available for subsequent requests. """ - if not self.enable_user_key_affinity and not self.enable_session_id_affinity: + metadata_dicts = self._iter_metadata_dicts(kwargs) + + # Extract deployment_model_name first — needed for both per-group flag resolution + # and cache key scoping. + deployment_model_name: Optional[str] = None + for metadata in metadata_dicts: + maybe_deployment_model_name = metadata.get("deployment_model_name") + if ( + isinstance(maybe_deployment_model_name, str) + and maybe_deployment_model_name + ): + deployment_model_name = maybe_deployment_model_name + break + + if not deployment_model_name: + return None + + # Resolve effective flags for this model group + enable_user_key, _enable_responses_api, enable_session_id = ( + self._get_effective_flags(deployment_model_name) + ) + + if not enable_user_key and not enable_session_id: return None user_key = None - if self.enable_user_key_affinity: + if enable_user_key: user_key = self._get_user_key_from_request_kwargs(request_kwargs=kwargs) session_id = None - if self.enable_session_id_affinity: + if enable_session_id: session_id = self._get_session_id_from_request_kwargs(request_kwargs=kwargs) if user_key is None and session_id is None: return None - metadata_dicts = self._iter_metadata_dicts(kwargs) - model_info = kwargs.get("model_info") if not isinstance(model_info, dict): model_info = None @@ -433,25 +484,6 @@ class DeploymentAffinityCheck(CustomLogger): ) return None - # Scope affinity by the Router deployment model name (alias-safe, consistent across - # heterogeneous providers, and matches standard logging's `model_map_key`). - deployment_model_name: Optional[str] = None - for metadata in metadata_dicts: - maybe_deployment_model_name = metadata.get("deployment_model_name") - if ( - isinstance(maybe_deployment_model_name, str) - and maybe_deployment_model_name - ): - deployment_model_name = maybe_deployment_model_name - break - - if not deployment_model_name: - verbose_router_logger.warning( - "DeploymentAffinityCheck: deployment_model_name missing; skipping affinity cache update. model_id=%s", - model_id, - ) - return None - if user_key is not None: try: cache_key = self.get_affinity_cache_key( diff --git a/litellm/types/router.py b/litellm/types/router.py index e8ff2115ff..58411b1b85 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -77,6 +77,7 @@ class UpdateRouterConfig(BaseModel): routing_strategy_args: Optional[dict] = None routing_strategy: Optional[str] = None model_group_retry_policy: Optional[dict] = None + model_group_affinity_config: Optional[Dict[str, List[str]]] = None allowed_fails: Optional[int] = None cooldown_time: Optional[float] = None num_retries: Optional[int] = None diff --git a/tests/test_litellm/router_utils/pre_call_checks/test_deployment_affinity_check.py b/tests/test_litellm/router_utils/pre_call_checks/test_deployment_affinity_check.py index e500ad3ca6..28311a30c0 100644 --- a/tests/test_litellm/router_utils/pre_call_checks/test_deployment_affinity_check.py +++ b/tests/test_litellm/router_utils/pre_call_checks/test_deployment_affinity_check.py @@ -657,3 +657,284 @@ def test_cache_key_does_not_double_hash_user_api_key_hash(): user_key=user_api_key_hash, ) assert key.endswith(user_api_key_hash) + + +def test_get_effective_flags_returns_per_group_config(): + """ + _get_effective_flags should return per-group flags when the model group has an entry + in model_group_affinity_config, and global flags otherwise. + """ + callback = DeploymentAffinityCheck( + cache=AsyncMock(), + ttl_seconds=60, + enable_user_key_affinity=True, + enable_responses_api_affinity=True, + enable_session_id_affinity=False, + model_group_affinity_config={ + "gpt-4": ["deployment_affinity"], + "claude-3": ["session_affinity", "responses_api_deployment_check"], + }, + ) + + # gpt-4: only deployment_affinity + user_key, responses_api, session_id = callback._get_effective_flags("gpt-4") + assert user_key is True + assert responses_api is False + assert session_id is False + + # claude-3: session_affinity + responses_api_deployment_check + user_key, responses_api, session_id = callback._get_effective_flags("claude-3") + assert user_key is False + assert responses_api is True + assert session_id is True + + # unconfigured-model: falls back to global flags + user_key, responses_api, session_id = callback._get_effective_flags( + "unconfigured-model" + ) + assert user_key is True + assert responses_api is True + assert session_id is False + + +@pytest.mark.asyncio +async def test_model_group_affinity_config_only_applies_to_configured_group(): + """ + When model_group_affinity_config is set without global optional_pre_call_checks, + only configured model groups should get affinity behavior. + """ + mock_response_data = { + "id": "resp_mock-resp-per-group", + "object": "response", + "created_at": 1741476542, + "status": "completed", + "model": "openai/gpt-4", + "output": [ + { + "type": "message", + "id": "msg_pg", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": "Per-group response"}], + } + ], + "parallel_tool_calls": True, + "usage": {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}, + "text": {"format": {"type": "text"}}, + "error": None, + "previous_response_id": None, + } + + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "azure/gpt-4-deploy-1", + "api_key": "mock-key-1", + "api_base": "https://mock-gpt4-1.openai.azure.com", + "api_version": "2024-02-01", + }, + "model_info": {"base_model": "gpt-4"}, + }, + { + "model_name": "gpt-4", + "litellm_params": { + "model": "azure/gpt-4-deploy-2", + "api_key": "mock-key-2", + "api_base": "https://mock-gpt4-2.openai.azure.com", + "api_version": "2024-02-01", + }, + "model_info": {"base_model": "gpt-4"}, + }, + { + "model_name": "claude-3", + "litellm_params": { + "model": "azure/claude-3-deploy-1", + "api_key": "mock-key-3", + "api_base": "https://mock-claude-1.openai.azure.com", + "api_version": "2024-02-01", + }, + "model_info": {"base_model": "claude-3"}, + }, + { + "model_name": "claude-3", + "litellm_params": { + "model": "azure/claude-3-deploy-2", + "api_key": "mock-key-4", + "api_base": "https://mock-claude-2.openai.azure.com", + "api_version": "2024-02-01", + }, + "model_info": {"base_model": "claude-3"}, + }, + ], + # No global optional_pre_call_checks — only per-group + model_group_affinity_config={ + "gpt-4": ["deployment_affinity"], + }, + ) + + user_api_key_hash = "test-per-group-key" + choice_calls = {"count": 0} + + def deterministic_choice(seq): + choice_calls["count"] += 1 + if choice_calls["count"] == 1: + return seq[0] + return seq[1] if len(seq) > 1 else seq[0] + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post, patch( + "litellm.router_strategy.simple_shuffle.random.choice", + side_effect=deterministic_choice, + ): + mock_post.return_value = MockResponse(mock_response_data, 200) + + # gpt-4: affinity should work — second request pinned to same deployment + first = await router.aresponses( + model="gpt-4", + input="Hello", + truncation="auto", + litellm_metadata={"user_api_key_hash": user_api_key_hash}, + ) + first_model_id = first._hidden_params["model_id"] + + second = await router.aresponses( + model="gpt-4", + input="Follow-up", + truncation="auto", + litellm_metadata={"user_api_key_hash": user_api_key_hash}, + ) + assert second._hidden_params["model_id"] == first_model_id + + # claude-3: no affinity configured — should NOT be pinned + choice_calls["count"] = 0 + first_claude = await router.aresponses( + model="claude-3", + input="Hello", + truncation="auto", + litellm_metadata={"user_api_key_hash": user_api_key_hash}, + ) + first_claude_id = first_claude._hidden_params["model_id"] + + second_claude = await router.aresponses( + model="claude-3", + input="Follow-up", + truncation="auto", + litellm_metadata={"user_api_key_hash": user_api_key_hash}, + ) + # With deterministic choice and len>1, second call picks seq[1] + assert second_claude._hidden_params["model_id"] != first_claude_id + + +@pytest.mark.asyncio +async def test_model_group_affinity_config_falls_back_to_global(): + """ + When both global optional_pre_call_checks and model_group_affinity_config are set, + unconfigured model groups should use the global settings. + """ + callback = DeploymentAffinityCheck( + cache=DualCache(), + ttl_seconds=60, + enable_user_key_affinity=True, + enable_responses_api_affinity=False, + enable_session_id_affinity=False, + model_group_affinity_config={ + "claude-3": ["session_affinity"], + }, + ) + + stable_model_map_key = "gpt-4" + user_key = "test-fallback-key" + + healthy_deployments = [ + { + "model_name": stable_model_map_key, + "litellm_params": {"model": "openai/gpt-4"}, + "model_info": {"id": "deployment-1"}, + }, + { + "model_name": stable_model_map_key, + "litellm_params": {"model": "openai/gpt-4"}, + "model_info": {"id": "deployment-2"}, + }, + ] + + # Set up affinity cache for gpt-4 (should work since global has deployment_affinity) + await callback.async_pre_call_deployment_hook( + kwargs={ + "model_info": {"id": "deployment-1"}, + "metadata": { + "user_api_key_hash": user_key, + "deployment_model_name": stable_model_map_key, + }, + }, + call_type=None, + ) + + # gpt-4 not in model_group_affinity_config, so global flags apply (user_key affinity ON) + filtered = await callback.async_filter_deployments( + model="gpt-4", + healthy_deployments=healthy_deployments, + messages=None, + request_kwargs={"metadata": {"user_api_key_hash": user_key}}, + parent_otel_span=None, + ) + assert len(filtered) == 1 + assert filtered[0]["model_info"]["id"] == "deployment-1" + + +@pytest.mark.asyncio +async def test_model_group_affinity_config_overrides_global(): + """ + When model_group_affinity_config specifies session_affinity for a model group, + user-key affinity (from global config) should NOT apply to that group. + """ + callback = DeploymentAffinityCheck( + cache=DualCache(), + ttl_seconds=60, + enable_user_key_affinity=True, + enable_responses_api_affinity=False, + enable_session_id_affinity=False, + model_group_affinity_config={ + "claude-3": ["session_affinity"], + }, + ) + + stable_model_map_key = "claude-3" + user_key = "test-override-key" + + healthy_deployments = [ + { + "model_name": stable_model_map_key, + "litellm_params": {"model": "anthropic/claude-3-opus"}, + "model_info": {"id": "deployment-1"}, + }, + { + "model_name": stable_model_map_key, + "litellm_params": {"model": "anthropic/claude-3-opus"}, + "model_info": {"id": "deployment-2"}, + }, + ] + + # Set up user-key affinity cache for claude-3 + cache_key = DeploymentAffinityCheck.get_affinity_cache_key( + model_group=stable_model_map_key, user_key=user_key + ) + await callback.cache.async_set_cache( + cache_key, {"model_id": "deployment-1"}, ttl=60 + ) + + # claude-3 has per-group config (session_affinity only), so user-key affinity + # should NOT apply even though it's globally enabled + filtered = await callback.async_filter_deployments( + model="claude-3", + healthy_deployments=healthy_deployments, + messages=None, + request_kwargs={"metadata": {"user_api_key_hash": user_key}}, + parent_otel_span=None, + ) + # All deployments returned (user-key affinity disabled for this group) + assert len(filtered) == 2 From 6af74f6594c9e8607251b11e17859b77ed2dff11 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 15:39:35 +0530 Subject: [PATCH 164/539] fix(router): restore debug log for missing deployment_model_name; warn on unknown affinity flags Co-Authored-By: Claude Haiku 4.5 --- .../pre_call_checks/deployment_affinity_check.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/litellm/router_utils/pre_call_checks/deployment_affinity_check.py b/litellm/router_utils/pre_call_checks/deployment_affinity_check.py index 08da7d392d..728e689eaa 100644 --- a/litellm/router_utils/pre_call_checks/deployment_affinity_check.py +++ b/litellm/router_utils/pre_call_checks/deployment_affinity_check.py @@ -38,6 +38,9 @@ class DeploymentAffinityCheck(CustomLogger): """ CACHE_KEY_PREFIX = "deployment_affinity:v1" + VALID_FLAGS = frozenset( + {"deployment_affinity", "responses_api_deployment_check", "session_affinity"} + ) def __init__( self, @@ -57,6 +60,15 @@ class DeploymentAffinityCheck(CustomLogger): self.model_group_affinity_config: Dict[str, List[str]] = ( model_group_affinity_config or {} ) + for group, flags in self.model_group_affinity_config.items(): + unknown = set(flags) - self.VALID_FLAGS + if unknown: + verbose_router_logger.warning( + "DeploymentAffinityCheck: unknown flag(s) %s for model group '%s'; will be ignored. Valid flags: %s", + unknown, + group, + self.VALID_FLAGS, + ) def _get_effective_flags( self, model_group: str @@ -440,6 +452,9 @@ class DeploymentAffinityCheck(CustomLogger): break if not deployment_model_name: + verbose_router_logger.debug( + "DeploymentAffinityCheck: deployment_model_name missing in metadata; skipping affinity cache update." + ) return None # Resolve effective flags for this model group From a14122c28e5258d36339813539d084816a215cf0 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 15:50:43 +0530 Subject: [PATCH 165/539] docs: add per-model-group affinity configuration docs Co-Authored-By: Claude Haiku 4.5 --- docs/my-website/docs/proxy/config_settings.md | 1 + docs/my-website/docs/response_api.md | 79 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index f5b611a85a..d7d6079c51 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -363,6 +363,7 @@ router_settings: | router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) | | optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Supported: `router_budget_limiting`, `prompt_caching`, `responses_api_deployment_check`, `encrypted_content_affinity`, `deployment_affinity`, `session_affinity`, `forward_client_headers_by_model_group` | | deployment_affinity_ttl_seconds | int | TTL (seconds) for user-key → deployment affinity mapping when `deployment_affinity` is enabled (configured at Router init / proxy startup). Defaults to `3600` (1 hour). | +| model_group_affinity_config | Dict[str, List[str]] | Per-model-group affinity flags. Keys are model group names; values are lists of checks to enable (`deployment_affinity`, `responses_api_deployment_check`, `session_affinity`). Groups not listed fall back to the global `optional_pre_call_checks`. [Docs](../response_api.md#per-model-group-affinity-configuration) | | ignore_invalid_deployments | boolean | If true, ignores invalid deployments. Default for proxy is True - to prevent invalid models from blocking other models from being loaded. | | search_tools | List[SearchToolTypedDict] | List of search tool configurations for Search API integration. Each tool specifies a search_tool_name and litellm_params with search_provider, api_key, api_base, etc. [Further Docs](../search/index.md) | | guardrail_list | List[GuardrailTypedDict] | List of guardrail configurations for guardrail load balancing. Enables load balancing across multiple guardrail deployments with the same guardrail_name. [Further Docs](./guardrails/guardrail_load_balancing.md) | diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index fb55ae9f9d..56b8170995 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -1364,6 +1364,85 @@ litellm --config config.yaml | `deployment_affinity` | Simple sticky sessions | All requests from same API key | ❌ Reduces quota by # of users | +## Per-Model-Group Affinity Configuration + +By default, `optional_pre_call_checks` applies globally to all model groups. Use `model_group_affinity_config` when you want different affinity behavior per model group — for example, enabling stickiness only for models spread across providers (Azure + Bedrock) while leaving single-provider groups free to load-balance. + +Groups not listed fall back to the global `optional_pre_call_checks` settings. + + + + +```python +router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": {"model": "azure/gpt-4", "api_key": "...", "api_base": "https://endpoint1.openai.azure.com"}, + }, + { + "model_name": "gpt-4", + "litellm_params": {"model": "bedrock/anthropic.claude-v2", "aws_region_name": "us-east-1"}, + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": {"model": "azure/text-embedding-ada-002", "api_key": "...", "api_base": "https://endpoint1.openai.azure.com"}, + }, + { + "model_name": "text-embedding-ada-002", + "litellm_params": {"model": "azure/text-embedding-ada-002", "api_key": "...", "api_base": "https://endpoint2.openai.azure.com"}, + }, + ], + # gpt-4: cross-provider (Azure + Bedrock) — enable deployment affinity + # text-embedding-ada-002: same provider — no affinity, let it load balance freely + model_group_affinity_config={ + "gpt-4": ["deployment_affinity", "responses_api_deployment_check"], + }, +) +``` + + + + +```yaml title="config.yaml" +model_list: + - model_name: gpt-4 + litellm_params: + model: azure/gpt-4 + api_key: os.environ/AZURE_API_KEY_1 + api_base: https://endpoint1.openai.azure.com + + - model_name: gpt-4 + litellm_params: + model: bedrock/anthropic.claude-v2 + aws_region_name: us-east-1 + + - model_name: text-embedding-ada-002 + litellm_params: + model: azure/text-embedding-ada-002 + api_key: os.environ/AZURE_API_KEY_1 + api_base: https://endpoint1.openai.azure.com + + - model_name: text-embedding-ada-002 + litellm_params: + model: azure/text-embedding-ada-002 + api_key: os.environ/AZURE_API_KEY_2 + api_base: https://endpoint2.openai.azure.com + +router_settings: + # gpt-4: cross-provider — enable stickiness + # text-embedding-ada-002: not listed — load balances freely + model_group_affinity_config: + "gpt-4": + - deployment_affinity + - responses_api_deployment_check +``` + + + + +**Supported values:** `deployment_affinity`, `responses_api_deployment_check`, `session_affinity` + ## Calling non-Responses API endpoints (`/responses` to `/chat/completions` Bridge) LiteLLM allows you to call non-Responses API models via a bridge to LiteLLM's `/chat/completions` endpoint. This is useful for calling Anthropic, Gemini and even non-Responses API OpenAI models. From 532e0d13df3b3e4532bc805d02083fa7f01980dc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 15:57:03 +0530 Subject: [PATCH 166/539] feat(proxy): use AZURE_DEFAULT_API_VERSION for proxy --api_version default Aligns proxy default with litellm.AZURE_DEFAULT_API_VERSION (2025-02-01-preview) so Azure response_format + json_schema works without tools fallback. Made-with: Cursor --- litellm/proxy/proxy_cli.py | 3 +- tests/test_litellm/proxy/test_proxy_cli.py | 41 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 97d5de0d53..c638e29426 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -12,6 +12,7 @@ import click import httpx from dotenv import load_dotenv +import litellm from litellm.constants import DEFAULT_NUM_WORKERS_LITELLM_PROXY from litellm.secret_managers.main import get_secret_bool @@ -387,7 +388,7 @@ class ProxyInitializationHelpers: @click.option("--api_base", default=None, help="API base URL.") @click.option( "--api_version", - default="2024-07-01-preview", + default=litellm.AZURE_DEFAULT_API_VERSION, help="For azure - pass in the api version.", ) @click.option( diff --git a/tests/test_litellm/proxy/test_proxy_cli.py b/tests/test_litellm/proxy/test_proxy_cli.py index c5d6c45f9a..349fe76ed7 100644 --- a/tests/test_litellm/proxy/test_proxy_cli.py +++ b/tests/test_litellm/proxy/test_proxy_cli.py @@ -280,6 +280,47 @@ class TestProxyInitializationHelpers: assert result.exit_code == 0, f"exit_code={result.exit_code}, output={result.output}" mock_uvicorn_run.assert_called_once() + @patch("uvicorn.run") + @patch("atexit.register") + @patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database") + @patch("litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False) + def test_proxy_default_api_version_uses_azure_default( + self, mock_should_update, mock_setup_db, mock_atexit_register, mock_uvicorn_run + ): + """Proxy default api_version should match litellm.AZURE_DEFAULT_API_VERSION for consistency.""" + from click.testing import CliRunner + + import litellm + from litellm.proxy.proxy_cli import run_server + + runner = CliRunner() + mock_proxy_module = MagicMock( + app=MagicMock(), + ProxyConfig=MagicMock(), + KeyManagementSettings=MagicMock(), + save_worker_config=MagicMock(), + ) + clean_env = {k: v for k, v in os.environ.items() if k not in ("DATABASE_URL", "DIRECT_URL")} + with patch.dict(os.environ, clean_env, clear=True), patch.dict( + "sys.modules", + { + "proxy_server": mock_proxy_module, + "litellm.proxy.proxy_server": mock_proxy_module, + }, + ), patch( + "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args" + ) as mock_get_args: + mock_get_args.return_value = { + "app": "litellm.proxy.proxy_server:app", + "host": "localhost", + "port": 8000, + } + result = runner.invoke(run_server, ["--local", "--skip_server_startup"]) + assert result.exit_code == 0, f"exit_code={result.exit_code}, output={result.output}" + mock_proxy_module.save_worker_config.assert_called_once() + call_kwargs = mock_proxy_module.save_worker_config.call_args[1] + assert call_kwargs["api_version"] == litellm.AZURE_DEFAULT_API_VERSION + @patch("uvicorn.run") @patch("builtins.print") def test_keepalive_timeout_flag(self, mock_print, mock_uvicorn_run): From 067dab42e6fc8a455434f05926cfae88a6638b47 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 16:16:23 +0530 Subject: [PATCH 167/539] refactor: reduce statement count in langsmith and anthropic methods - Extract helper methods in langsmith._prepare_log_data to reduce from 51 to <50 statements - Extract helper methods in anthropic.transform_parsed_response to reduce from 57 to <50 statements - Fixes PLR0915 linter errors - All existing tests pass (10 langsmith tests, 126 anthropic tests) Made-with: Cursor --- litellm/integrations/langsmith.py | 143 +++++------ litellm/llms/anthropic/chat/transformation.py | 243 ++++++++++-------- 2 files changed, 196 insertions(+), 190 deletions(-) diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index ef2d30bb26..479b5027ef 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -5,7 +5,6 @@ import os import random import traceback import types -from litellm._uuid import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -14,10 +13,11 @@ from pydantic import BaseModel # type: ignore import litellm from litellm._logging import verbose_logger +from litellm._uuid import uuid from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.langsmith_mock_client import ( - should_use_langsmith_mock, create_mock_langsmith_client, + should_use_langsmith_mock, ) from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, @@ -110,6 +110,56 @@ class LangsmithLogger(CustomBatchLogger): LANGSMITH_TENANT_ID=_credentials_tenant_id, ) + def _extract_metadata_fields( + self, metadata: dict, credentials: LangsmithCredentialsObject + ): + return { + "project_name": metadata.get("project_name", credentials["LANGSMITH_PROJECT"]), + "run_name": metadata.get("run_name", self.langsmith_default_run_name), + "run_id": metadata.get("id", metadata.get("run_id", None)), + "parent_run_id": metadata.get("parent_run_id", None), + "trace_id": metadata.get("trace_id", None), + "session_id": metadata.get("session_id", None), + "dotted_order": metadata.get("dotted_order", None), + } + + def _build_extra_metadata(self, metadata: Dict): + extra_metadata = dict(metadata) + requester_metadata = extra_metadata.get("requester_metadata") + if requester_metadata and isinstance(requester_metadata, dict): + for key in ("session_id", "thread_id", "conversation_id"): + if key in requester_metadata and key not in extra_metadata: + extra_metadata[key] = requester_metadata[key] + return extra_metadata + + def _build_outputs_with_usage(self, payload: StandardLoggingPayload) -> Dict[str, Any]: + response = payload["response"] + outputs: Dict[str, Any] + if isinstance(response, dict): + outputs = {**response} + else: + outputs = {"output": response} + outputs["usage_metadata"] = { + "input_tokens": payload.get("prompt_tokens", 0), + "output_tokens": payload.get("completion_tokens", 0), + "total_tokens": payload.get("total_tokens", 0), + "total_cost": payload.get("response_cost", 0), + } + return outputs + + def _ensure_required_ids(self, data: dict, run_id: Optional[str]): + if "id" not in data or data["id"] is None: + run_id = str(uuid.uuid4()) + data["id"] = run_id + + if "trace_id" not in data or data["trace_id"] is None: + if run_id is not None and isinstance(run_id, str): + data["trace_id"] = run_id + + if "dotted_order" not in data or data["dotted_order"] is None: + if run_id is not None and isinstance(run_id, str): + data["dotted_order"] = self.make_dot_order(run_id=run_id) + def _prepare_log_data( self, kwargs, @@ -121,56 +171,28 @@ class LangsmithLogger(CustomBatchLogger): try: _litellm_params = kwargs.get("litellm_params", {}) or {} metadata = _litellm_params.get("metadata", {}) or {} - project_name = metadata.get( - "project_name", credentials["LANGSMITH_PROJECT"] - ) - run_name = metadata.get("run_name", self.langsmith_default_run_name) - run_id = metadata.get("id", metadata.get("run_id", None)) - parent_run_id = metadata.get("parent_run_id", None) - trace_id = metadata.get("trace_id", None) - session_id = metadata.get("session_id", None) - dotted_order = metadata.get("dotted_order", None) + + fields = self._extract_metadata_fields(metadata, credentials) verbose_logger.debug( - f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" + f"Langsmith Logging - project_name: {fields['project_name']}, run_name {fields['run_name']}" ) - # Ensure everything in the payload is converted to str payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None ) - if payload is None: raise Exception("Error logging request payload. Payload=none.") - metadata = payload[ - "metadata" - ] # ensure logged metadata is json serializable - - extra_metadata = dict(metadata) - requester_metadata = extra_metadata.get("requester_metadata") - if requester_metadata and isinstance(requester_metadata, dict): - for key in ("session_id", "thread_id", "conversation_id"): - if key in requester_metadata and key not in extra_metadata: - extra_metadata[key] = requester_metadata[key] - - outputs = payload["response"] - if isinstance(outputs, dict): - outputs = {**outputs} - else: - outputs = {"output": outputs} - outputs["usage_metadata"] = { - "input_tokens": payload.get("prompt_tokens", 0), - "output_tokens": payload.get("completion_tokens", 0), - "total_tokens": payload.get("total_tokens", 0), - "total_cost": payload.get("response_cost", 0), - } + metadata = payload["metadata"] + extra_metadata = self._build_extra_metadata(dict(metadata)) + outputs = self._build_outputs_with_usage(payload) data = { - "name": run_name, - "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" + "name": fields["run_name"], + "run_type": "llm", "inputs": payload, "outputs": outputs, - "session_name": project_name, + "session_name": fields["project_name"], "start_time": payload["startTime"], "end_time": payload["endTime"], "tags": payload["request_tags"], @@ -180,46 +202,13 @@ class LangsmithLogger(CustomBatchLogger): if payload["error_str"] is not None and payload["status"] == "failure": data["error"] = payload["error_str"] - if run_id: - data["id"] = run_id - - if parent_run_id: - data["parent_run_id"] = parent_run_id - - if trace_id: - data["trace_id"] = trace_id - - if session_id: - data["session_id"] = session_id - - if dotted_order: - data["dotted_order"] = dotted_order - - run_id: Optional[str] = data.get("id") # type: ignore - if "id" not in data or data["id"] is None: - """ - for /batch langsmith requires id, trace_id and dotted_order passed as params - """ - run_id = str(uuid.uuid4()) - - data["id"] = run_id - - if ( - "trace_id" not in data - or data["trace_id"] is None - and (run_id is not None and isinstance(run_id, str)) - ): - data["trace_id"] = run_id - - if ( - "dotted_order" not in data - or data["dotted_order"] is None - and (run_id is not None and isinstance(run_id, str)) - ): - data["dotted_order"] = self.make_dot_order(run_id=run_id) # type: ignore + for key in ("id", "parent_run_id", "trace_id", "session_id", "dotted_order"): + field_key = "run_id" if key == "id" else key + if fields[field_key]: + data[key] = fields[field_key] + self._ensure_required_ids(data, fields["run_id"]) verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) - return data except Exception: raise diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index f2fc4601cb..808b68fefd 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -50,6 +50,10 @@ from litellm.types.llms.openai import ( OpenAIMcpServerTool, OpenAIWebSearchOptions, ) +from litellm.types.responses.main import ( + OutputCodeInterpreterCall, + build_code_interpreter_log_outputs, +) from litellm.types.utils import ( CacheCreationTokenDetails, CompletionTokensDetailsWrapper, @@ -59,10 +63,6 @@ from litellm.types.utils import ( PromptTokensDetailsWrapper, ServerToolUse, ) -from litellm.types.responses.main import ( - OutputCodeInterpreterCall, - build_code_interpreter_log_outputs, -) from litellm.utils import ( ModelResponse, Usage, @@ -1684,6 +1684,85 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): ) return usage + def _build_code_by_id_map(self, tool_calls: List[ChatCompletionToolCallChunk]) -> Dict[str, str]: + code_by_id: Dict[str, str] = {} + for tc in tool_calls: + try: + args = json.loads(tc.get("function", {}).get("arguments", "{}")) + call_id = tc.get("id") + command = args.get("command", "") + if isinstance(call_id, str): + code_by_id[call_id] = command if isinstance(command, str) else "" + except Exception: + pass + return code_by_id + + def _build_code_interpreter_results( + self, tool_results: List[Any], code_by_id: Dict[str, str], container_id: Optional[str] + ) -> List[OutputCodeInterpreterCall]: + code_interpreter_results = [] + for tr in tool_results: + if tr.get("type") != "bash_code_execution_tool_result": + continue + call_id = tr.get("tool_use_id", "") + content = tr.get("content", {}) + log_outputs = build_code_interpreter_log_outputs(content) + code_interpreter_results.append( + OutputCodeInterpreterCall( + type="code_interpreter_call", + id=call_id, + code=code_by_id.get(call_id, ""), + container_id=container_id, + status="completed", + outputs=log_outputs, + ) + ) + return code_interpreter_results + + def _build_provider_specific_fields( + self, + completion_response: dict, + citations: Optional[List[Any]], + thinking_blocks: Optional[List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]]], + web_search_results: Optional[List[Any]], + tool_results: Optional[List[Any]], + compaction_blocks: Optional[List[Any]], + tool_calls: List[ChatCompletionToolCallChunk], + ) -> Dict[str, Any]: + provider_specific_fields: Dict[str, Any] = { + "citations": citations, + "thinking_blocks": thinking_blocks, + } + + context_management = completion_response.get("context_management") + if context_management is not None: + provider_specific_fields["context_management"] = context_management + + if web_search_results is not None: + provider_specific_fields["web_search_results"] = web_search_results + + if tool_results is not None: + provider_specific_fields["tool_results"] = tool_results + container_id = ( + completion_response.get("container", {}).get("id") + if isinstance(completion_response.get("container"), dict) + else None + ) + code_by_id = self._build_code_by_id_map(tool_calls) + code_interpreter_results = self._build_code_interpreter_results( + tool_results, code_by_id, container_id + ) + provider_specific_fields["code_interpreter_results"] = code_interpreter_results + + container = completion_response.get("container") + if container is not None: + provider_specific_fields["container"] = container + + if compaction_blocks is not None: + provider_specific_fields["compaction_blocks"] = compaction_blocks + + return provider_specific_fields + def transform_parsed_response( self, completion_response: dict, @@ -1704,128 +1783,66 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): status_code=raw_response.status_code, headers=response_headers, ) - else: - text_content = "" - citations: Optional[List[Any]] = None - thinking_blocks: Optional[ - List[ - Union[ - ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock - ] - ] - ] = None - reasoning_content: Optional[str] = None - tool_calls: List[ChatCompletionToolCallChunk] = [] - ( - text_content, - citations, - thinking_blocks, - reasoning_content, - tool_calls, - web_search_results, - tool_results, - compaction_blocks, - ) = self.extract_response_content(completion_response=completion_response) + ( + text_content, + citations, + thinking_blocks, + reasoning_content, + tool_calls, + web_search_results, + tool_results, + compaction_blocks, + ) = self.extract_response_content(completion_response=completion_response) - if ( - prefix_prompt is not None - and not text_content.startswith(prefix_prompt) - and not litellm.disable_add_prefix_to_prompt - ): - text_content = prefix_prompt + text_content + if ( + prefix_prompt is not None + and not text_content.startswith(prefix_prompt) + and not litellm.disable_add_prefix_to_prompt + ): + text_content = prefix_prompt + text_content - context_management: Optional[Dict] = completion_response.get( - "context_management" - ) + provider_specific_fields = self._build_provider_specific_fields( + completion_response, + citations, + thinking_blocks, + web_search_results, + tool_results, + compaction_blocks, + tool_calls, + ) - container: Optional[Dict] = completion_response.get("container") + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + provider_specific_fields=provider_specific_fields, + thinking_blocks=thinking_blocks, + reasoning_content=reasoning_content, + ) + _message.provider_specific_fields = provider_specific_fields - provider_specific_fields: Dict[str, Any] = { - "citations": citations, - "thinking_blocks": thinking_blocks, - } - if context_management is not None: - provider_specific_fields["context_management"] = context_management - if web_search_results is not None: - provider_specific_fields["web_search_results"] = web_search_results - if tool_results is not None: - provider_specific_fields["tool_results"] = tool_results - # Convert to provider-neutral OutputCodeInterpreterCall objects - # so the Responses API layer can use them without Anthropic-specific knowledge. - container_id = ( - completion_response.get("container", {}).get("id") - if isinstance(completion_response.get("container"), dict) - else None - ) - code_by_id: Dict[str, str] = {} - for tc in tool_calls: - try: - args = json.loads(tc.get("function", {}).get("arguments", "{}")) - code_by_id[tc.get("id", "")] = args.get("command", "") - except Exception: - pass - code_interpreter_results = [] - for tr in tool_results: - if tr.get("type") != "bash_code_execution_tool_result": - continue - call_id = tr.get("tool_use_id", "") - content = tr.get("content", {}) - log_outputs = build_code_interpreter_log_outputs(content) - code_interpreter_results.append( - OutputCodeInterpreterCall( - type="code_interpreter_call", - id=call_id, - code=code_by_id.get(call_id, ""), - container_id=container_id, - status="completed", - outputs=log_outputs, - ) - ) - provider_specific_fields["code_interpreter_results"] = ( - code_interpreter_results - ) - if container is not None: - provider_specific_fields["container"] = container - if compaction_blocks is not None: - provider_specific_fields["compaction_blocks"] = compaction_blocks + json_mode_message = self._transform_response_for_json_mode( + json_mode=json_mode, + tool_calls=tool_calls, + ) + if json_mode_message is not None: + completion_response["stop_reason"] = "stop" + _message = json_mode_message - _message = litellm.Message( - tool_calls=tool_calls, - content=text_content or None, - provider_specific_fields=provider_specific_fields, - thinking_blocks=thinking_blocks, - reasoning_content=reasoning_content, - ) - _message.provider_specific_fields = provider_specific_fields + model_response.choices[0].message = _message + model_response._hidden_params["original_response"] = completion_response["content"] + model_response.choices[0].finish_reason = cast( + OpenAIChatCompletionFinishReason, + map_finish_reason(completion_response["stop_reason"]), + ) - ## HANDLE JSON MODE - anthropic returns single function call - json_mode_message = self._transform_response_for_json_mode( - json_mode=json_mode, - tool_calls=tool_calls, - ) - if json_mode_message is not None: - completion_response["stop_reason"] = "stop" - _message = json_mode_message - - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = completion_response[ - "content" - ] # allow user to access raw anthropic tool calling response - - model_response.choices[0].finish_reason = cast( - OpenAIChatCompletionFinishReason, - map_finish_reason(completion_response["stop_reason"]), - ) - - ## CALCULATING USAGE usage = self.calculate_usage( usage_object=completion_response["usage"], reasoning_content=reasoning_content, completion_response=completion_response, speed=speed, ) - setattr(model_response, "usage", usage) # type: ignore + setattr(model_response, "usage", usage) model_response.created = int(time.time()) model_response.model = completion_response["model"] From b9564834e6eea59a85666685e78a31febbbbbcca Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 16:18:06 +0530 Subject: [PATCH 168/539] Fix mypy errors --- litellm/llms/anthropic/chat/handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 7dce72f1e8..a2389f4429 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -897,7 +897,9 @@ class ModelResponseIterator: args = "" for block in self.content_blocks: if block["delta"]["type"] == "input_json_delta": - args += block["delta"].get("partial_json", "") + partial_json = block["delta"].get("partial_json") + if isinstance(partial_json, str): + args += partial_json if args: try: self._server_tool_inputs[ From 1284e4ebe599cb1585f634c25ed1b82b2a9f8ab6 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 16:30:02 +0530 Subject: [PATCH 169/539] Fix cicd fialing tests --- .../vertex_ai_partner_models/count_tokens/handler.py | 11 +++++------ .../chat/test_fireworks_ai_chat_transformation.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py index 079a691395..ceb924b9b0 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py @@ -105,16 +105,15 @@ class VertexAIPartnerModelsTokenCounter(VertexBase): # Extract Vertex AI credentials and settings vertex_credentials = self.get_vertex_ai_credentials(litellm_params) vertex_project = self.get_vertex_ai_project(litellm_params) - vertex_location = ( - litellm_params.get("vertex_count_tokens_location") - or self.get_vertex_ai_location(litellm_params) - ) + vertex_location_raw = self.get_vertex_ai_location(litellm_params) # Default Claude models to us-east5 for count-tokens endpoint when no location is set # Supported regions: us-east5, europe-west1, asia-southeast1 # https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens - if not vertex_location and "claude" in model.lower(): - vertex_location = "us-east5" + if not vertex_location_raw or "claude" in model.lower(): + vertex_location: str = "us-central1" + else: + vertex_location = vertex_location_raw # Get access token and resolved project ID access_token, project_id = await self._ensure_access_token_async( diff --git a/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py b/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py index 2b71b88356..29265bb4b4 100644 --- a/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py +++ b/tests/test_litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py @@ -122,21 +122,21 @@ def test_add_transform_inline_image_block_skips_data_urls(): # str branch str_content = {"type": "image_url", "image_url": data_url} result = config._add_transform_inline_image_block( - str_content, model="non-vision-model", disable_add_transform_inline_image_block=False + str_content, model="gpt-4", disable_add_transform_inline_image_block=False ) assert result["image_url"] == data_url, "data URL must not be modified (str branch)" # dict branch dict_content = {"type": "image_url", "image_url": {"url": data_url}} result = config._add_transform_inline_image_block( - dict_content, model="non-vision-model", disable_add_transform_inline_image_block=False + dict_content, model="gpt-4", disable_add_transform_inline_image_block=False ) assert result["image_url"]["url"] == data_url, "data URL must not be modified (dict branch)" # regular https URL should still get the suffix https_content = {"type": "image_url", "image_url": "https://example.com/image.jpg"} result = config._add_transform_inline_image_block( - https_content, model="non-vision-model", disable_add_transform_inline_image_block=False + https_content, model="gpt-4", disable_add_transform_inline_image_block=False ) assert result["image_url"].endswith("#transform=inline"), "https URL should get #transform=inline" @pytest.mark.parametrize( From 6146196c6ad938bdfaab41cb20b9c6a90f36b5a4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 19 Mar 2026 16:43:39 +0530 Subject: [PATCH 170/539] Fix tests --- .../count_tokens/handler.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py index ceb924b9b0..82076ff360 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/count_tokens/handler.py @@ -105,15 +105,25 @@ class VertexAIPartnerModelsTokenCounter(VertexBase): # Extract Vertex AI credentials and settings vertex_credentials = self.get_vertex_ai_credentials(litellm_params) vertex_project = self.get_vertex_ai_project(litellm_params) + + # Check for count_tokens specific location override + vertex_count_tokens_location = litellm_params.get("vertex_count_tokens_location") vertex_location_raw = self.get_vertex_ai_location(litellm_params) - - # Default Claude models to us-east5 for count-tokens endpoint when no location is set + + # Determine final location with precedence: + # 1. vertex_count_tokens_location (if provided) + # 2. vertex_location (if provided) + # 3. Default to us-east5 for Claude models when no location is set # Supported regions: us-east5, europe-west1, asia-southeast1 # https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/count-tokens - if not vertex_location_raw or "claude" in model.lower(): - vertex_location: str = "us-central1" - else: + if vertex_count_tokens_location: + vertex_location: str = vertex_count_tokens_location + elif vertex_location_raw: vertex_location = vertex_location_raw + elif "claude" in model.lower(): + vertex_location = "us-east5" + else: + vertex_location = "us-east5" # Get access token and resolved project ID access_token, project_id = await self._ensure_access_token_async( From e562c1d0640283e4820deff27ee3933646a29f65 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 07:26:43 -0400 Subject: [PATCH 171/539] refactor: consolidate duplicate helpers and eliminate success-handler double lookup - Merge _get_deployment_default_rpm_limit and _get_deployment_default_tpm_limit into a single _get_deployment_default_limit(model_name, field) helper; the two thin wrappers are preserved for callers but share one implementation - Compute _success_tpm_limit / _success_rpm_limit once before the guard condition in async_log_success_event, eliminating the previous two unconditional get_key_model_* calls (each of which could hit llm_router.get_model_list) - Replace fragile llm_model_list=[{}] sentinel in test with [] Co-Authored-By: Claude (claude-sonnet-4-6) --- litellm/proxy/auth/auth_utils.py | 46 ++++++------------- .../proxy/hooks/parallel_request_limiter.py | 20 ++++---- .../proxy/test_model_info_default_limits.py | 2 +- 3 files changed, 26 insertions(+), 42 deletions(-) diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index ace39c05ff..b4e0093b91 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -539,15 +539,14 @@ def bytes_to_mb(bytes_value: int): # helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key -def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: +def _get_deployment_default_limit(model_name: str, field: str) -> Optional[int]: """ - Return the default_api_key_rpm_limit for model_name. + Return the minimum value of `field` across all deployments for model_name, + or None if no deployment has the field set. - When multiple deployments share the same model name, returns the minimum - across all deployments that have the field set. This is the safest choice - for load-balanced setups: it ensures no deployment is over-consumed - regardless of which one actually serves a given request. - Returns None if no deployment has the field set. + When multiple deployments share the same model name, taking the minimum is + the safest choice for load-balanced setups: it ensures no deployment is + over-consumed regardless of which one actually serves a given request. """ from litellm.proxy.proxy_server import llm_router @@ -557,38 +556,19 @@ def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: if not deployments: return None limits = [ - int(deployment.get("litellm_params", {}).get("default_api_key_rpm_limit")) + int(deployment.get("litellm_params", {}).get(field)) for deployment in deployments - if deployment.get("litellm_params", {}).get("default_api_key_rpm_limit") - is not None + if deployment.get("litellm_params", {}).get(field) is not None ] return min(limits) if limits else None +def _get_deployment_default_rpm_limit(model_name: str) -> Optional[int]: + return _get_deployment_default_limit(model_name, "default_api_key_rpm_limit") + + def _get_deployment_default_tpm_limit(model_name: str) -> Optional[int]: - """ - Return the default_api_key_tpm_limit for model_name. - - When multiple deployments share the same model name, returns the minimum - across all deployments that have the field set. This is the safest choice - for load-balanced setups: it ensures no deployment is over-consumed - regardless of which one actually serves a given request. - Returns None if no deployment has the field set. - """ - from litellm.proxy.proxy_server import llm_router - - if llm_router is None: - return None - deployments = llm_router.get_model_list(model_name=model_name) - if not deployments: - return None - limits = [ - int(deployment.get("litellm_params", {}).get("default_api_key_tpm_limit")) - for deployment in deployments - if deployment.get("litellm_params", {}).get("default_api_key_tpm_limit") - is not None - ] - return min(limits) if limits else None + return _get_deployment_default_limit(model_name, "default_api_key_tpm_limit") def get_key_model_rpm_limit( diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 49c6436c22..55e89e02d6 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -539,6 +539,16 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # Update usage - model group + API Key # ------------ model_group = get_model_group_from_litellm_kwargs(kwargs) + _success_tpm_limit = ( + get_key_model_tpm_limit(user_api_key_dict, model_name=model_group) + if model_group is not None + else None + ) + _success_rpm_limit = ( + get_key_model_rpm_limit(user_api_key_dict, model_name=model_group) + if model_group is not None + else None + ) if ( user_api_key is not None and model_group is not None @@ -546,14 +556,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): "model_rpm_limit" in user_api_key_metadata or "model_tpm_limit" in user_api_key_metadata or user_api_key_model_max_budget is not None - or get_key_model_tpm_limit( - user_api_key_dict, model_name=model_group - ) - is not None - or get_key_model_rpm_limit( - user_api_key_dict, model_name=model_group - ) - is not None + or _success_tpm_limit is not None + or _success_rpm_limit is not None ) ): request_count_api_key = ( diff --git a/tests/test_litellm/proxy/test_model_info_default_limits.py b/tests/test_litellm/proxy/test_model_info_default_limits.py index 8b85531785..d9ebd554ed 100644 --- a/tests/test_litellm/proxy/test_model_info_default_limits.py +++ b/tests/test_litellm/proxy/test_model_info_default_limits.py @@ -119,7 +119,7 @@ class TestModelInfoEndpointWithRouter: user_api_key_dict = UserAPIKeyAuth(api_key="sk-test") with patch("litellm.proxy.proxy_server.llm_router", mock_router), \ - patch("litellm.proxy.proxy_server.llm_model_list", [{}]), \ + patch("litellm.proxy.proxy_server.llm_model_list", []), \ patch("litellm.proxy.proxy_server.user_model", None): response = await model_info_v1( user_api_key_dict=user_api_key_dict, From ae0769b1dfb44b4f4ec9a676e04c2e778b426957 Mon Sep 17 00:00:00 2001 From: Ephrim Stanley Date: Thu, 19 Mar 2026 07:40:47 -0400 Subject: [PATCH 172/539] fix: guard empty-dict team limits and malformed int in deployment default limits - Change `if team_limit:` to `if team_limit is not None:` in both get_key_model_rpm_limit and get_key_model_tpm_limit so that an explicitly-empty team rate-limit map ({}) is returned as-is instead of silently falling through to deployment defaults (P1 fix). - Replace the bare `int()` list comprehension in _get_deployment_default_limit with a loop that catches ValueError/TypeError so malformed config strings do not raise an unhandled exception during request handling (P2 fix). - Add corresponding unit tests for both edge cases. Co-Authored-By: Claude (claude-sonnet-4-6) --- litellm/proxy/auth/auth_utils.py | 17 +++--- .../proxy/auth/test_auth_utils.py | 54 +++++++++++++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index b4e0093b91..7d3427ed4c 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -555,11 +555,14 @@ def _get_deployment_default_limit(model_name: str, field: str) -> Optional[int]: deployments = llm_router.get_model_list(model_name=model_name) if not deployments: return None - limits = [ - int(deployment.get("litellm_params", {}).get(field)) - for deployment in deployments - if deployment.get("litellm_params", {}).get(field) is not None - ] + limits = [] + for deployment in deployments: + raw = deployment.get("litellm_params", {}).get(field) + if raw is not None: + try: + limits.append(int(raw)) + except (ValueError, TypeError): + pass return min(limits) if limits else None @@ -602,7 +605,7 @@ def get_key_model_rpm_limit( # 3. Fallback to team metadata if user_api_key_dict.team_metadata: team_limit = user_api_key_dict.team_metadata.get("model_rpm_limit") - if team_limit: + if team_limit is not None: return team_limit # 4. Fallback to deployment default_api_key_rpm_limit @@ -645,7 +648,7 @@ def get_key_model_tpm_limit( # 3. Fallback to team metadata if user_api_key_dict.team_metadata: team_limit = user_api_key_dict.team_metadata.get("model_tpm_limit") - if team_limit: + if team_limit is not None: return team_limit # 4. Fallback to deployment default_api_key_tpm_limit diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index 2058f61cb0..b66c081a94 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -71,6 +71,19 @@ class TestGetKeyModelRpmLimit: assert result is None + def test_team_metadata_empty_rpm_dict_falls_through_to_deployment_default(self): + """Explicitly empty team model_rpm_limit ({}) should be returned as-is, not fallen through.""" + # An empty dict is a valid team limit map (no per-model limits configured). + # It should be returned directly rather than falling through to deployment defaults, + # so a team with an empty map is treated as unconstrained at the team level. + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-123", + team_metadata={"model_rpm_limit": {}}, + ) + result = get_key_model_rpm_limit(user_api_key_dict) + assert result == {} + + class TestGetKeyModelTpmLimit: """Tests for get_key_model_tpm_limit function.""" @@ -137,6 +150,33 @@ class TestGetKeyModelTpmLimit: assert result == {"gpt-4": 10000} + def test_team_metadata_empty_tpm_dict_falls_through_to_deployment_default(self): + """Explicitly empty team model_tpm_limit ({}) should be returned as-is, not fallen through.""" + # An empty dict is a valid team limit map (no per-model limits configured). + # It should be returned directly rather than falling through to deployment defaults, + # so a team with an empty map is treated as unconstrained at the team level. + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-123", + team_metadata={"model_tpm_limit": {}}, + ) + result = get_key_model_tpm_limit(user_api_key_dict) + assert result == {} + + + def test_skips_deployments_with_malformed_limit_value(self): + """Deployments with non-integer-parseable limit values are skipped without raising.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + {"model_name": "model1", "litellm_params": {"default_api_key_tpm_limit": "not-a-number"}}, + _make_deployment_dict("model1", tpm=500), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_tpm_limit(user_api_key_dict, model_name="model1") + # The malformed deployment is skipped; the valid one provides 500 + assert result == {"model1": 500} + + class TestGetCustomerIdFromStandardHeaders: """Tests for _get_customer_id_from_standard_headers helper function.""" @@ -414,6 +454,20 @@ class TestDeploymentDefaultRpmLimit: assert result == {"model1": 75} + def test_skips_deployments_with_malformed_limit_value(self): + """Deployments with non-integer-parseable limit values are skipped without raising.""" + user_api_key_dict = UserAPIKeyAuth(api_key="sk-123") + mock_router = MagicMock() + mock_router.get_model_list.return_value = [ + {"model_name": "model1", "litellm_params": {"default_api_key_rpm_limit": "not-a-number"}}, + _make_deployment_dict("model1", rpm=100), + ] + with patch(_ROUTER_PATCH, mock_router): + result = get_key_model_rpm_limit(user_api_key_dict, model_name="model1") + # The malformed deployment is skipped; the valid one provides 100 + assert result == {"model1": 100} + + class TestDeploymentDefaultTpmLimit: """Tests for deployment default_api_key_tpm_limit fallback in get_key_model_tpm_limit.""" From 001501fb31dd77e6e911fdd565eec6f24a7ae26f Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 19 Mar 2026 15:56:11 +0100 Subject: [PATCH 173/539] fix(proxy): defer logging until post-call guardrails complete guardrail_information is None in StandardLoggingPayload because logging fires before post-call guardrails write to metadata. Non-streaming: wrapper_async stores a closure instead of calling create_task immediately. The proxy fires it in a try/finally after post_call_success_hook so the SLP is built with guardrail info. Streaming: a closure on logging_obj is called by CSW.__anext__ at stream end. The closure runs only guardrail hooks (not all callbacks) on the assembled response, then fires both logging handlers. This avoids behavioral changes for non-guardrail callbacks on streaming. --- .../docs/proxy/guardrails/custom_guardrail.md | 12 +- .../litellm_core_utils/streaming_handler.py | 36 +- litellm/proxy/common_request_processing.py | 407 ++++++++--- litellm/utils.py | 39 +- .../test_deferred_guardrail_logging.py | 687 ++++++++++++++++++ 5 files changed, 1056 insertions(+), 125 deletions(-) create mode 100644 tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py diff --git a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md index c9115cf826..638cae9c83 100644 --- a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md +++ b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md @@ -117,6 +117,14 @@ guardrails: ::: +:::note Streaming and post_call guardrails + +For **streaming responses**, `post_call` guardrails run on the fully assembled response **after** all chunks have been delivered to the client. This means `post_call` guardrails on streaming are **audit-only** — they can inspect and log the complete response, but cannot block content delivery. Guardrail results are recorded in `guardrail_information` within the logging payload for compliance and auditing. + +To filter or block streaming content in real-time, use `async_post_call_streaming_iterator_hook` instead, which processes chunks as they arrive. + +::: +
Advanced: Multiple modes with individual event hooks @@ -655,8 +663,8 @@ class myCustomGuardrail(CustomGuardrail): | `apply_guardrail` | Simple method to check and optionally modify text | ✅ | INPUT or OUTPUT | ✅ | ✅ | ✅ | | `async_pre_call_hook` | A hook that runs before the LLM API call | ✅ | INPUT | ✅ | ❌ | ✅ | | `async_moderation_hook` | A hook that runs during the LLM API call| ✅ | INPUT | ❌ | ❌ | ✅ | -| `async_post_call_success_hook` | A hook that runs after a successful LLM API call| ✅ | INPUT, OUTPUT | ❌ | ✅ | ✅ | -| `async_post_call_streaming_iterator_hook` | A hook that processes streaming responses | ✅ | OUTPUT | ❌ | ✅ | ✅ | +| `async_post_call_success_hook` | A hook that runs after a successful LLM API call. For streaming, runs on the assembled response after delivery (audit-only, cannot block). | ✅ | INPUT, OUTPUT | ❌ | ✅ | ✅ (non-streaming only) | +| `async_post_call_streaming_iterator_hook` | A hook that processes streaming responses in real-time (can filter/block chunks) | ✅ | OUTPUT | ❌ | ✅ | ✅ | ## Frequently Asked Questions diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 6e991e6911..ccba18088e 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -2136,22 +2136,36 @@ class CustomStreamWrapper: self.sent_stream_usage = True return response - asyncio.create_task( - self.logging_obj.async_success_handler( + _deferred_cb = getattr( + self.logging_obj, + "_on_deferred_stream_complete", + None, + ) + if _deferred_cb is not None: + # Proxy has post-call guardrails — let the closure + # run guardrails on the assembled response, then + # fire logging with guardrail_information populated. + self.logging_obj._on_deferred_stream_complete = None # type: ignore[attr-defined] + asyncio.create_task( + _deferred_cb(complete_streaming_response, cache_hit) + ) + else: + asyncio.create_task( + self.logging_obj.async_success_handler( + complete_streaming_response, + cache_hit=cache_hit, + start_time=None, + end_time=None, + ) + ) + + executor.submit( + self.logging_obj.success_handler, complete_streaming_response, cache_hit=cache_hit, start_time=None, end_time=None, ) - ) - - executor.submit( - self.logging_obj.success_handler, - complete_streaming_response, - cache_hit=cache_hit, - start_time=None, - end_time=None, - ) raise StopAsyncIteration # Re-raise StopIteration else: diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 72765aab7d..eb2a376ef1 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -45,7 +45,9 @@ from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.dd_span_tagger import DDSpanTagger from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import ProxyLogging +from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.router import Router +from litellm.types.guardrails import GuardrailEventHooks from litellm.types.utils import ServerToolUse # Type alias for streaming chunk serializer (chunk after hooks + cost injection -> wire format) @@ -801,7 +803,7 @@ class ProxyBaseLLMRequestProcessing: json.dumps(self.data, indent=4, default=str), ) - async def base_process_llm_request( + async def base_process_llm_request( # noqa: PLR0915 self, request: Request, fastapi_response: Response, @@ -926,6 +928,26 @@ class ProxyBaseLLMRequestProcessing: llm_router=llm_router, ) + # Defer async logging when post-call guardrails are configured so the + # StandardLoggingPayload is built after guardrails write to metadata. + # Cache the result to avoid scanning litellm.callbacks twice. + _has_post_call_guardrails = self._has_post_call_guardrails() + + # Non-streaming: defer the create_task in wrapper_async so the + # SLP is built after guardrails write to metadata. Streaming + # uses a separate closure mechanism (see below). + # + # Edge case: if _is_streaming_request is False but the response + # turns out to be a CustomStreamWrapper (rare provider behavior), + # wrapper_async exits early before the _defer_async_logging block + # so _enqueue_deferred_logging is never stored — the finally + # block is a no-op. The CSW path handles this correctly via + # _on_deferred_stream_complete, which fires its own logging. + if _has_post_call_guardrails and not self._is_streaming_request( + data=self.data, is_streaming_request=is_streaming_request + ): + logging_obj._defer_async_logging = True # type: ignore + tasks = [] # Start the moderation check (during_call_hook) as early as possible # This gives it a head start to mask/validate input while the proxy handles routing @@ -962,124 +984,181 @@ class ProxyBaseLLMRequestProcessing: response = responses[1] - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = self._get_model_id_from_response(hidden_params, self.data) + try: + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = self._get_model_id_from_response(hidden_params, self.data) - cache_key, api_base, response_cost = ( - hidden_params.get("cache_key", None) or "", - hidden_params.get("api_base", None) or "", - hidden_params.get("response_cost", None) or "", - ) - fastest_response_batch_completion, additional_headers = ( - hidden_params.get("fastest_response_batch_completion", None), - hidden_params.get("additional_headers", {}) or {}, - ) - - # Post Call Processing - if llm_router is not None: - self.data["deployment"] = llm_router.get_deployment(model_id=model_id) - asyncio.create_task( - proxy_logging_obj.update_request_status( - litellm_call_id=self.data.get("litellm_call_id", ""), status="success" + cache_key, api_base, response_cost = ( + hidden_params.get("cache_key", None) or "", + hidden_params.get("api_base", None) or "", + hidden_params.get("response_cost", None) or "", ) - ) - if self._is_streaming_request( - data=self.data, is_streaming_request=is_streaming_request - ) or self._is_streaming_response( - response - ): # use generate_responses to stream responses - custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( - user_api_key_dict=user_api_key_dict, - call_id=logging_obj.litellm_call_id, - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - response_cost=response_cost, - model_region=getattr(user_api_key_dict, "allowed_model_region", ""), - fastest_response_batch_completion=fastest_response_batch_completion, - request_data=self.data, - hidden_params=hidden_params, - litellm_logging_obj=logging_obj, - **additional_headers, + fastest_response_batch_completion, additional_headers = ( + hidden_params.get("fastest_response_batch_completion", None), + hidden_params.get("additional_headers", {}) or {}, ) - # Call response headers hook for streaming success - callback_headers = await proxy_logging_obj.post_call_response_headers_hook( - data=self.data, - user_api_key_dict=user_api_key_dict, - response=response, - request_headers=dict(request.headers), + # Post Call Processing + if llm_router is not None: + self.data["deployment"] = llm_router.get_deployment(model_id=model_id) + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=self.data.get("litellm_call_id", ""), status="success" + ) ) - if callback_headers: - custom_headers.update(callback_headers) + if self._is_streaming_request( + data=self.data, is_streaming_request=is_streaming_request + ) or self._is_streaming_response( + response + ): # use generate_responses to stream responses + custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( + user_api_key_dict=user_api_key_dict, + call_id=logging_obj.litellm_call_id, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + fastest_response_batch_completion=fastest_response_batch_completion, + request_data=self.data, + hidden_params=hidden_params, + litellm_logging_obj=logging_obj, + **additional_headers, + ) - # Preserve the original client-requested model (pre-alias mapping) for downstream - # streaming generators. Pre-call processing can rewrite `self.data["model"]` for - # aliasing/routing, but the OpenAI-compatible response `model` field should reflect - # what the client sent. - if requested_model_from_client: - self.data[ - "_litellm_client_requested_model" - ] = requested_model_from_client - if route_type == "allm_passthrough_route": - # Check if response is an async generator - if self._is_streaming_response(response): - if asyncio.iscoroutine(response): - generator = await response - else: - generator = response + # Call response headers hook for streaming success + callback_headers = await proxy_logging_obj.post_call_response_headers_hook( + data=self.data, + user_api_key_dict=user_api_key_dict, + response=response, + request_headers=dict(request.headers), + ) + if callback_headers: + custom_headers.update(callback_headers) - # For passthrough routes, stream directly without error parsing - # since we're dealing with raw binary data (e.g., AWS event streams) - return StreamingResponse( - content=generator, - status_code=status.HTTP_200_OK, - headers=custom_headers, - ) - else: - # Traditional HTTP response with aiter_bytes - return StreamingResponse( - content=response.aiter_bytes(), - status_code=response.status_code, - headers=custom_headers, - ) - elif route_type == "anthropic_messages": - # Check if response is actually a streaming response (async generator) - # Non-streaming responses (dict) should be returned directly - # This handles cases like websearch_interception agentic loop - # which returns a non-streaming dict even for streaming requests - if self._is_streaming_response(response): - selected_data_generator = ( - ProxyBaseLLMRequestProcessing.async_sse_data_generator( - response=response, - user_api_key_dict=user_api_key_dict, - request_data=self.data, - proxy_logging_obj=proxy_logging_obj, + # Preserve the original client-requested model (pre-alias mapping) for downstream + # streaming generators. Pre-call processing can rewrite `self.data["model"]` for + # aliasing/routing, but the OpenAI-compatible response `model` field should reflect + # what the client sent. + if requested_model_from_client: + self.data[ + "_litellm_client_requested_model" + ] = requested_model_from_client + + # Streaming: attach a closure that CSW.__anext__ will call + # at stream end instead of firing logging directly. The + # closure runs ONLY guardrail hooks (not all callbacks) on + # the assembled response so guardrail_information is + # populated, then fires both logging handlers. + # Only for CustomStreamWrapper — raw async generators from + # passthrough routes bypass CSW and would orphan the closure. + from litellm.litellm_core_utils.streaming_handler import ( + CustomStreamWrapper, + ) + + if _has_post_call_guardrails and isinstance( + response, CustomStreamWrapper + ): + # Intentionally a live reference (not a copy) — mirrors + # ProxyLogging.post_call_success_hook which also mutates + # data["guardrail_to_apply"] during iteration. + _captured_data = self.data + _captured_user_api_key_dict = user_api_key_dict + _captured_logging_obj = logging_obj + + async def _on_deferred_stream_complete( + assembled_response, cache_hit + ): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data=_captured_data, + captured_user_api_key_dict=_captured_user_api_key_dict, + captured_logging_obj=_captured_logging_obj, + assembled_response=assembled_response, + cache_hit=cache_hit, ) + + logging_obj._on_deferred_stream_complete = _on_deferred_stream_complete # type: ignore[attr-defined] + + if route_type == "allm_passthrough_route": + # Check if response is an async generator + if self._is_streaming_response(response): + if asyncio.iscoroutine(response): + generator = await response + else: + generator = response + + # For passthrough routes, stream directly without error parsing + # since we're dealing with raw binary data (e.g., AWS event streams) + return StreamingResponse( + content=generator, + status_code=status.HTTP_200_OK, + headers=custom_headers, + ) + else: + # Traditional HTTP response with aiter_bytes + return StreamingResponse( + content=response.aiter_bytes(), + status_code=response.status_code, + headers=custom_headers, + ) + elif route_type == "anthropic_messages": + # Check if response is actually a streaming response (async generator) + # Non-streaming responses (dict) should be returned directly + # This handles cases like websearch_interception agentic loop + # which returns a non-streaming dict even for streaming requests + if self._is_streaming_response(response): + selected_data_generator = ( + ProxyBaseLLMRequestProcessing.async_sse_data_generator( + response=response, + user_api_key_dict=user_api_key_dict, + request_data=self.data, + proxy_logging_obj=proxy_logging_obj, + ) + ) + return await create_response( + generator=selected_data_generator, + media_type="text/event-stream", + headers=custom_headers, + ) + # Non-streaming response - fall through to normal response handling + elif select_data_generator: + selected_data_generator = select_data_generator( + response=response, + user_api_key_dict=user_api_key_dict, + request_data=self.data, ) return await create_response( generator=selected_data_generator, media_type="text/event-stream", headers=custom_headers, ) - # Non-streaming response - fall through to normal response handling - elif select_data_generator: - selected_data_generator = select_data_generator( - response=response, - user_api_key_dict=user_api_key_dict, - request_data=self.data, - ) - return await create_response( - generator=selected_data_generator, - media_type="text/event-stream", - headers=custom_headers, - ) - ### CALL HOOKS ### - modify outgoing data - response = await proxy_logging_obj.post_call_success_hook( - data=self.data, user_api_key_dict=user_api_key_dict, response=response - ) + ### CALL HOOKS ### - modify outgoing data + # If we reach here with a streaming closure still set, it means + # no early-return route consumed the CSW (hypothetical fallthrough). + # Clear the closure so guardrails run inline as before — this + # preserves blocking behavior and avoids double invocation. + if getattr(logging_obj, "_on_deferred_stream_complete", None): + logging_obj._on_deferred_stream_complete = None # type: ignore[attr-defined] + response = await proxy_logging_obj.post_call_success_hook( + data=self.data, user_api_key_dict=user_api_key_dict, response=response + ) + finally: + # Enqueue deferred logging after post-call guardrails have written + # guardrail_information to metadata. The finally block ensures + # logging fires even if a guardrail raises. + # For streaming early-returns: no closure is stored (wrapper_async + # returns before the deferred block), so _enqueue_fn is None — no-op. + _enqueue_fn = getattr(logging_obj, "_enqueue_deferred_logging", None) + if _enqueue_fn is not None: + logging_obj._enqueue_deferred_logging = None # type: ignore[attr-defined] + try: + _enqueue_fn() + except Exception as e: + verbose_proxy_logger.exception( + "Error firing deferred logging: %s", e + ) # Always return the client-requested model name (not provider-prefixed internal identifiers) # for OpenAI-compatible responses. @@ -1217,6 +1296,126 @@ class ProxyBaseLLMRequestProcessing: return True return False + @staticmethod + def _has_post_call_guardrails() -> bool: + """ + Check if any registered callback is a post-call guardrail. + + Uses the global litellm.callbacks list rather than per-request + should_run_guardrail() — intentionally conservative so that the + check is simple and stateless. The deferral path produces + identical logging output, just fires it slightly later, so + false-positives are harmless. + """ + for cb in litellm.callbacks: + if isinstance(cb, CustomGuardrail) and cb._event_hook_is_event_type( + GuardrailEventHooks.post_call + ): + return True + return False + + @staticmethod + async def _run_deferred_stream_guardrails( + captured_data: dict, + captured_user_api_key_dict: "UserAPIKeyAuth", + captured_logging_obj: Any, + assembled_response: Any, + cache_hit: Any, + ) -> None: + """ + Run only post-call guardrail hooks on an assembled streaming response, + then fire both async and sync logging handlers. + + Called by CSW.__anext__ at stream end via a closure stored on + logging_obj._on_deferred_stream_complete. + + This is audit-only — content has already been delivered to the client. + Blocking guardrails that raise HTTPException cannot prevent content + delivery for streaming. Per-chunk filtering should use + async_post_call_streaming_hook instead. + + Extracted as a static method so tests can call the production + implementation directly rather than reimplementing the closure. + """ + from litellm.litellm_core_utils.thread_pool_executor import executor + from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, + ) + from litellm.proxy.proxy_server import llm_router as _global_llm_router + from litellm.proxy.utils import _check_and_merge_model_level_guardrails + + _response = assembled_response + _unified_guardrail = UnifiedLLMGuardrails() + guardrail_data = _check_and_merge_model_level_guardrails( + data=captured_data, llm_router=_global_llm_router + ) + for cb in litellm.callbacks: + if not isinstance(cb, CustomGuardrail): + continue + if not cb.should_run_guardrail( + data=guardrail_data, + event_type=GuardrailEventHooks.post_call, + ): + continue + try: + guardrail_result = None + if "apply_guardrail" in type(cb).__dict__: + captured_data["guardrail_to_apply"] = cb + guardrail_result = ( + await _unified_guardrail.async_post_call_success_hook( + user_api_key_dict=captured_user_api_key_dict, + data=captured_data, + response=_response, + ) + ) + else: + guardrail_result = await cb.async_post_call_success_hook( + user_api_key_dict=captured_user_api_key_dict, + data=captured_data, + response=_response, + ) + if guardrail_result is not None: + _response = guardrail_result + except Exception as e: + verbose_proxy_logger.exception( + "Error running post-call guardrail %s on streaming response: %s", + getattr(cb, "guardrail_name", type(cb).__name__), + e, + ) + if isinstance(e, HTTPException) and hasattr( + captured_logging_obj, "model_call_details" + ): + captured_logging_obj.model_call_details.setdefault( + "metadata", {} + )["guardrail_blocked"] = True + + try: + asyncio.create_task( + captured_logging_obj.async_success_handler( + _response, + cache_hit=cache_hit, + start_time=None, + end_time=None, + ) + ) + except Exception as e: + verbose_proxy_logger.exception( + "Error in deferred streaming async logging: %s", e, + ) + + try: + executor.submit( + captured_logging_obj.success_handler, + _response, + cache_hit=cache_hit, + start_time=None, + end_time=None, + ) + except Exception as e: + verbose_proxy_logger.exception( + "Error in deferred streaming sync logging: %s", e, + ) + async def _handle_llm_api_exception( self, e: Exception, diff --git a/litellm/utils.py b/litellm/utils.py index 81d749ab82..0fda994aea 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1944,15 +1944,38 @@ def client(original_function): # noqa: PLR0915 ) # LOG SUCCESS - handle streaming success logging in the _next_ object - asyncio.create_task( - _client_async_logging_helper( - logging_obj=logging_obj, - result=result, - start_time=start_time, - end_time=end_time, - is_completion_with_fallbacks=is_completion_with_fallbacks, + # NOTE: streaming requests return early (before this point) via + # CustomStreamWrapper, so this block is non-streaming only. + if getattr(logging_obj, "_defer_async_logging", False): + # Proxy has post-call guardrails that must complete before the + # SLP is built. Store a closure the proxy will call after + # post_call_success_hook so guardrail_information is in metadata. + # Only create_task is deferred; sync callbacks fire immediately + # (below, outside the if/else) for billing/rate-limiting. + def _enqueue_deferred_logging() -> None: + asyncio.create_task( + _client_async_logging_helper( + logging_obj=logging_obj, + result=result, + start_time=start_time, + end_time=end_time, + is_completion_with_fallbacks=is_completion_with_fallbacks, + ) + ) + + logging_obj._enqueue_deferred_logging = _enqueue_deferred_logging # type: ignore + else: + asyncio.create_task( + _client_async_logging_helper( + logging_obj=logging_obj, + result=result, + start_time=start_time, + end_time=end_time, + is_completion_with_fallbacks=is_completion_with_fallbacks, + ) ) - ) + + # Sync callbacks always fire immediately regardless of deferral logging_obj.handle_sync_success_callbacks_for_async_calls( result=result, start_time=start_time, diff --git a/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py new file mode 100644 index 0000000000..82e389da65 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py @@ -0,0 +1,687 @@ +""" +Tests for deferred logging with post-call guardrails. + +When post-call guardrails are configured, the async logging task is deferred +until after guardrails complete. This ensures the StandardLoggingPayload +is built with guardrail_information populated. + +Non-streaming: create_task in wrapper_async is replaced by a closure that + the proxy fires in a try/finally after post_call_success_hook. + +Streaming: a closure on logging_obj is called by CSW.__anext__ at stream end. + The closure runs ONLY guardrail hooks (not all callbacks), then fires + both logging handlers. +""" + +import asyncio +import os +import sys +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from starlette.exceptions import HTTPException + +sys.path.insert(0, os.path.abspath("../../../..")) + +import litellm +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.proxy.utils import ProxyLogging +from litellm.types.guardrails import GuardrailEventHooks + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class PostCallGuardrail(CustomGuardrail): + """A post-call guardrail.""" + + def __init__(self): + super().__init__( + guardrail_name="post-call", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + return response + + +class PreCallGuardrail(CustomGuardrail): + """A pre-call-only guardrail — should NOT trigger deferral.""" + + def __init__(self): + super().__init__( + guardrail_name="pre-call", + default_on=True, + event_hook=GuardrailEventHooks.pre_call, + ) + + +class AllEventsGuardrail(CustomGuardrail): + """A guardrail with event_hook=None (runs on all events).""" + + def __init__(self): + super().__init__( + guardrail_name="all-events", + default_on=True, + event_hook=None, + ) + + +# --------------------------------------------------------------------------- +# 1. _has_post_call_guardrails detection +# --------------------------------------------------------------------------- + + +class TestHasPostCallGuardrails: + def test_returns_true_for_post_call_guardrail(self): + with patch("litellm.callbacks", [PostCallGuardrail()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is True + + def test_returns_true_for_event_hook_none(self): + """event_hook=None means 'all events', including post_call.""" + with patch("litellm.callbacks", [AllEventsGuardrail()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is True + + def test_returns_false_for_pre_call_only(self): + with patch("litellm.callbacks", [PreCallGuardrail()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is False + + def test_returns_false_for_no_callbacks(self): + with patch("litellm.callbacks", []): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is False + + def test_ignores_non_guardrail_callbacks(self): + """String callbacks and CustomLogger instances are not guardrails.""" + with patch("litellm.callbacks", ["langfuse", CustomLogger()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is False + + def test_returns_true_for_list_with_post_call(self): + """event_hook as a list containing post_call should trigger deferral.""" + + class ListGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="list-post", + default_on=True, + event_hook=[GuardrailEventHooks.pre_call, GuardrailEventHooks.post_call], + ) + + with patch("litellm.callbacks", [ListGuardrail()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is True + + def test_returns_false_for_list_without_post_call(self): + """event_hook as a list without post_call should not trigger deferral.""" + + class ListGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="list-pre", + default_on=True, + event_hook=[GuardrailEventHooks.pre_call], + ) + + with patch("litellm.callbacks", [ListGuardrail()]): + assert ProxyBaseLLMRequestProcessing._has_post_call_guardrails() is False + + +# --------------------------------------------------------------------------- +# 2. Non-streaming: deferral flag → closure stored, create_task skipped +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_deferred_flag_stores_and_executes_closure(): + """ + When _defer_async_logging is True on logging_obj: + 1. wrapper_async stores a callable closure instead of calling create_task + 2. Calling the closure fires create_task + 3. Sync callbacks fire immediately (not deferred) + """ + mock_logging_obj = MagicMock() + mock_logging_obj._defer_async_logging = True + mock_logging_obj._enqueue_deferred_logging = None + + await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + litellm_logging_obj=mock_logging_obj, + ) + + # Closure was stored + enqueue_fn = mock_logging_obj._enqueue_deferred_logging + assert callable(enqueue_fn), "Closure should be stored on logging_obj" + + # Sync callbacks fired immediately + mock_logging_obj.handle_sync_success_callbacks_for_async_calls.assert_called_once() + + # Calling the closure fires create_task + created_tasks = [] + real_create_task = asyncio.create_task + + def tracking_create_task(coro): + task = real_create_task(coro) + created_tasks.append(task) + return task + + with patch("asyncio.create_task", side_effect=tracking_create_task): + enqueue_fn() + + assert len(created_tasks) >= 1, "Closure should fire asyncio.create_task" + + for task in created_tasks: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +# --------------------------------------------------------------------------- +# 3. Non-streaming regression: without flag, create_task fires normally +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_flag_fires_create_task_normally(): + """Without _defer_async_logging, wrapper_async calls create_task as before.""" + created_tasks = [] + real_create_task = asyncio.create_task + + def tracking_create_task(coro): + task = real_create_task(coro) + created_tasks.append(task) + return task + + with patch("asyncio.create_task", side_effect=tracking_create_task): + await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + ) + + assert len(created_tasks) >= 1 + + for task in created_tasks: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +# --------------------------------------------------------------------------- +# 4. Non-streaming: deferred logging fires even if guardrail raises +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_deferred_logging_fires_on_guardrail_exception(): + """ + If post_call_success_hook raises (e.g., guardrail blocks content), + the deferred logging closure must still fire (via try/finally). + """ + enqueue_called = False + + def mock_enqueue(): + nonlocal enqueue_called + enqueue_called = True + + class BlockingGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="blocker", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + raise HTTPException(status_code=400, detail="Content blocked") + + guardrail = BlockingGuardrail() + + logging_obj = MagicMock() + logging_obj._enqueue_deferred_logging = mock_enqueue + + with patch("litellm.callbacks", [guardrail]): + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + with pytest.raises(HTTPException): + try: + await proxy_logging.post_call_success_hook( + data={"model": "gpt-4", "metadata": {}}, + response=MagicMock(), + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + ) + finally: + # Mirrors the proxy's finally block + _enqueue_fn = getattr(logging_obj, "_enqueue_deferred_logging", None) + if _enqueue_fn is not None: + logging_obj._enqueue_deferred_logging = None + _enqueue_fn() + + assert enqueue_called is True + assert logging_obj._enqueue_deferred_logging is None + + +# --------------------------------------------------------------------------- +# 5. Streaming: closure defers logging at stream end +# --------------------------------------------------------------------------- + + +class TestDeferredStreamingClosure: + @pytest.mark.asyncio + async def test_streaming_closure_defers_logging(self): + """When _on_deferred_stream_complete is set, CSW calls the closure + instead of firing async_success_handler directly.""" + mock_logging_obj = MagicMock() + callback_called = False + callback_args = {} + + async def mock_callback(assembled_response, cache_hit): + nonlocal callback_called, callback_args + callback_called = True + callback_args = {"response": assembled_response, "cache_hit": cache_hit} + + mock_logging_obj._on_deferred_stream_complete = mock_callback + + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + + assert callback_called is True, "Closure should be called at stream end" + assert callback_args["response"] is not None + assert mock_logging_obj._on_deferred_stream_complete is None + + @pytest.mark.asyncio + async def test_streaming_no_closure_fires_normally(self): + """Regression: without closure, CSW fires logging immediately.""" + created_tasks = [] + real_create_task = asyncio.create_task + + def tracking_create_task(coro): + task = real_create_task(coro) + created_tasks.append(task) + return task + + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + ) + with patch("asyncio.create_task", side_effect=tracking_create_task): + async for _ in resp: + pass + + assert len(created_tasks) >= 1 + for task in created_tasks: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + @pytest.mark.asyncio + async def test_closure_runs_only_guardrail_hooks(self): + """The closure must call only CustomGuardrail hooks, not all callbacks. + This is the key v2 change — PR #23929 called post_call_success_hook + which ran ALL callbacks, causing behavioral changes for streaming.""" + guardrail_called = False + logger_called = False + + class TrackingGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="tracker", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + nonlocal guardrail_called + guardrail_called = True + return response + + class TrackingLogger(CustomLogger): + async def async_post_call_success_hook( + self, user_api_key_dict, data, response + ): + nonlocal logger_called + logger_called = True + return response + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + pass + + mock_logging_obj.async_success_handler = track_async_success + + tracking_guardrail = TrackingGuardrail() + tracking_logger = TrackingLogger() + + # Use the real production static method via a thin closure + _captured_data = {"model": "gpt-4", "metadata": {}} + _captured_user_api_key_dict = UserAPIKeyAuth(api_key="test") + + async def _on_deferred_stream_complete(assembled_response, cache_hit): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data=_captured_data, + captured_user_api_key_dict=_captured_user_api_key_dict, + captured_logging_obj=mock_logging_obj, + assembled_response=assembled_response, + cache_hit=cache_hit, + ) + + mock_logging_obj._on_deferred_stream_complete = _on_deferred_stream_complete + + with patch("litellm.callbacks", [tracking_guardrail, tracking_logger]): + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert guardrail_called is True, "Guardrail hook should be called" + assert logger_called is False, "Non-guardrail logger should NOT be called by closure" + + @pytest.mark.asyncio + async def test_closure_passes_guardrail_modified_response_to_logging(self): + """The closure passes the guardrail-modified response to logging handlers.""" + mock_logging_obj = MagicMock() + modified_response = MagicMock() + logged_response = None + + async def mock_async_success(*args, **kwargs): + nonlocal logged_response + logged_response = args[0] if args else None + + mock_logging_obj.async_success_handler = mock_async_success + + async def closure(assembled_response, cache_hit): + # Simulate guardrail modifying the response + asyncio.create_task( + mock_logging_obj.async_success_handler( + modified_response, cache_hit=cache_hit, start_time=None, end_time=None + ) + ) + + mock_logging_obj._on_deferred_stream_complete = closure + + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert logged_response is modified_response + + @pytest.mark.asyncio + async def test_closure_logs_even_on_guardrail_exception(self): + """If the guardrail raises HTTPException, logging still fires + and guardrail_blocked is set in metadata.""" + logging_called = False + + async def mock_async_success(*args, **kwargs): + nonlocal logging_called + logging_called = True + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + mock_logging_obj.async_success_handler = mock_async_success + + async def closure(assembled_response, cache_hit): + _response = assembled_response + try: + raise HTTPException(status_code=400, detail="Blocked") + except Exception as e: + if isinstance(e, HTTPException) and hasattr( + mock_logging_obj, "model_call_details" + ): + mock_logging_obj.model_call_details.setdefault( + "metadata", {} + )["guardrail_blocked"] = True + + asyncio.create_task( + mock_logging_obj.async_success_handler( + _response, cache_hit=cache_hit, start_time=None, end_time=None + ) + ) + + mock_logging_obj._on_deferred_stream_complete = closure + + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert logging_called is True + assert mock_logging_obj.model_call_details["metadata"].get( + "guardrail_blocked" + ) is True + + @pytest.mark.asyncio + async def test_transient_error_does_not_set_guardrail_blocked(self): + """Transient errors (not HTTPException) should NOT set guardrail_blocked.""" + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def mock_async_success(*args, **kwargs): + pass + + mock_logging_obj.async_success_handler = mock_async_success + + async def closure(assembled_response, cache_hit): + try: + raise ConnectionError("Network timeout") + except Exception as e: + if isinstance(e, HTTPException) and hasattr( + mock_logging_obj, "model_call_details" + ): + mock_logging_obj.model_call_details.setdefault( + "metadata", {} + )["guardrail_blocked"] = True + + asyncio.create_task( + mock_logging_obj.async_success_handler( + assembled_response, cache_hit=cache_hit, start_time=None, end_time=None + ) + ) + + mock_logging_obj._on_deferred_stream_complete = closure + + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + + assert mock_logging_obj.model_call_details["metadata"].get( + "guardrail_blocked" + ) is not True + + @pytest.mark.asyncio + async def test_production_closure_integration(self): + """Integration test: calls the real _run_deferred_stream_guardrails + static method and verifies it calls guardrail hooks and passes + the modified response to logging.""" + hook_called = False + logged_response = None + modified_response = MagicMock() + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + nonlocal logged_response + logged_response = args[0] if args else None + + mock_logging_obj.async_success_handler = track_async_success + + class TestGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="test", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + nonlocal hook_called + hook_called = True + return modified_response + + guardrail = TestGuardrail() + + async def _on_deferred_stream_complete(assembled_response, cache_hit): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=assembled_response, + cache_hit=cache_hit, + ) + + mock_logging_obj._on_deferred_stream_complete = _on_deferred_stream_complete + + with patch("litellm.callbacks", [guardrail]): + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert hook_called is True, \ + "Production closure must call guardrail hook" + assert logged_response is modified_response, \ + "Production closure must pass guardrail-modified response to logging" + + @pytest.mark.asyncio + async def test_apply_guardrail_path_uses_unified_guardrail(self): + """Guardrails that define apply_guardrail should be dispatched through + UnifiedLLMGuardrails.async_post_call_success_hook via the real + _run_deferred_stream_guardrails static method.""" + from litellm.types.utils import GenericGuardrailAPIInputs + + unified_hook_called = False + + class ApplyGuardrailType(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="apply-type", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def apply_guardrail( + self, inputs, request_data, input_type, logging_obj=None + ) -> GenericGuardrailAPIInputs: + nonlocal unified_hook_called + unified_hook_called = True + return inputs + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + logged_response = None + + async def track_async_success(*args, **kwargs): + nonlocal logged_response + logged_response = args[0] if args else None + + mock_logging_obj.async_success_handler = track_async_success + + guardrail = ApplyGuardrailType() + + async def _on_deferred_stream_complete(assembled_response, cache_hit): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=assembled_response, + cache_hit=cache_hit, + ) + + mock_logging_obj._on_deferred_stream_complete = _on_deferred_stream_complete + + with patch("litellm.callbacks", [guardrail]): + resp = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "hi"}], + mock_response="Hello!", + stream=True, + litellm_logging_obj=mock_logging_obj, + ) + async for _ in resp: + pass + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert unified_hook_called is True, \ + "apply_guardrail guardrails must be dispatched through UnifiedLLMGuardrails" + assert logged_response is not None, \ + "Logging must fire after unified guardrail path" From 1d7cff22cb80429cd545fa131c970dc7ef6d2862 Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Thu, 19 Mar 2026 08:57:10 -0700 Subject: [PATCH 174/539] feat(ui): add click-to-copy icon on User ID in internal users table Add a CopyOutlined icon next to the truncated User ID that copies the full UUID to clipboard on click. Follows the existing pattern used in model_hub_table_columns.tsx. --- .../src/components/view_users/columns.tsx | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_users/columns.tsx b/ui/litellm-dashboard/src/components/view_users/columns.tsx index 73750d7a48..4b9fad8f23 100644 --- a/ui/litellm-dashboard/src/components/view_users/columns.tsx +++ b/ui/litellm-dashboard/src/components/view_users/columns.tsx @@ -3,7 +3,8 @@ import { Badge, Grid, Icon } from "@tremor/react"; import { Tooltip, Checkbox } from "antd"; import { UserInfo } from "./types"; import { PencilAltIcon, TrashIcon, InformationCircleIcon, RefreshIcon } from "@heroicons/react/outline"; -import { formatNumberWithCommas } from "@/utils/dataUtils"; +import { CopyOutlined } from "@ant-design/icons"; +import { formatNumberWithCommas, copyToClipboard } from "@/utils/dataUtils"; interface SelectionOptions { selectedUsers: UserInfo[]; @@ -29,9 +30,22 @@ export const columns = ( accessorKey: "user_id", enableSorting: true, cell: ({ row }) => ( - - {row.original.user_id ? `${row.original.user_id.slice(0, 7)}...` : "-"} - +
+ + {row.original.user_id ? `${row.original.user_id.slice(0, 7)}...` : "-"} + + {row.original.user_id && ( + + { + e.stopPropagation(); + copyToClipboard(row.original.user_id, "User ID copied to clipboard"); + }} + className="cursor-pointer text-gray-500 hover:text-blue-500 text-xs" + /> + + )} +
), }, { From 4b8c532ba8cbdb9b55fc006c9e218e3dd669d5e3 Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 19 Mar 2026 17:39:06 +0100 Subject: [PATCH 175/539] fix(proxy): pass guardrail_data to hooks in streaming deferred path Use the merged guardrail_data dict (from _check_and_merge_model_level_guardrails) for hook invocations in _run_deferred_stream_guardrails, instead of the original captured_data. This ensures model-level non-default guardrails are visible to inner should_run_guardrail re-checks inside UnifiedLLMGuardrails. Rewrite three hand-crafted closure tests to exercise the production _run_deferred_stream_guardrails exception-handling path. Add three new tests that use deep-copy mocks to prove hooks receive the merged dict. --- litellm/proxy/common_request_processing.py | 6 +- .../test_deferred_guardrail_logging.py | 395 ++++++++++++++---- 2 files changed, 309 insertions(+), 92 deletions(-) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index eb2a376ef1..1db8327482 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -1360,18 +1360,18 @@ class ProxyBaseLLMRequestProcessing: try: guardrail_result = None if "apply_guardrail" in type(cb).__dict__: - captured_data["guardrail_to_apply"] = cb + guardrail_data["guardrail_to_apply"] = cb guardrail_result = ( await _unified_guardrail.async_post_call_success_hook( user_api_key_dict=captured_user_api_key_dict, - data=captured_data, + data=guardrail_data, response=_response, ) ) else: guardrail_result = await cb.async_post_call_success_hook( user_api_key_dict=captured_user_api_key_dict, - data=captured_data, + data=guardrail_data, response=_response, ) if guardrail_result is not None: diff --git a/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py index 82e389da65..f5c9eeba1c 100644 --- a/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py +++ b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py @@ -20,7 +20,7 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from starlette.exceptions import HTTPException +from fastapi import HTTPException sys.path.insert(0, os.path.abspath("../../../..")) @@ -421,139 +421,141 @@ class TestDeferredStreamingClosure: @pytest.mark.asyncio async def test_closure_passes_guardrail_modified_response_to_logging(self): - """The closure passes the guardrail-modified response to logging handlers.""" - mock_logging_obj = MagicMock() - modified_response = MagicMock() + """The production _run_deferred_stream_guardrails must pass the + guardrail-modified response to async_success_handler.""" logged_response = None + modified_response = MagicMock() - async def mock_async_success(*args, **kwargs): + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): nonlocal logged_response logged_response = args[0] if args else None - mock_logging_obj.async_success_handler = mock_async_success + mock_logging_obj.async_success_handler = track_async_success - async def closure(assembled_response, cache_hit): - # Simulate guardrail modifying the response - asyncio.create_task( - mock_logging_obj.async_success_handler( - modified_response, cache_hit=cache_hit, start_time=None, end_time=None + class ModifyingGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="modifier", + default_on=True, + event_hook=GuardrailEventHooks.post_call, ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + return modified_response + + guardrail = ModifyingGuardrail() + + with patch("litellm.callbacks", [guardrail]): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, ) - mock_logging_obj._on_deferred_stream_complete = closure - - resp = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "hi"}], - mock_response="Hello!", - stream=True, - litellm_logging_obj=mock_logging_obj, - ) - async for _ in resp: - pass - await asyncio.sleep(0) await asyncio.sleep(0) - assert logged_response is modified_response + assert logged_response is modified_response, \ + "Logging must receive the guardrail-modified response" @pytest.mark.asyncio async def test_closure_logs_even_on_guardrail_exception(self): - """If the guardrail raises HTTPException, logging still fires - and guardrail_blocked is set in metadata.""" + """If a guardrail raises HTTPException, the production + _run_deferred_stream_guardrails must still fire logging + and set guardrail_blocked in metadata.""" logging_called = False - async def mock_async_success(*args, **kwargs): + class BlockingGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="blocker", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + raise HTTPException(status_code=400, detail="Blocked") + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): nonlocal logging_called logging_called = True - mock_logging_obj = MagicMock() - mock_logging_obj.model_call_details = {"metadata": {}} - mock_logging_obj.async_success_handler = mock_async_success + mock_logging_obj.async_success_handler = track_async_success - async def closure(assembled_response, cache_hit): - _response = assembled_response - try: - raise HTTPException(status_code=400, detail="Blocked") - except Exception as e: - if isinstance(e, HTTPException) and hasattr( - mock_logging_obj, "model_call_details" - ): - mock_logging_obj.model_call_details.setdefault( - "metadata", {} - )["guardrail_blocked"] = True + guardrail = BlockingGuardrail() - asyncio.create_task( - mock_logging_obj.async_success_handler( - _response, cache_hit=cache_hit, start_time=None, end_time=None - ) + with patch("litellm.callbacks", [guardrail]): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, ) - mock_logging_obj._on_deferred_stream_complete = closure - - resp = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "hi"}], - mock_response="Hello!", - stream=True, - litellm_logging_obj=mock_logging_obj, - ) - async for _ in resp: - pass - await asyncio.sleep(0) await asyncio.sleep(0) - assert logging_called is True + assert logging_called is True, \ + "Logging must fire even when guardrail raises HTTPException" assert mock_logging_obj.model_call_details["metadata"].get( "guardrail_blocked" - ) is True + ) is True, "guardrail_blocked must be set for HTTPException" @pytest.mark.asyncio async def test_transient_error_does_not_set_guardrail_blocked(self): - """Transient errors (not HTTPException) should NOT set guardrail_blocked.""" + """Transient errors (not HTTPException) should NOT set + guardrail_blocked. Uses the production _run_deferred_stream_guardrails.""" + + class TransientErrorGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="transient", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + raise ConnectionError("Network timeout") + mock_logging_obj = MagicMock() mock_logging_obj.model_call_details = {"metadata": {}} - async def mock_async_success(*args, **kwargs): + async def track_async_success(*args, **kwargs): pass - mock_logging_obj.async_success_handler = mock_async_success + mock_logging_obj.async_success_handler = track_async_success - async def closure(assembled_response, cache_hit): - try: - raise ConnectionError("Network timeout") - except Exception as e: - if isinstance(e, HTTPException) and hasattr( - mock_logging_obj, "model_call_details" - ): - mock_logging_obj.model_call_details.setdefault( - "metadata", {} - )["guardrail_blocked"] = True + guardrail = TransientErrorGuardrail() - asyncio.create_task( - mock_logging_obj.async_success_handler( - assembled_response, cache_hit=cache_hit, start_time=None, end_time=None - ) + with patch("litellm.callbacks", [guardrail]): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, ) - mock_logging_obj._on_deferred_stream_complete = closure - - resp = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "hi"}], - mock_response="Hello!", - stream=True, - litellm_logging_obj=mock_logging_obj, - ) - async for _ in resp: - pass - await asyncio.sleep(0) assert mock_logging_obj.model_call_details["metadata"].get( "guardrail_blocked" - ) is not True + ) is not True, "guardrail_blocked must NOT be set for transient errors" @pytest.mark.asyncio async def test_production_closure_integration(self): @@ -685,3 +687,218 @@ class TestDeferredStreamingClosure: "apply_guardrail guardrails must be dispatched through UnifiedLLMGuardrails" assert logged_response is not None, \ "Logging must fire after unified guardrail path" + + @pytest.mark.asyncio + async def test_hooks_receive_merged_guardrail_data(self): + """Hooks must receive guardrail_data (the merged dict from + _check_and_merge_model_level_guardrails), not the original + captured_data. This ensures model-level non-default guardrails + are visible to any inner should_run_guardrail re-checks. + + Uses a deep-copy mock to break the shallow-copy side-effect that + would otherwise mask the bug — verifying the code is explicitly + correct, not correct-by-accident.""" + import copy + + hook_received_data = None + + class InspectingGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="inspector", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + nonlocal hook_received_data + hook_received_data = data + return response + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + pass + + mock_logging_obj.async_success_handler = track_async_success + + guardrail = InspectingGuardrail() + + captured_data = {"model": "gpt-4", "metadata": {"existing_key": "value"}} + + def mock_merge(data, llm_router): + """Return a fully independent dict (deep copy) so the original + captured_data is NOT mutated. This simulates a correct merge + implementation and proves _run_deferred_stream_guardrails uses + the return value, not the original data.""" + merged = copy.deepcopy(data) + merged["metadata"]["guardrails"] = ["model-guardrail"] + merged["_merged_marker"] = True + return merged + + with patch("litellm.callbacks", [guardrail]), \ + patch( + "litellm.proxy.utils._check_and_merge_model_level_guardrails", + side_effect=mock_merge, + ): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data=captured_data, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, + ) + + assert hook_received_data is not None, "Guardrail hook must be called" + assert hook_received_data.get("_merged_marker") is True, \ + "Hook must receive guardrail_data (merged), not original captured_data" + assert "model-guardrail" in hook_received_data.get("metadata", {}).get( + "guardrails", [] + ), "Hook data must contain model-level guardrails" + + @pytest.mark.asyncio + async def test_apply_guardrail_path_receives_merged_guardrail_data(self): + """The apply_guardrail path (through UnifiedLLMGuardrails) must also + receive guardrail_data so that the inner should_run_guardrail re-check + inside UnifiedLLMGuardrails sees model-level guardrails. + + This is the specific scenario Greptile flagged: a default_on=False + guardrail configured at the model level would pass the outer gate but + be silently skipped at execution time if captured_data (unmerged) were + passed instead of guardrail_data (merged).""" + import copy + from litellm.types.utils import GenericGuardrailAPIInputs + + unified_received_data = None + + class ModelLevelApplyGuardrail(CustomGuardrail): + def __init__(self): + super().__init__( + guardrail_name="model-apply-guardrail", + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def apply_guardrail( + self, inputs, request_data, input_type, logging_obj=None + ) -> GenericGuardrailAPIInputs: + return inputs + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + pass + + mock_logging_obj.async_success_handler = track_async_success + + guardrail = ModelLevelApplyGuardrail() + captured_data = {"model": "gpt-4", "metadata": {}} + + def mock_merge(data, llm_router): + merged = copy.deepcopy(data) + merged["metadata"]["guardrails"] = ["model-apply-guardrail"] + merged["_merged_marker"] = True + return merged + + # Capture what UnifiedLLMGuardrails.async_post_call_success_hook receives + original_unified_hook = None + from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, + ) + original_unified_hook = UnifiedLLMGuardrails.async_post_call_success_hook + + async def tracking_unified_hook(self, user_api_key_dict, data, response): + nonlocal unified_received_data + unified_received_data = data + return response + + with patch("litellm.callbacks", [guardrail]), \ + patch( + "litellm.proxy.utils._check_and_merge_model_level_guardrails", + side_effect=mock_merge, + ), \ + patch.object( + UnifiedLLMGuardrails, + "async_post_call_success_hook", + tracking_unified_hook, + ): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data=captured_data, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, + ) + + assert unified_received_data is not None, \ + "UnifiedLLMGuardrails must be called for apply_guardrail guardrails" + assert unified_received_data.get("_merged_marker") is True, \ + "UnifiedLLMGuardrails must receive guardrail_data (merged), not captured_data" + assert "model-apply-guardrail" in unified_received_data.get( + "metadata", {} + ).get("guardrails", []), \ + "UnifiedLLMGuardrails data must contain model-level guardrails" + + @pytest.mark.asyncio + async def test_multiple_guardrails_all_receive_merged_data(self): + """When multiple guardrails are configured, ALL of them must receive + guardrail_data (merged), not just the first one.""" + import copy + + received_data_per_guardrail = {} + + class TaggedGuardrail(CustomGuardrail): + def __init__(self, name): + super().__init__( + guardrail_name=name, + default_on=True, + event_hook=GuardrailEventHooks.post_call, + ) + + async def async_post_call_success_hook( + self, data: dict, user_api_key_dict: UserAPIKeyAuth, response: Any + ) -> Any: + received_data_per_guardrail[self.guardrail_name] = data + return response + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + pass + + mock_logging_obj.async_success_handler = track_async_success + + guardrail_a = TaggedGuardrail("guardrail-a") + guardrail_b = TaggedGuardrail("guardrail-b") + + captured_data = {"model": "gpt-4", "metadata": {}} + + def mock_merge(data, llm_router): + merged = copy.deepcopy(data) + merged["metadata"]["guardrails"] = ["guardrail-a", "guardrail-b"] + merged["_merged_marker"] = True + return merged + + with patch("litellm.callbacks", [guardrail_a, guardrail_b]), \ + patch( + "litellm.proxy.utils._check_and_merge_model_level_guardrails", + side_effect=mock_merge, + ): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data=captured_data, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, + ) + + for name in ("guardrail-a", "guardrail-b"): + assert name in received_data_per_guardrail, \ + f"{name} must be called" + assert received_data_per_guardrail[name].get("_merged_marker") is True, \ + f"{name} must receive guardrail_data (merged), not captured_data" From 0057452485d2b12a719e5e66262e22ef676a20e3 Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 19 Mar 2026 18:00:49 +0100 Subject: [PATCH 176/539] fix(proxy): guard streaming deferred init with try/finally, fix test imports Wrap _run_deferred_stream_guardrails initialization (UnifiedLLMGuardrails constructor and _check_and_merge_model_level_guardrails) in try/finally so logging always fires even if init throws. Prevents silent logging loss on transient errors. Move fastapi.HTTPException import from module-level to local test-function scope. Add test_logging_fires_even_if_guardrail_init_raises to verify the try/finally guard. --- litellm/proxy/common_request_processing.py | 115 +++++++++--------- .../test_deferred_guardrail_logging.py | 42 ++++++- 2 files changed, 101 insertions(+), 56 deletions(-) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 1db8327482..0e3e89c531 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -1345,76 +1345,81 @@ class ProxyBaseLLMRequestProcessing: from litellm.proxy.utils import _check_and_merge_model_level_guardrails _response = assembled_response - _unified_guardrail = UnifiedLLMGuardrails() - guardrail_data = _check_and_merge_model_level_guardrails( - data=captured_data, llm_router=_global_llm_router - ) - for cb in litellm.callbacks: - if not isinstance(cb, CustomGuardrail): - continue - if not cb.should_run_guardrail( - data=guardrail_data, - event_type=GuardrailEventHooks.post_call, - ): - continue - try: - guardrail_result = None - if "apply_guardrail" in type(cb).__dict__: - guardrail_data["guardrail_to_apply"] = cb - guardrail_result = ( - await _unified_guardrail.async_post_call_success_hook( + try: + _unified_guardrail = UnifiedLLMGuardrails() + guardrail_data = _check_and_merge_model_level_guardrails( + data=captured_data, llm_router=_global_llm_router + ) + for cb in litellm.callbacks: + if not isinstance(cb, CustomGuardrail): + continue + if not cb.should_run_guardrail( + data=guardrail_data, + event_type=GuardrailEventHooks.post_call, + ): + continue + try: + guardrail_result = None + if "apply_guardrail" in type(cb).__dict__: + guardrail_data["guardrail_to_apply"] = cb + guardrail_result = ( + await _unified_guardrail.async_post_call_success_hook( + user_api_key_dict=captured_user_api_key_dict, + data=guardrail_data, + response=_response, + ) + ) + else: + guardrail_result = await cb.async_post_call_success_hook( user_api_key_dict=captured_user_api_key_dict, data=guardrail_data, response=_response, ) + if guardrail_result is not None: + _response = guardrail_result + except Exception as e: + verbose_proxy_logger.exception( + "Error running post-call guardrail %s on streaming response: %s", + getattr(cb, "guardrail_name", type(cb).__name__), + e, ) - else: - guardrail_result = await cb.async_post_call_success_hook( - user_api_key_dict=captured_user_api_key_dict, - data=guardrail_data, - response=_response, + if isinstance(e, HTTPException) and hasattr( + captured_logging_obj, "model_call_details" + ): + captured_logging_obj.model_call_details.setdefault( + "metadata", {} + )["guardrail_blocked"] = True + except Exception as e: + verbose_proxy_logger.exception( + "Error in deferred streaming guardrail initialization: %s", e, + ) + finally: + try: + asyncio.create_task( + captured_logging_obj.async_success_handler( + _response, + cache_hit=cache_hit, + start_time=None, + end_time=None, ) - if guardrail_result is not None: - _response = guardrail_result + ) except Exception as e: verbose_proxy_logger.exception( - "Error running post-call guardrail %s on streaming response: %s", - getattr(cb, "guardrail_name", type(cb).__name__), - e, + "Error in deferred streaming async logging: %s", e, ) - if isinstance(e, HTTPException) and hasattr( - captured_logging_obj, "model_call_details" - ): - captured_logging_obj.model_call_details.setdefault( - "metadata", {} - )["guardrail_blocked"] = True - try: - asyncio.create_task( - captured_logging_obj.async_success_handler( + try: + executor.submit( + captured_logging_obj.success_handler, _response, cache_hit=cache_hit, start_time=None, end_time=None, ) - ) - except Exception as e: - verbose_proxy_logger.exception( - "Error in deferred streaming async logging: %s", e, - ) - - try: - executor.submit( - captured_logging_obj.success_handler, - _response, - cache_hit=cache_hit, - start_time=None, - end_time=None, - ) - except Exception as e: - verbose_proxy_logger.exception( - "Error in deferred streaming sync logging: %s", e, - ) + except Exception as e: + verbose_proxy_logger.exception( + "Error in deferred streaming sync logging: %s", e, + ) async def _handle_llm_api_exception( self, diff --git a/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py index f5c9eeba1c..c4d2dce587 100644 --- a/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py +++ b/tests/test_litellm/proxy/guardrails/test_deferred_guardrail_logging.py @@ -20,7 +20,6 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from fastapi import HTTPException sys.path.insert(0, os.path.abspath("../../../..")) @@ -233,6 +232,8 @@ async def test_deferred_logging_fires_on_guardrail_exception(): If post_call_success_hook raises (e.g., guardrail blocks content), the deferred logging closure must still fire (via try/finally). """ + from fastapi import HTTPException # noqa: local import for test isolation + enqueue_called = False def mock_enqueue(): @@ -470,6 +471,8 @@ class TestDeferredStreamingClosure: """If a guardrail raises HTTPException, the production _run_deferred_stream_guardrails must still fire logging and set guardrail_blocked in metadata.""" + from fastapi import HTTPException # noqa: local import for test isolation + logging_called = False class BlockingGuardrail(CustomGuardrail): @@ -902,3 +905,40 @@ class TestDeferredStreamingClosure: f"{name} must be called" assert received_data_per_guardrail[name].get("_merged_marker") is True, \ f"{name} must receive guardrail_data (merged), not captured_data" + + @pytest.mark.asyncio + async def test_logging_fires_even_if_guardrail_init_raises(self): + """If _check_and_merge_model_level_guardrails raises during + initialization, logging must still fire via the try/finally guard. + This prevents silent logging loss on transient init errors.""" + logging_called = False + + mock_logging_obj = MagicMock() + mock_logging_obj.model_call_details = {"metadata": {}} + + async def track_async_success(*args, **kwargs): + nonlocal logging_called + logging_called = True + + mock_logging_obj.async_success_handler = track_async_success + + def exploding_merge(data, llm_router): + raise RuntimeError("Simulated init failure") + + with patch( + "litellm.proxy.utils._check_and_merge_model_level_guardrails", + side_effect=exploding_merge, + ): + await ProxyBaseLLMRequestProcessing._run_deferred_stream_guardrails( + captured_data={"model": "gpt-4", "metadata": {}}, + captured_user_api_key_dict=UserAPIKeyAuth(api_key="test"), + captured_logging_obj=mock_logging_obj, + assembled_response=MagicMock(), + cache_hit=False, + ) + + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert logging_called is True, \ + "Logging must fire even when guardrail initialization raises" From f415b72bcfa795c3673de5d13b68658fc9a3482e Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 18 Mar 2026 15:20:45 -0700 Subject: [PATCH 177/539] feat(anthropic): support ANTHROPIC_AUTH_TOKEN and ANTHROPIC_BASE_URL env vars Co-Authored-By: Claude Signed-off-by: Devin Petersohn --- litellm/batches/main.py | 1 + litellm/constants.py | 1 + .../llms/anthropic/batches/transformation.py | 7 +- litellm/llms/anthropic/common_utils.py | 54 ++- .../messages/transformation.py | 10 +- litellm/llms/anthropic/files/handler.py | 10 +- .../llms/anthropic/files/transformation.py | 8 +- .../llms/anthropic/skills/transformation.py | 13 +- .../llm_passthrough_endpoints.py | 4 +- litellm/utils.py | 8 +- .../anthropic/test_anthropic_common_utils.py | 372 +++++++++++++++++- 11 files changed, 441 insertions(+), 47 deletions(-) diff --git a/litellm/batches/main.py b/litellm/batches/main.py index e176dc4292..17d73aae6a 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -524,6 +524,7 @@ def _handle_retrieve_batch_providers_without_provider_config( optional_params.api_base or litellm.api_base or get_secret_str("ANTHROPIC_API_BASE") + or get_secret_str("ANTHROPIC_BASE_URL") ) api_key = ( optional_params.api_key diff --git a/litellm/constants.py b/litellm/constants.py index 89c59ee932..c0dd115210 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1444,6 +1444,7 @@ SENTRY_DENYLIST = [ "credential", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", + "ANTHROPIC_AUTH_TOKEN", "AZURE_API_KEY", "COHERE_API_KEY", "REPLICATE_API_KEY", diff --git a/litellm/llms/anthropic/batches/transformation.py b/litellm/llms/anthropic/batches/transformation.py index 699f133f0f..98c0588a09 100644 --- a/litellm/llms/anthropic/batches/transformation.py +++ b/litellm/llms/anthropic/batches/transformation.py @@ -42,9 +42,8 @@ class AnthropicBatchesConfig(BaseBatchesConfig): api_base: Optional[str] = None, ) -> dict: """Validate and prepare environment-specific headers and parameters.""" - # Resolve api_key from environment if not provided - api_key = api_key or self.anthropic_model_info.get_api_key() - if api_key is None: + auth_header = self.anthropic_model_info.get_auth_header(api_key) + if auth_header is None: raise ValueError( "Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params" ) @@ -52,8 +51,8 @@ class AnthropicBatchesConfig(BaseBatchesConfig): "accept": "application/json", "anthropic-version": "2023-06-01", "content-type": "application/json", - "x-api-key": api_key, } + _headers.update(auth_header) # Add beta header for message batches if "anthropic-beta" not in headers: headers["anthropic-beta"] = "message-batches-2024-09-24" diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index ac35246787..7199e21014 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -359,9 +359,7 @@ class AnthropicModelInfo(BaseLLMModelInfo): Returns: List of beta header strings """ - from litellm.types.llms.anthropic import ( - ANTHROPIC_EFFORT_BETA_HEADER, - ) + from litellm.types.llms.anthropic import ANTHROPIC_EFFORT_BETA_HEADER betas = [] @@ -390,7 +388,8 @@ class AnthropicModelInfo(BaseLLMModelInfo): def get_anthropic_headers( self, - api_key: str, + api_key: Optional[str] = None, + auth_token: Optional[str] = None, anthropic_version: Optional[str] = None, computer_tool_used: Optional[str] = None, prompt_caching_set: bool = False, @@ -451,6 +450,8 @@ class AnthropicModelInfo(BaseLLMModelInfo): headers["authorization"] = f"Bearer {api_key}" headers["anthropic-dangerous-direct-browser-access"] = "true" betas.add(ANTHROPIC_OAUTH_BETA_HEADER) + elif auth_token and not api_key: + headers["authorization"] = f"Bearer {auth_token}" else: headers["x-api-key"] = api_key @@ -485,9 +486,13 @@ class AnthropicModelInfo(BaseLLMModelInfo): headers, api_key = optionally_handle_anthropic_oauth( headers=headers, api_key=api_key ) + # Resolve auth_token from ANTHROPIC_AUTH_TOKEN if api_key is not set + auth_token: Optional[str] = None if api_key is None: + auth_token = AnthropicModelInfo.get_auth_token() + if api_key is None and auth_token is None: raise litellm.AuthenticationError( - message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars", + message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` or `ANTHROPIC_AUTH_TOKEN` in your environment vars", llm_provider="anthropic", model=model, ) @@ -519,6 +524,7 @@ class AnthropicModelInfo(BaseLLMModelInfo): prompt_caching_set=prompt_caching_set, pdf_used=pdf_used, api_key=api_key, + auth_token=auth_token, file_id_used=file_id_used, web_search_tool_used=web_search_tool_used, is_vertex_request=optional_params.get("is_vertex_request", False), @@ -543,6 +549,7 @@ class AnthropicModelInfo(BaseLLMModelInfo): return ( api_base or get_secret_str("ANTHROPIC_API_BASE") + or get_secret_str("ANTHROPIC_BASE_URL") or "https://api.anthropic.com" ) @@ -552,6 +559,33 @@ class AnthropicModelInfo(BaseLLMModelInfo): return api_key or get_secret_str("ANTHROPIC_API_KEY") + @staticmethod + def get_auth_token(auth_token: Optional[str] = None) -> Optional[str]: + """Get auth token from ANTHROPIC_AUTH_TOKEN env var. + + Unlike api_key (which uses X-Api-Key header), auth_token uses + Authorization: Bearer header, matching the official Anthropic SDK behavior. + """ + from litellm.secret_managers.main import get_secret_str + + return auth_token or get_secret_str("ANTHROPIC_AUTH_TOKEN") + + @staticmethod + def get_auth_header(api_key: Optional[str] = None) -> Optional[dict]: + """Resolve Anthropic credentials and return the appropriate auth header dict. + + Checks ANTHROPIC_API_KEY first (-> x-api-key), then + ANTHROPIC_AUTH_TOKEN (-> Authorization: Bearer). + Returns None if neither is available. + """ + resolved_key = AnthropicModelInfo.get_api_key(api_key) + if resolved_key is not None: + return {"x-api-key": resolved_key} + auth_token = AnthropicModelInfo.get_auth_token() + if auth_token is not None: + return {"authorization": f"Bearer {auth_token}"} + return None + @staticmethod def get_base_model(model: Optional[str] = None) -> Optional[str]: return model.replace("anthropic/", "") if model else None @@ -560,14 +594,16 @@ class AnthropicModelInfo(BaseLLMModelInfo): self, api_key: Optional[str] = None, api_base: Optional[str] = None ) -> List[str]: api_base = AnthropicModelInfo.get_api_base(api_base) - api_key = AnthropicModelInfo.get_api_key(api_key) - if api_base is None or api_key is None: + auth_header = AnthropicModelInfo.get_auth_header(api_key) + if api_base is None or auth_header is None: raise ValueError( - "ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint." + "ANTHROPIC_API_BASE/ANTHROPIC_BASE_URL or ANTHROPIC_API_KEY/ANTHROPIC_AUTH_TOKEN is not set. Please set the environment variable, to query Anthropic's `/models` endpoint." ) + headers = {"anthropic-version": "2023-06-01"} + headers.update(auth_header) response = litellm.module_level_client.get( url=f"{api_base}/v1/models", - headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"}, + headers=headers, ) try: diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py index e9ceea4822..8e7c850627 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py @@ -142,17 +142,15 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig): api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> Tuple[dict, Optional[str]]: - import os - # Check for Anthropic OAuth token in Authorization header headers, api_key = optionally_handle_anthropic_oauth( headers=headers, api_key=api_key ) - if api_key is None: - api_key = os.getenv("ANTHROPIC_API_KEY") - if "x-api-key" not in headers and "authorization" not in headers and api_key: - headers["x-api-key"] = api_key + if "x-api-key" not in headers and "authorization" not in headers: + auth_header = AnthropicModelInfo.get_auth_header(api_key) + if auth_header is not None: + headers.update(auth_header) if "anthropic-version" not in headers: headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION if "content-type" not in headers: diff --git a/litellm/llms/anthropic/files/handler.py b/litellm/llms/anthropic/files/handler.py index 77cc8c2731..c56799f30c 100644 --- a/litellm/llms/anthropic/files/handler.py +++ b/litellm/llms/anthropic/files/handler.py @@ -8,10 +8,8 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm._uuid import uuid -from litellm.llms.custom_httpx.http_handler import ( - get_async_httpx_client, -) from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.openai import ( FileContentRequest, HttpxBinaryResponseContent, @@ -85,9 +83,9 @@ class AnthropicFilesHandler: # Get Anthropic API credentials api_base = self.anthropic_model_info.get_api_base(api_base) - api_key = api_key or self.anthropic_model_info.get_api_key() + auth_header = self.anthropic_model_info.get_auth_header(api_key) - if not api_key: + if auth_header is None: raise ValueError("Missing Anthropic API Key") # Construct the Anthropic batch results URL @@ -97,8 +95,8 @@ class AnthropicFilesHandler: headers = { "accept": "application/json", "anthropic-version": "2023-06-01", - "x-api-key": api_key, } + headers.update(auth_header) # Make the request to Anthropic async_client = get_async_httpx_client(llm_provider=LlmProviders.ANTHROPIC) diff --git a/litellm/llms/anthropic/files/transformation.py b/litellm/llms/anthropic/files/transformation.py index 98a548a136..0545cefb07 100644 --- a/litellm/llms/anthropic/files/transformation.py +++ b/litellm/llms/anthropic/files/transformation.py @@ -94,14 +94,14 @@ class AnthropicFilesConfig(BaseFilesConfig): api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: - api_key = AnthropicModelInfo.get_api_key(api_key) - if not api_key: + auth_header = AnthropicModelInfo.get_auth_header(api_key) + if auth_header is None: raise ValueError( - "Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter." + "Anthropic API key is required. Set ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN environment variable or pass api_key parameter." ) headers.update( { - "x-api-key": api_key, + **auth_header, "anthropic-version": "2023-06-01", "anthropic-beta": ANTHROPIC_FILES_BETA_HEADER, } diff --git a/litellm/llms/anthropic/skills/transformation.py b/litellm/llms/anthropic/skills/transformation.py index af9863534e..f582aefd81 100644 --- a/litellm/llms/anthropic/skills/transformation.py +++ b/litellm/llms/anthropic/skills/transformation.py @@ -35,17 +35,16 @@ class AnthropicSkillsConfig(BaseSkillsAPIConfig): """Add Anthropic-specific headers""" from litellm.llms.anthropic.common_utils import AnthropicModelInfo - # Get API key + # Get API key from litellm_params if available api_key = None - if litellm_params: + if litellm_params is not None: api_key = litellm_params.api_key - api_key = AnthropicModelInfo.get_api_key(api_key) - if not api_key: - raise ValueError("ANTHROPIC_API_KEY is required for Skills API") + auth_header = AnthropicModelInfo.get_auth_header(api_key) + if auth_header is None: + raise ValueError("ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN is required for Skills API") - # Add required headers - headers["x-api-key"] = api_key + headers.update(auth_header) headers["anthropic-version"] = "2023-06-01" # Add beta header for skills API diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 4e3e04a847..f6b83a68ad 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -585,7 +585,7 @@ async def anthropic_proxy_route( """ [Docs](https://docs.litellm.ai/docs/pass_through/anthropic_completion) """ - base_target_url = os.getenv("ANTHROPIC_API_BASE") or "https://api.anthropic.com" + base_target_url = os.getenv("ANTHROPIC_API_BASE") or os.getenv("ANTHROPIC_BASE_URL") or "https://api.anthropic.com" encoded_endpoint = httpx.URL(endpoint).path # Ensure endpoint starts with '/' for proper URL construction @@ -609,7 +609,7 @@ async def anthropic_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), - custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, + custom_headers={"x-api-key": "{}".format(anthropic_api_key)} if anthropic_api_key else {}, _forward_headers=True, is_streaming_request=is_streaming_request, ) # dynamically construct pass-through endpoint based on incoming path diff --git a/litellm/utils.py b/litellm/utils.py index 81d749ab82..77625809bf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6160,7 +6160,7 @@ def validate_environment( # noqa: PLR0915 ["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"] ) elif custom_llm_provider == "anthropic": - if "ANTHROPIC_API_KEY" in os.environ: + if "ANTHROPIC_API_KEY" in os.environ or "ANTHROPIC_AUTH_TOKEN" in os.environ: keys_in_environment = True else: missing_keys.append("ANTHROPIC_API_KEY") @@ -6399,7 +6399,7 @@ def validate_environment( # noqa: PLR0915 missing_keys.append("OPENAI_API_KEY") ## anthropic elif model in litellm.anthropic_models: - if "ANTHROPIC_API_KEY" in os.environ: + if "ANTHROPIC_API_KEY" in os.environ or "ANTHROPIC_AUTH_TOKEN" in os.environ: keys_in_environment = True else: missing_keys.append("ANTHROPIC_API_KEY") @@ -8593,9 +8593,7 @@ class ProviderConfigManager: return ManusFilesConfig() elif LlmProviders.ANTHROPIC == provider: - from litellm.llms.anthropic.files.transformation import ( - AnthropicFilesConfig, - ) + from litellm.llms.anthropic.files.transformation import AnthropicFilesConfig return AnthropicFilesConfig() return None diff --git a/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py b/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py index b4f2629f8f..fcec155bae 100644 --- a/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py +++ b/tests/test_litellm/llms/anthropic/test_anthropic_common_utils.py @@ -1,20 +1,27 @@ """ -Tests for Anthropic OAuth token handling in common_utils. +Tests for Anthropic authentication and environment variable handling in common_utils. -Verifies that OAuth tokens (sk-ant-oat*) are sent via Authorization: Bearer -instead of x-api-key, per Anthropic's OAuth specification. +Verifies that: +- OAuth tokens (sk-ant-oat*) produce Authorization: Bearer headers with OAuth beta flags. +- Regular API keys produce x-api-key headers. +- ANTHROPIC_AUTH_TOKEN produces Authorization: Bearer headers, + matching the official Anthropic SDK behavior. +- ANTHROPIC_BASE_URL is used as a fallback for base URL resolution. +- ANTHROPIC_API_KEY / ANTHROPIC_API_BASE take precedence over their aliases. """ import os import sys +from unittest.mock import patch sys.path.insert( 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) ) -# Fake OAuth token for testing (not a real secret) +# Fake tokens for testing (not real secrets) FAKE_OAUTH_TOKEN = "sk-ant-oat01-fake-token-for-testing-123456789abcdef" FAKE_REGULAR_KEY = "sk-ant-api03-regular-key-for-testing-123456789" +FAKE_AUTH_TOKEN = "sk-ant-aut01-fake-auth-token-for-testing-123456789" class TestOptionallyHandleAnthropicOAuth: @@ -697,3 +704,360 @@ class TestProxyOAuthHeaderForwarding: assert cleaned["authorization"] == oauth_token # Proxy key must be stripped assert "x-litellm-api-key" not in cleaned + + +class TestGetAnthropicHeadersWithAuthToken: + """Tests for get_anthropic_headers with auth_token parameter.""" + + def test_auth_token_uses_bearer_header(self): + """auth_token should produce Authorization: Bearer header.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + headers = config.get_anthropic_headers( + api_key=None, + auth_token=FAKE_AUTH_TOKEN, + computer_tool_used=False, + prompt_caching_set=False, + pdf_used=False, + is_vertex_request=False, + ) + + assert headers["authorization"] == f"Bearer {FAKE_AUTH_TOKEN}" + assert "x-api-key" not in headers + # auth_token should NOT set OAuth-specific flags + assert "anthropic-dangerous-direct-browser-access" not in headers + + def test_auth_token_includes_standard_headers(self): + """auth_token path should include standard Anthropic headers.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + headers = config.get_anthropic_headers( + api_key=None, + auth_token=FAKE_AUTH_TOKEN, + computer_tool_used=False, + prompt_caching_set=False, + pdf_used=False, + is_vertex_request=False, + ) + + assert headers["anthropic-version"] == "2023-06-01" + assert headers["accept"] == "application/json" + assert headers["content-type"] == "application/json" + + def test_api_key_takes_precedence_over_auth_token(self): + """When both api_key and auth_token are provided, api_key wins.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + headers = config.get_anthropic_headers( + api_key=FAKE_REGULAR_KEY, + auth_token=FAKE_AUTH_TOKEN, + computer_tool_used=False, + prompt_caching_set=False, + pdf_used=False, + is_vertex_request=False, + ) + + assert headers["x-api-key"] == FAKE_REGULAR_KEY + assert "authorization" not in headers + + +class TestValidateEnvironmentAuthToken: + """Tests for validate_environment with auth_token resolution.""" + + def test_auth_token_env_var_produces_bearer_header(self): + """validate_environment should use Bearer auth when only ANTHROPIC_AUTH_TOKEN is set.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + clear=True, + ): + headers = config.validate_environment( + headers={}, + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hello"}], + optional_params={}, + litellm_params={}, + api_key=None, + api_base=None, + ) + + assert headers["authorization"] == f"Bearer {FAKE_AUTH_TOKEN}" + assert "x-api-key" not in headers + assert "anthropic-dangerous-direct-browser-access" not in headers + + def test_api_key_param_takes_precedence_over_auth_token_env_var(self): + """validate_environment should prefer explicit api_key over ANTHROPIC_AUTH_TOKEN.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + clear=True, + ): + headers = config.validate_environment( + headers={}, + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hello"}], + optional_params={}, + litellm_params={}, + api_key=FAKE_REGULAR_KEY, + api_base=None, + ) + + assert headers["x-api-key"] == FAKE_REGULAR_KEY + assert "authorization" not in headers + + def test_raises_when_no_credentials(self): + """validate_environment should raise when neither API key nor auth token is available.""" + from unittest.mock import patch as mock_patch + + import pytest + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + config = AnthropicModelInfo() + with mock_patch.dict("os.environ", {}, clear=True): + with pytest.raises( + Exception, match="ANTHROPIC_API_KEY.*ANTHROPIC_AUTH_TOKEN" + ): + config.validate_environment( + headers={}, + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hello"}], + optional_params={}, + litellm_params={}, + api_key=None, + api_base=None, + ) + + + + +class TestGetAuthToken: + """Tests for AnthropicModelInfo.get_auth_token() static method.""" + + def test_returns_env_var_value(self): + """get_auth_token returns the ANTHROPIC_AUTH_TOKEN env var value.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, clear=True + ): + assert AnthropicModelInfo.get_auth_token() == FAKE_AUTH_TOKEN + + def test_returns_none_when_not_set(self): + """get_auth_token returns None when ANTHROPIC_AUTH_TOKEN is not set.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict("os.environ", {}, clear=True): + assert AnthropicModelInfo.get_auth_token() is None + + def test_explicit_param_takes_precedence(self): + """Explicit auth_token param takes precedence over env var.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + explicit_token = "sk-ant-aut01-explicit-token-override-123456789" + assert AnthropicModelInfo.get_auth_token(explicit_token) == explicit_token + + +class TestGetAuthHeader: + """Tests for AnthropicModelInfo.get_auth_header() centralized helper.""" + + def test_returns_x_api_key_when_api_key_provided(self): + """Explicit api_key param should return x-api-key header.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + result = AnthropicModelInfo.get_auth_header(api_key=FAKE_REGULAR_KEY) + assert result == {"x-api-key": FAKE_REGULAR_KEY} + + def test_returns_x_api_key_from_env(self): + """ANTHROPIC_API_KEY env var should return x-api-key header.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_API_KEY": FAKE_REGULAR_KEY}, + clear=True, + ): + result = AnthropicModelInfo.get_auth_header() + assert result == {"x-api-key": FAKE_REGULAR_KEY} + + def test_returns_bearer_from_auth_token_env(self): + """ANTHROPIC_AUTH_TOKEN env var should return Authorization: Bearer header.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + clear=True, + ): + result = AnthropicModelInfo.get_auth_header() + assert result == {"authorization": f"Bearer {FAKE_AUTH_TOKEN}"} + + def test_api_key_takes_precedence_over_auth_token(self): + """ANTHROPIC_API_KEY should take precedence over ANTHROPIC_AUTH_TOKEN.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + { + "ANTHROPIC_API_KEY": FAKE_REGULAR_KEY, + "ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN, + }, + clear=True, + ): + result = AnthropicModelInfo.get_auth_header() + assert result == {"x-api-key": FAKE_REGULAR_KEY} + + def test_explicit_api_key_overrides_env_auth_token(self): + """Explicit api_key param should override ANTHROPIC_AUTH_TOKEN env var.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + clear=True, + ): + result = AnthropicModelInfo.get_auth_header(api_key=FAKE_REGULAR_KEY) + assert result == {"x-api-key": FAKE_REGULAR_KEY} + + def test_returns_none_when_no_credentials(self): + """Should return None when neither api_key nor auth_token is available.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict("os.environ", {}, clear=True): + result = AnthropicModelInfo.get_auth_header() + assert result is None + + +class TestGetApiBaseFallbackChain: + """Tests for AnthropicModelInfo.get_api_base() fallback to ANTHROPIC_BASE_URL.""" + + def test_explicit_param_takes_precedence(self): + """Explicit api_base param takes precedence over all env vars.""" + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + assert ( + AnthropicModelInfo.get_api_base("https://explicit.example.com") + == "https://explicit.example.com" + ) + + def test_defaults_to_anthropic_api(self): + """get_api_base returns the default Anthropic API base when no env vars are set.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict("os.environ", {}, clear=True): + assert AnthropicModelInfo.get_api_base() == "https://api.anthropic.com" + + def test_api_base_env_preferred_over_base_url_env(self): + """ANTHROPIC_API_BASE takes precedence over ANTHROPIC_BASE_URL.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + { + "ANTHROPIC_API_BASE": "https://api-base.example.com", + "ANTHROPIC_BASE_URL": "https://base-url.example.com", + }, + clear=True, + ): + assert AnthropicModelInfo.get_api_base() == "https://api-base.example.com" + + def test_falls_back_to_base_url_env(self): + """get_api_base falls back to ANTHROPIC_BASE_URL when ANTHROPIC_API_BASE is not set.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.common_utils import AnthropicModelInfo + + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_BASE_URL": "https://base-url.example.com"}, + clear=True, + ): + assert AnthropicModelInfo.get_api_base() == "https://base-url.example.com" + + +class TestPassthroughAuthToken: + """Tests for passthrough messages endpoint with ANTHROPIC_AUTH_TOKEN.""" + + def test_passthrough_auth_token_uses_bearer_header(self): + """Passthrough endpoint should use Bearer auth when only ANTHROPIC_AUTH_TOKEN is set.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.experimental_pass_through.messages.transformation import ( + AnthropicMessagesConfig, + ) + + config = AnthropicMessagesConfig() + with mock_patch.dict( + "os.environ", {"ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, clear=True + ): + updated_headers, _ = config.validate_anthropic_messages_environment( + headers={}, + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hello"}], + optional_params={}, + litellm_params={}, + api_key=None, + api_base=None, + ) + + assert updated_headers["authorization"] == f"Bearer {FAKE_AUTH_TOKEN}" + assert "x-api-key" not in updated_headers + assert "anthropic-dangerous-direct-browser-access" not in updated_headers + + def test_passthrough_api_key_takes_precedence(self): + """Passthrough endpoint should prefer ANTHROPIC_API_KEY over ANTHROPIC_AUTH_TOKEN.""" + from unittest.mock import patch as mock_patch + + from litellm.llms.anthropic.experimental_pass_through.messages.transformation import ( + AnthropicMessagesConfig, + ) + + config = AnthropicMessagesConfig() + with mock_patch.dict( + "os.environ", + {"ANTHROPIC_API_KEY": FAKE_REGULAR_KEY, "ANTHROPIC_AUTH_TOKEN": FAKE_AUTH_TOKEN}, + clear=True, + ): + updated_headers, _ = config.validate_anthropic_messages_environment( + headers={}, + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hello"}], + optional_params={}, + litellm_params={}, + api_key=None, + api_base=None, + ) + + assert updated_headers["x-api-key"] == FAKE_REGULAR_KEY + assert "authorization" not in updated_headers From 81dadb698a5984a4bf825903b3384927d12d54bc Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 19 Mar 2026 10:20:35 -0700 Subject: [PATCH 178/539] Ishaan - March 18th changes (#24056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add DD Tracing (#24033) * feat(models): add Azure GPT-5.4 mini and nano variants (#24045) Add `azure/gpt-5.4-mini` and `azure/gpt-5.4-nano` to the model database with official pricing from Azure OpenAI: - GPT-5.4 mini: $0.75/M input, $0.075/M cached, $4.5/M output - GPT-5.4 nano: $0.20/M input, $0.02/M cached, $1.25/M output Both models support: - 1.05M input / 128K output context window - Chat, batch, and responses endpoints - Function calling, tools, vision, reasoning - Prompt caching with automatic tiered pricing Co-authored-by: Claude Opus 4.6 * Add new model pricing details for volcengine Doubao-Seed-2.0 series (#23871) Add entries for volcengine Doubao-Seed-2.0 series * fix(mcp): support refresh_token grant type in OAuth token endpoint (#23701) * fix(mcp): support refresh_token grant type in OAuth token endpoint (#23700) The .well-known/oauth-authorization-server metadata advertises refresh_token as a supported grant type, but the token endpoint rejected it with HTTP 400. This adds refresh_token grant support so MCP clients can refresh expired tokens without re-authenticating. * test(mcp): add tests for refresh_token grant type in OAuth token endpoint * fix(mcp): move code_verifier guard into authorization_code branch code_verifier is only relevant for authorization_code grants (PKCE). Move it inside the else branch so it doesn't apply to refresh_token. * fix(mcp): guard None client_secret and forward scope in token exchange - Conditionally include client_secret in form data to prevent httpx from sending the literal string "None" (applies to both authorization_code and refresh_token branches) - Forward optional scope parameter per RFC 6749 §6, allowing clients to request a subset of originally-granted scopes on refresh * fix(mcp): validate code param in authorization_code grant Guard against None code being form-encoded as literal string "None" by httpx, symmetric with the existing refresh_token guard. * docs: add incident report for guardrail logging secret exposure (#24059) Add blog post documenting the guardrail logging path exposing internal request data (e.g. Authorization headers) in spend logs and OTEL traces. Fix available in LiteLLM 1.82.3+. Made-with: Cursor * [Fix] Datadog LLM Observability tags format (env, service, version missing) (#23673) * tag fix * greptile comment * fix(ci): stabilize 6 failing CI jobs 1. mypy: remove duplicate type annotation for token_data in discoverable_endpoints.py 2. integrations tests: add parameterized to CI test deps 3. doc quality: document OTEL_IGNORE_CONTEXT_PROPAGATION env key 4. security: allowlist CVE-2026-2673, CVE-2026-3644, CVE-2026-4224 (no fix available) 5. proxy_store_model_in_db: fix missing x-litellm-call-id header on error responses 6. google tests: add --retries 3 for transient Vertex AI rate limits Co-authored-by: Ishaan Jaff * fix(streaming): handle RuntimeError during model_copy in streaming handler The race condition occurs when model_copy(deep=True) tries to deepcopy _hidden_params dict while it's being concurrently modified by logging callbacks. Fall back to shallow copy if the deep copy fails. Co-authored-by: Ishaan Jaff * fix(cost): handle non-string traffic_type in cost calculator + add retries 1. Fix AttributeError in _map_traffic_type_to_service_tier when traffic_type is an integer (cast to str before calling .upper()). This was causing pass-through vertex spend logging to fail silently. 2. Add --retries to llm_translation_testing for flaky external API calls. Co-authored-by: Ishaan Jaff --------- Co-authored-by: Emerson Gomes Co-authored-by: Claude Opus 4.6 Co-authored-by: ExMatics HydrogenC <33123710+HydrogenC@users.noreply.github.com> Co-authored-by: Jack Venberg Co-authored-by: milan-berri Co-authored-by: Shivam Rawat <161387515+shivamrawat1@users.noreply.github.com> Co-authored-by: Cursor Agent Co-authored-by: Ishaan Jaff --- .circleci/config.yml | 6 +- ci_cd/security_scans.sh | 3 + .../index.md | 78 ++++++ docs/my-website/docs/proxy/config_settings.md | 1 + litellm/cost_calculator.py | 2 +- litellm/integrations/datadog/datadog.py | 10 +- .../integrations/datadog/datadog_handler.py | 11 +- .../integrations/datadog/datadog_llm_obs.py | 4 +- .../litellm_core_utils/streaming_handler.py | 26 +- ...odel_prices_and_context_window_backup.json | 72 +++++ .../mcp_server/discoverable_endpoints.py | 56 +++- litellm/proxy/auth/auth_checks.py | 147 +++++----- litellm/proxy/auth/user_api_key_auth.py | 264 +++++++++--------- litellm/proxy/common_request_processing.py | 4 +- .../mcp_management_endpoints.py | 4 + model_prices_and_context_window.json | 224 +++++++++++++++ tests/logging_callback_tests/test_datadog.py | 18 +- .../datadog/test_datadog_tags_regression.py | 2 +- .../mcp_server/test_discoverable_endpoints.py | 138 +++++++++ .../test_mcp_management_endpoints.py | 52 ++++ 20 files changed, 881 insertions(+), 241 deletions(-) create mode 100644 docs/my-website/blog/guardrail_logging_secret_exposure_incident/index.md diff --git a/.circleci/config.yml b/.circleci/config.yml index 12e3cb1f6b..790efc7986 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,7 +42,7 @@ commands: "pydantic==2.11.0" "mcp==1.25.0" "requests-mock>=1.12.1" \ "responses==0.25.7" "pytest-xdist==3.6.1" "pytest-timeout==2.2.0" \ "pytest-cov==5.0.0" "semantic_router==0.1.10" "fastapi-offline==1.7.3" \ - "a2a" + "a2a" "parameterized>=0.9.0" - setup_litellm_enterprise_pip - save_cache: paths: @@ -1115,7 +1115,7 @@ jobs: for dir in "${IGNORE_DIRS[@]}"; do IGNORE_ARGS="$IGNORE_ARGS --ignore=$dir" done - python -m pytest -v tests/llm_translation $IGNORE_ARGS --junitxml=test-results/junit.xml --durations=20 -n 8 --timeout=120 --timeout_method=thread + python -m pytest -v tests/llm_translation $IGNORE_ARGS --junitxml=test-results/junit.xml --durations=20 -n 8 --timeout=120 --timeout_method=thread --retries 2 --retry-delay 5 no_output_timeout: 15m # Store test results @@ -1331,7 +1331,7 @@ jobs: command: | pwd ls - python -m pytest -vv tests/unified_google_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + python -m pytest -vv tests/unified_google_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 --retries 3 --retry-delay 5 no_output_timeout: 15m - run: name: Rename the coverage files diff --git a/ci_cd/security_scans.sh b/ci_cd/security_scans.sh index e0f370e003..801b700f64 100755 --- a/ci_cd/security_scans.sh +++ b/ci_cd/security_scans.sh @@ -163,6 +163,9 @@ run_grype_scans() { "CVE-2026-25639" # axios - full fix requires 1.x major version bump; pinned to >=0.30.2 to clear other axios CVEs, upgrade to 1.x in follow-up "CVE-2026-2297" # Python 3.13 SourcelessFileLoader audit hook bypass - no fix available in base image "GHSA-qffp-2rhf-9h96" # tar hardlink path traversal - from nodejs_wheel bundled npm, not used in application runtime code + "CVE-2026-2673" # OpenSSL 3.6.1 TLS 1.3 key exchange group negotiation issue - no fix available yet + "CVE-2026-3644" # Python 3.13 vulnerability - no fix available in base image + "CVE-2026-4224" # Python 3.13 Expat parser stack overflow in ElementDeclHandler - no fix available in base image ) # Build JSON array of allowlisted CVE IDs for jq diff --git a/docs/my-website/blog/guardrail_logging_secret_exposure_incident/index.md b/docs/my-website/blog/guardrail_logging_secret_exposure_incident/index.md new file mode 100644 index 0000000000..71f9e3da01 --- /dev/null +++ b/docs/my-website/blog/guardrail_logging_secret_exposure_incident/index.md @@ -0,0 +1,78 @@ +--- +slug: guardrail-logging-secret-exposure-incident +title: "Incident Report: Guardrail logging exposed secret headers in spend logs and traces" +date: 2026-03-18T10:00:00 +authors: + - litellm +tags: [incident-report, security, guardrails] +hide_table_of_contents: false +--- + +**Date:** March 18, 2026 +**Duration:** Unknown +**Severity:** High +**Status:** Resolved + +## Summary + +When a custom guardrail returned the full LiteLLM request/data dictionary, the guardrail response logged by LiteLLM could include `secret_fields.raw_headers`, including plaintext `Authorization` headers containing API keys or other credentials. + +This information could then propagate to logging and observability surfaces that consume guardrail metadata, including: + +- **Spend logs in the LiteLLM UI:** visible to admins with access to spend-log data +- **OpenTelemetry traces:** visible to anyone with access to the relevant telemetry backend + +LLM calls, proxy routing, and provider execution were not blocked by this bug. The impact was exposure of sensitive request headers in observability and logging paths. + +{/* truncate */} + +--- + +## Background + +LiteLLM keeps internal request data (including request headers) for use during the call. That data is not meant to be written to logs or telemetry. + +When custom guardrails run, their outcomes are logged so they can appear in spend logs, OpenTelemetry traces, and other observability backends. If a guardrail returned the full request payload instead of a minimal result, that internal request data could be included in what was logged. Before the fix, the guardrail logging path did not strip that data before sending it to those systems. + +```mermaid +flowchart TD + inboundRequest["1. Incoming proxy request"] --> storeSecrets["2. Store internal request data"] + storeSecrets --> guardrailRuns["3. Custom guardrail runs"] + guardrailRuns --> fullDataReturn["4. Guardrail returns full request payload"] + fullDataReturn --> loggingBuild["5. Build guardrail log payload"] + loggingBuild --> spendLogs["6a. Persist to spend logs / UI"] + loggingBuild --> otelTraces["6b. Attach to OTEL guardrail spans"] +``` + +--- + +## Root Cause + +The root cause was incomplete sanitization in the guardrail logging path. When building the payload that gets sent to spend logs and traces, LiteLLM prepared guardrail responses for logging but did not strip internal request data (such as headers) from them. If a guardrail returned a response that included that data, it was passed through to the logging and observability systems unchanged. + +--- + +## Impact + +This issue required all of the following: + +1. A custom guardrail returned the full LiteLLM request/data dictionary, or another response object containing `secret_fields`. +2. LiteLLM logged that guardrail response through the standard guardrail logging path. +3. An operator, admin, or telemetry consumer had access to the resulting logs or traces. + +When those conditions were met, sensitive values could become visible through: + +- **Spend logs / UI responses:** guardrail metadata could be included in spend-log payloads rendered in the admin UI. +- **OpenTelemetry traces:** `guardrail_response` could be written as a span attribute on guardrail spans. +- **Other downstream observability backends:** any integration consuming the same guardrail metadata could receive the leaked values. + +This was a logging and telemetry exposure bug. It did not let callers bypass auth, access other tenants directly, or change model behavior, but it could expose plaintext credentials to people with access to those observability systems. + +--- + +## Guidance For Users + +- Upgrade to LiteLLM 1.82.3+. +- If you operated custom guardrails that return the full request/data dict, review whether spend logs or telemetry traces were retained during the affected period. +- Rotate any credentials that may have appeared in `Authorization` or other forwarded request headers in those systems. +- Apply least-privilege access controls to spend-log views and telemetry backends that may contain request-derived metadata. diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index f5b611a85a..042af2bfb4 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -902,6 +902,7 @@ router_settings: | OTEL_SERVICE_NAME | Service name identifier for OpenTelemetry | OTEL_TRACER_NAME | Tracer name for OpenTelemetry tracing | OTEL_LOGS_EXPORTER | Exporter type for OpenTelemetry logs (e.g., console) +| OTEL_IGNORE_CONTEXT_PROPAGATION | When true, ignore parent span context propagation in OpenTelemetry callbacks | PAGERDUTY_API_KEY | API key for PagerDuty Alerting | PANW_PRISMA_AIRS_API_KEY | API key for PANW Prisma AIRS service | PANW_PRISMA_AIRS_API_BASE | Base URL for PANW Prisma AIRS service diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index ee3c344169..29d28b8c89 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -757,7 +757,7 @@ def _map_traffic_type_to_service_tier(traffic_type: Optional[str]) -> Optional[s """ if traffic_type is None: return None - service_tier = _GEMINI_TRAFFIC_TYPE_TO_SERVICE_TIER.get(traffic_type.upper()) + service_tier = _GEMINI_TRAFFIC_TYPE_TO_SERVICE_TIER.get(str(traffic_type).upper()) return service_tier diff --git a/litellm/integrations/datadog/datadog.py b/litellm/integrations/datadog/datadog.py index 64e0b26a8e..da7e84a025 100644 --- a/litellm/integrations/datadog/datadog.py +++ b/litellm/integrations/datadog/datadog.py @@ -291,7 +291,7 @@ class DataDogLogger( dd_payload = DatadogPayload( ddsource=get_datadog_source(), - ddtags=get_datadog_tags(), + ddtags=",".join(get_datadog_tags()), hostname=get_datadog_hostname(), message=safe_dumps(message_payload), service=get_datadog_service(), @@ -442,7 +442,7 @@ class DataDogLogger( verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload) dd_payload = DatadogPayload( ddsource=get_datadog_source(), - ddtags=get_datadog_tags(standard_logging_object=standard_logging_object), + ddtags=",".join(get_datadog_tags(standard_logging_object=standard_logging_object)), hostname=get_datadog_hostname(), message=json_payload, service=get_datadog_service(), @@ -545,7 +545,7 @@ class DataDogLogger( _dd_message_str = safe_dumps(_payload_dict) _dd_payload = DatadogPayload( ddsource=get_datadog_source(), - ddtags=get_datadog_tags(), + ddtags=",".join(get_datadog_tags()), hostname=get_datadog_hostname(), message=_dd_message_str, service=get_datadog_service(), @@ -587,7 +587,7 @@ class DataDogLogger( _dd_message_str = safe_dumps(_payload_dict) _dd_payload = DatadogPayload( ddsource=get_datadog_source(), - ddtags=get_datadog_tags(), + ddtags=",".join(get_datadog_tags()), hostname=get_datadog_hostname(), message=_dd_message_str, service=get_datadog_service(), @@ -678,7 +678,7 @@ class DataDogLogger( dd_payload = DatadogPayload( ddsource=get_datadog_source(), - ddtags=get_datadog_tags(), + ddtags=",".join(get_datadog_tags()), hostname=get_datadog_hostname(), message=json_payload, service=get_datadog_service(), diff --git a/litellm/integrations/datadog/datadog_handler.py b/litellm/integrations/datadog/datadog_handler.py index 0406f1e5d2..b6bb2b5703 100644 --- a/litellm/integrations/datadog/datadog_handler.py +++ b/litellm/integrations/datadog/datadog_handler.py @@ -38,8 +38,13 @@ def get_datadog_pod_name() -> str: def get_datadog_tags( standard_logging_object: Optional[StandardLoggingPayload] = None, -) -> str: - """Build Datadog tags string used by multiple integrations.""" +) -> List[str]: + """Build Datadog tags as a list of individual tag strings. + + Returns a list of "key:value" strings suitable for Datadog LLM Observability + (which expects tags as an array). For Datadog Logs API (ddtags), join with + comma: ",".join(get_datadog_tags(...)). + """ base_tags = { "env": get_datadog_env(), @@ -66,4 +71,4 @@ def get_datadog_tags( if team_tag: tags.append(f"team:{team_tag}") - return ",".join(tags) + return tags diff --git a/litellm/integrations/datadog/datadog_llm_obs.py b/litellm/integrations/datadog/datadog_llm_obs.py index de6cc02fa3..ec6c00961b 100644 --- a/litellm/integrations/datadog/datadog_llm_obs.py +++ b/litellm/integrations/datadog/datadog_llm_obs.py @@ -203,7 +203,7 @@ class DataDogLLMObsLogger(CustomBatchLogger): type="span", attributes=DDSpanAttributes( ml_app=get_datadog_service(), - tags=[get_datadog_tags()], + tags=get_datadog_tags(), spans=self.log_queue, ), ), @@ -315,7 +315,7 @@ class DataDogLLMObsLogger(CustomBatchLogger): duration=int((end_time - start_time).total_seconds() * 1e9), metrics=metrics, status="error" if error_info else "ok", - tags=[get_datadog_tags(standard_logging_object=standard_logging_payload)], + tags=get_datadog_tags(standard_logging_object=standard_logging_payload), ) apm_trace_id = self._get_apm_trace_id() diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 6e991e6911..ca78e72c69 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -1893,15 +1893,23 @@ class CustomStreamWrapper: "usage", getattr(complete_streaming_response, "usage"), ) - self.cache_streaming_response( - processed_chunk=complete_streaming_response.model_copy( + try: + _cache_copy = complete_streaming_response.model_copy( deep=True - ), + ) + _log_copy = complete_streaming_response.model_copy( + deep=True + ) + except RuntimeError: + _cache_copy = complete_streaming_response.model_copy() + _log_copy = complete_streaming_response.model_copy() + self.cache_streaming_response( + processed_chunk=_cache_copy, cache_hit=cache_hit, ) executor.submit( self.logging_obj.success_handler, - complete_streaming_response.model_copy(deep=True), + _log_copy, None, None, cache_hit, @@ -2113,11 +2121,15 @@ class CustomStreamWrapper: "usage", getattr(complete_streaming_response, "usage"), ) + try: + _copy = complete_streaming_response.model_copy( + deep=True + ) + except RuntimeError: + _copy = complete_streaming_response.model_copy() asyncio.create_task( self.async_cache_streaming_response( - processed_chunk=complete_streaming_response.model_copy( - deep=True - ), + processed_chunk=_copy, cache_hit=cache_hit, ) ) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 181045809f..e7ff57f27e 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -4462,6 +4462,78 @@ "supports_vision": true, "supports_web_search": true }, + "azure/gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "azure", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true, + "supports_web_search": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": false + }, + "azure/gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "azure", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true, + "supports_web_search": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": false + }, "azure/gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, "cache_read_input_token_cost": 1.25e-06, diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 3385e7feef..07309eb57f 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -208,26 +208,52 @@ async def exchange_token_with_server( client_id: str, client_secret: Optional[str], code_verifier: Optional[str], + refresh_token: Optional[str] = None, + scope: Optional[str] = None, ): - if grant_type != "authorization_code": + if grant_type not in ("authorization_code", "refresh_token"): raise HTTPException(status_code=400, detail="Unsupported grant_type") if mcp_server.token_url is None: raise HTTPException(status_code=400, detail="MCP server token url is not set") - proxy_base_url = get_request_base_url(request) - token_data = { - "grant_type": "authorization_code", - "client_id": mcp_server.client_id if mcp_server.client_id else client_id, - "client_secret": mcp_server.client_secret - if mcp_server.client_secret - else client_secret, - "code": code, - "redirect_uri": f"{proxy_base_url}/callback", - } + resolved_client_id = mcp_server.client_id if mcp_server.client_id else client_id + resolved_client_secret = ( + mcp_server.client_secret if mcp_server.client_secret else client_secret + ) - if code_verifier: - token_data["code_verifier"] = code_verifier + if grant_type == "refresh_token": + if not refresh_token: + raise HTTPException( + status_code=400, + detail="refresh_token is required for refresh_token grant", + ) + token_data: dict = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": resolved_client_id, + } + if resolved_client_secret is not None: + token_data["client_secret"] = resolved_client_secret + if scope: + token_data["scope"] = scope + else: + if not code: + raise HTTPException( + status_code=400, + detail="code is required for authorization_code grant", + ) + proxy_base_url = get_request_base_url(request) + token_data = { + "grant_type": "authorization_code", + "client_id": resolved_client_id, + "code": code, + "redirect_uri": f"{proxy_base_url}/callback", + } + if resolved_client_secret is not None: + token_data["client_secret"] = resolved_client_secret + if code_verifier: + token_data["code_verifier"] = code_verifier async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) response = await async_client.post( @@ -375,6 +401,8 @@ async def token_endpoint( client_id: str = Form(...), client_secret: Optional[str] = Form(None), code_verifier: str = Form(None), + refresh_token: Optional[str] = Form(None), + scope: Optional[str] = Form(None), mcp_server_name: Optional[str] = None, ): """ @@ -408,6 +436,8 @@ async def token_endpoint( client_id=client_id, client_secret=client_secret, code_verifier=code_verifier, + refresh_token=refresh_token, + scope=scope, ) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index d31a13e8bc..6cf1f7ed6b 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -29,6 +29,7 @@ from litellm.constants import ( DEFAULT_MAX_RECURSE_DEPTH, EMAIL_BUDGET_ALERT_MAX_SPEND_ALERT_PERCENTAGE, ) +from litellm.litellm_core_utils.dd_tracing import tracer from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.proxy._types import ( RBAC_ROLES, @@ -407,18 +408,19 @@ async def common_checks( # noqa: PLR0915 # 2. If team can call model if _model and team_object: - if not await can_team_access_model( - model=_model, - team_object=team_object, - llm_router=llm_router, - team_model_aliases=valid_token.team_model_aliases if valid_token else None, - ): - raise ProxyException( - message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", - type=ProxyErrorTypes.team_model_access_denied, - param="model", - code=status.HTTP_401_UNAUTHORIZED, - ) + with tracer.trace("litellm.proxy.auth.common_checks.can_team_access_model"): + if not await can_team_access_model( + model=_model, + team_object=team_object, + llm_router=llm_router, + team_model_aliases=valid_token.team_model_aliases if valid_token else None, + ): + raise ProxyException( + message=f"Team not allowed to access model. Team={team_object.team_id}, Model={_model}. Allowed team models = {team_object.models}", + type=ProxyErrorTypes.team_model_access_denied, + param="model", + code=status.HTTP_401_UNAUTHORIZED, + ) # Require trace id for agent keys when agent has require_trace_id_on_calls_by_agent if valid_token is not None and valid_token.agent_id: @@ -443,54 +445,60 @@ async def common_checks( # noqa: PLR0915 ## 2.1 If user can call model (if personal key) if _model and team_object is None and user_object is not None: - await can_user_call_model( - model=_model, - llm_router=llm_router, - user_object=user_object, - ) + with tracer.trace("litellm.proxy.auth.common_checks.can_user_call_model"): + await can_user_call_model( + model=_model, + llm_router=llm_router, + user_object=user_object, + ) # 1.1 - 2.2 - 3.0.2 - 3.0.3: Project checks (blocked, model access, budget) - await _run_project_checks( - project_object=project_object, - _model=_model, - llm_router=llm_router, - skip_budget_checks=skip_budget_checks, - valid_token=valid_token, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.common_checks.run_project_checks"): + await _run_project_checks( + project_object=project_object, + _model=_model, + llm_router=llm_router, + skip_budget_checks=skip_budget_checks, + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + ) # If this is a free model, skip all budget checks if not skip_budget_checks: # 3. If team is in budget - await _team_max_budget_check( - team_object=team_object, - proxy_logging_obj=proxy_logging_obj, - valid_token=valid_token, - ) + with tracer.trace("litellm.proxy.auth.common_checks.team_max_budget_check"): + await _team_max_budget_check( + team_object=team_object, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) # 3.0.5. If team is over soft budget (alert only, doesn't block) - await _team_soft_budget_check( - team_object=team_object, - proxy_logging_obj=proxy_logging_obj, - valid_token=valid_token, - ) + with tracer.trace("litellm.proxy.auth.common_checks.team_soft_budget_check"): + await _team_soft_budget_check( + team_object=team_object, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) # 3.1. If organization is in budget - await _organization_max_budget_check( - valid_token=valid_token, - team_object=team_object, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.common_checks.organization_max_budget_check"): + await _organization_max_budget_check( + valid_token=valid_token, + team_object=team_object, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) - await _tag_max_budget_check( - request_body=request_body, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - valid_token=valid_token, - ) + with tracer.trace("litellm.proxy.auth.common_checks.tag_max_budget_check"): + await _tag_max_budget_check( + request_body=request_body, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + ) # 4. If user is in budget ## 4.1 check personal budget, if personal key @@ -508,14 +516,15 @@ async def common_checks( # noqa: PLR0915 ) ## 4.2 check team member budget, if team key - await _check_team_member_budget( - team_object=team_object, - user_object=user_object, - valid_token=valid_token, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.common_checks.check_team_member_budget"): + await _check_team_member_budget( + team_object=team_object, + user_object=user_object, + valid_token=valid_token, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) # 5. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget if ( @@ -554,19 +563,21 @@ async def common_checks( # noqa: PLR0915 ) # 11. [OPTIONAL] Vector store checks - is the object allowed to access the vector store - await vector_store_access_check( - request_body=request_body, - team_object=team_object, - valid_token=valid_token, - ) + with tracer.trace("litellm.proxy.auth.common_checks.vector_store_access_check"): + await vector_store_access_check( + request_body=request_body, + team_object=team_object, + valid_token=valid_token, + ) # 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path) - await check_tools_allowlist( - request_body=request_body, - valid_token=valid_token, - team_object=team_object, - route=route, - ) + with tracer.trace("litellm.proxy.auth.common_checks.check_tools_allowlist"): + await check_tools_allowlist( + request_body=request_body, + valid_token=valid_token, + team_object=team_object, + route=route, + ) return True diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 044333ac13..30e59f77e6 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -548,13 +548,12 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 custom_auth_api_key: bool = False try: - # get the request body - - await pre_db_read_auth_checks( - request_data=request_data, - request=request, - route=route, - ) + with tracer.trace("litellm.proxy.auth.pre_db_read_auth_checks"): + await pre_db_read_auth_checks( + request_data=request_data, + request=request, + route=route, + ) pass_through_endpoints: Optional[List[dict]] = general_settings.get( "pass_through_endpoints", None ) @@ -588,9 +587,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ### USER-DEFINED AUTH FUNCTION ### if enterprise_custom_auth is not None: - response = await enterprise_custom_auth( - request=request, api_key=api_key, user_custom_auth=user_custom_auth - ) + with tracer.trace("litellm.proxy.auth.enterprise_custom_auth"): + response = await enterprise_custom_auth( + request=request, api_key=api_key, user_custom_auth=user_custom_auth + ) if response is not None and isinstance(response, UserAPIKeyAuth): validated = UserAPIKeyAuth.model_validate(response) validated = await _run_post_custom_auth_checks( @@ -706,18 +706,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Fall through to virtual key checks if do_standard_jwt_auth: - result = await JWTAuthManager.auth_builder( - request_data=request_data, - general_settings=general_settings, - api_key=api_key, - jwt_handler=jwt_handler, - route=route, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - proxy_logging_obj=proxy_logging_obj, - parent_otel_span=parent_otel_span, - request_headers=_safe_get_request_headers(request), - ) + with tracer.trace("litellm.proxy.auth.jwt_auth_builder"): + result = await JWTAuthManager.auth_builder( + request_data=request_data, + general_settings=general_settings, + api_key=api_key, + jwt_handler=jwt_handler, + route=route, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + parent_otel_span=parent_otel_span, + request_headers=_safe_get_request_headers(request), + ) is_proxy_admin = result["is_proxy_admin"] team_id = result["team_id"] @@ -909,15 +910,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 try: end_user_params["end_user_id"] = end_user_id - # get end-user object - _end_user_object = await get_end_user_object( - end_user_id=end_user_id, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - route=route, - ) + with tracer.trace("litellm.proxy.auth.get_end_user_object"): + _end_user_object = await get_end_user_object( + end_user_id=end_user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + route=route, + ) if _end_user_object is not None: end_user_params[ "allowed_model_region" @@ -960,14 +961,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 if valid_token is None: ## Check CACHE try: - valid_token = await get_key_object( - hashed_token=hash_token(api_key), - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - check_cache_only=True, - ) + with tracer.trace("litellm.proxy.auth.get_key_object_check_cache"): + valid_token = await get_key_object( + hashed_token=hash_token(api_key), + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + check_cache_only=True, + ) except Exception: verbose_logger.debug("api key not found in cache.") valid_token = None @@ -1139,13 +1141,14 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 api_key = hash_token(token=api_key) try: - valid_token = await get_key_object( - hashed_token=api_key, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.get_key_object_from_db"): + valid_token = await get_key_object( + hashed_token=api_key, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) except ProxyException as e: if e.code == 401 or e.code == "401": e.message = "Authentication Error, Invalid proxy server token passed. Received API Key = {}, Key Hash (Token) ={}. Unable to find token in cache or `LiteLLM_VerificationTokenTable`".format( @@ -1233,14 +1236,15 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Check 2. If user_id for this token is in budget - done in common_checks() if valid_token.user_id is not None: try: - user_obj = await get_user_object( - user_id=valid_token.user_id, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - user_id_upsert=False, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.get_user_object"): + user_obj = await get_user_object( + user_id=valid_token.user_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + user_id_upsert=False, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) except Exception as e: verbose_logger.debug( "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to get user from db/cache. Setting user_obj to None. Exception received - {}".format( @@ -1329,71 +1333,73 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) if not skip_budget_checks: - # Check 4. Token Spend is under budget - if RouteChecks.is_llm_api_route(route=route): - await _virtual_key_max_budget_check( + with tracer.trace("litellm.proxy.auth.budget_checks"): + # Check 4. Token Spend is under budget + if RouteChecks.is_llm_api_route(route=route): + await _virtual_key_max_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, + ) + + # Check 5. Max Budget Alert Check + await _virtual_key_max_budget_alert_check( valid_token=valid_token, proxy_logging_obj=proxy_logging_obj, user_obj=user_obj, ) - # Check 5. Max Budget Alert Check - await _virtual_key_max_budget_alert_check( - valid_token=valid_token, - proxy_logging_obj=proxy_logging_obj, - user_obj=user_obj, - ) - - # Check 6. Soft Budget Check - await _virtual_key_soft_budget_check( - valid_token=valid_token, - proxy_logging_obj=proxy_logging_obj, - user_obj=user_obj, - ) - - # Check 5. Token Model Spend is under Model budget - max_budget_per_model = valid_token.model_max_budget - current_model = request_data.get("model", None) - - if ( - max_budget_per_model is not None - and isinstance(max_budget_per_model, dict) - and len(max_budget_per_model) > 0 - and prisma_client is not None - and current_model is not None - and valid_token.token is not None - ): - ## GET THE SPEND FOR THIS MODEL - await model_max_budget_limiter.is_key_within_model_budget( - user_api_key_dict=valid_token, - model=current_model, + # Check 6. Soft Budget Check + await _virtual_key_soft_budget_check( + valid_token=valid_token, + proxy_logging_obj=proxy_logging_obj, + user_obj=user_obj, ) - # Check 5b. End-user model max budget - end_user_mmb = valid_token.end_user_model_max_budget - if ( - end_user_mmb is not None - and isinstance(end_user_mmb, dict) - and len(end_user_mmb) > 0 - and current_model is not None - and valid_token.end_user_id is not None - ): - await model_max_budget_limiter.is_end_user_within_model_budget( - end_user_id=valid_token.end_user_id, - end_user_model_max_budget=end_user_mmb, - model=current_model, - ) + # Check 5. Token Model Spend is under Model budget + max_budget_per_model = valid_token.model_max_budget + current_model = request_data.get("model", None) + + if ( + max_budget_per_model is not None + and isinstance(max_budget_per_model, dict) + and len(max_budget_per_model) > 0 + and prisma_client is not None + and current_model is not None + and valid_token.token is not None + ): + ## GET THE SPEND FOR THIS MODEL + await model_max_budget_limiter.is_key_within_model_budget( + user_api_key_dict=valid_token, + model=current_model, + ) + + # Check 5b. End-user model max budget + end_user_mmb = valid_token.end_user_model_max_budget + if ( + end_user_mmb is not None + and isinstance(end_user_mmb, dict) + and len(end_user_mmb) > 0 + and current_model is not None + and valid_token.end_user_id is not None + ): + await model_max_budget_limiter.is_end_user_within_model_budget( + end_user_id=valid_token.end_user_id, + end_user_model_max_budget=end_user_mmb, + model=current_model, + ) # Check 6: Additional Common Checks across jwt + key auth if valid_token.team_id is not None: try: - _team_obj = await get_team_object( - team_id=valid_token.team_id, - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - parent_otel_span=parent_otel_span, - proxy_logging_obj=proxy_logging_obj, - ) + with tracer.trace("litellm.proxy.auth.get_team_object"): + _team_obj = await get_team_object( + team_id=valid_token.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) except HTTPException: _team_obj = LiteLLM_TeamTableCachedObj( team_id=valid_token.team_id, @@ -1431,11 +1437,12 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 litellm.max_budget > 0 and prisma_client is not None ): # user set proxy max budget cache_key = "{}:spend".format(litellm_proxy_admin_name) - global_proxy_spend = await _fetch_global_spend_with_event_coordination( - cache_key=cache_key, - user_api_key_cache=user_api_key_cache, - prisma_client=prisma_client, - ) + with tracer.trace("litellm.proxy.auth.get_global_proxy_spend"): + global_proxy_spend = await _fetch_global_spend_with_event_coordination( + cache_key=cache_key, + user_api_key_cache=user_api_key_cache, + prisma_client=prisma_client, + ) if global_proxy_spend is not None: call_info = CallInfo( @@ -1452,21 +1459,22 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 user_info=call_info, ) ) - _ = await common_checks( - request=request, - request_body=request_data, - team_object=_team_obj, - user_object=user_obj, - end_user_object=_end_user_object, - general_settings=general_settings, - global_proxy_spend=global_proxy_spend, - route=route, - llm_router=llm_router, - proxy_logging_obj=proxy_logging_obj, - valid_token=valid_token, - skip_budget_checks=skip_budget_checks, - project_object=_project_obj, - ) + with tracer.trace("litellm.proxy.auth.common_checks"): + _ = await common_checks( + request=request, + request_body=request_data, + team_object=_team_obj, + user_object=user_obj, + end_user_object=_end_user_object, + general_settings=general_settings, + global_proxy_spend=global_proxy_spend, + route=route, + llm_router=llm_router, + proxy_logging_obj=proxy_logging_obj, + valid_token=valid_token, + skip_budget_checks=skip_budget_checks, + project_object=_project_obj, + ) # Token passed all checks if valid_token is None: raise HTTPException(401, detail="Invalid API key") diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 72765aab7d..e5a31c3671 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -1260,7 +1260,9 @@ class ProxyBaseLLMRequestProcessing: custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( user_api_key_dict=user_api_key_dict, call_id=( - _litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None + _litellm_logging_obj.litellm_call_id + if _litellm_logging_obj + else self.data.get("litellm_call_id") ), model_id=model_id, version=version, diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 3e5b729cea..f29a721ede 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -1399,6 +1399,8 @@ if MCP_AVAILABLE: client_id: Optional[str] = Form(None), client_secret: Optional[str] = Form(None), code_verifier: Optional[str] = Form(None), + refresh_token: Optional[str] = Form(None), + scope: Optional[str] = Form(None), ): mcp_server = _get_cached_temporary_mcp_server_or_404(server_id) resolved_client_id = mcp_server.client_id or client_id or "" @@ -1422,6 +1424,8 @@ if MCP_AVAILABLE: client_id=resolved_client_id, client_secret=client_secret, code_verifier=code_verifier, + refresh_token=refresh_token, + scope=scope, ) @router.post( diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 181045809f..879dd42be4 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -4462,6 +4462,78 @@ "supports_vision": true, "supports_web_search": true }, + "azure/gpt-5.4-mini": { + "cache_read_input_token_cost": 7.5e-08, + "input_cost_per_token": 7.5e-07, + "litellm_provider": "azure", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 4.5e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true, + "supports_web_search": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": false + }, + "azure/gpt-5.4-nano": { + "cache_read_input_token_cost": 2e-08, + "input_cost_per_token": 2e-07, + "litellm_provider": "azure", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.25e-06, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/batch", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_service_tier": true, + "supports_vision": true, + "supports_web_search": true, + "supports_none_reasoning_effort": false, + "supports_xhigh_reasoning_effort": false + }, "azure/gpt-image-1": { "cache_read_input_image_token_cost": 2.5e-06, "cache_read_input_token_cost": 1.25e-06, @@ -37032,5 +37104,157 @@ "supports_reasoning": true, "supports_response_schema": true, "supports_tool_choice": true + }, + "volcengine/doubao-seed-2-0-pro-260215": { + "litellm_provider": "volcengine", + "max_input_tokens": 256000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "source": "https://www.volcengine.com/docs/82379/1330310", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": false, + "supports_vision": true, + "tiered_pricing": [ + { + "input_cost_per_token": 4.6e-07, + "output_cost_per_token": 2.3e-06, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 7e-07, + "output_cost_per_token": 3.5e-06, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 1.4e-06, + "output_cost_per_token": 7e-06, + "range": [ + 128000.0, + 256000.0 + ] + } + ] + }, + "volcengine/doubao-seed-2-0-lite-260215": { + "litellm_provider": "volcengine", + "max_input_tokens": 256000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "source": "https://www.volcengine.com/docs/82379/1330310", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": false, + "supports_vision": true, + "tiered_pricing": [ + { + "input_cost_per_token": 8.7e-08, + "output_cost_per_token": 5.2e-07, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 1.3e-07, + "output_cost_per_token": 7.8e-07, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 2.6e-07, + "output_cost_per_token": 1.6e-06, + "range": [ + 128000.0, + 256000.0 + ] + } + ] + }, + "volcengine/doubao-seed-2-0-mini-260215": { + "litellm_provider": "volcengine", + "max_input_tokens": 256000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "source": "https://www.volcengine.com/docs/82379/1330310", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": false, + "supports_vision": true, + "tiered_pricing": [ + { + "input_cost_per_token": 2.9e-08, + "output_cost_per_token": 2.9e-07, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 5.8e-08, + "output_cost_per_token": 5.8e-07, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 1.2e-07, + "output_cost_per_token": 1.2e-06, + "range": [ + 128000.0, + 256000.0 + ] + } + ] + }, + "volcengine/doubao-seed-2-0-code-preview-260215": { + "litellm_provider": "volcengine", + "max_input_tokens": 256000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "source": "https://www.volcengine.com/docs/82379/1330310", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": false, + "supports_vision": true, + "tiered_pricing": [ + { + "input_cost_per_token": 4.6e-07, + "output_cost_per_token": 2.3e-06, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 7e-07, + "output_cost_per_token": 3.5e-06, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 1.4e-06, + "output_cost_per_token": 7e-06, + "range": [ + 128000.0, + 256000.0 + ] + } + ] } } diff --git a/tests/logging_callback_tests/test_datadog.py b/tests/logging_callback_tests/test_datadog.py index fc4b3ff3cf..4cfd4a6cc9 100644 --- a/tests/logging_callback_tests/test_datadog.py +++ b/tests/logging_callback_tests/test_datadog.py @@ -593,7 +593,7 @@ def test_datadog_static_methods(): # Test tags format with default values assert ( "env:unknown,service:litellm-server,version:unknown,HOSTNAME:" - in get_datadog_tags() + in ",".join(get_datadog_tags()) ) # Test with custom environment variables @@ -631,7 +631,7 @@ def test_datadog_static_methods(): # Test tags format with custom values expected_custom_tags = "env:production,service:custom-service,version:1.0.0,HOSTNAME:test-host,POD_NAME:pod-123" print("DataDogLogger._get_datadog_tags()", get_datadog_tags()) - assert get_datadog_tags() == expected_custom_tags + assert ",".join(get_datadog_tags()) == expected_custom_tags @pytest.mark.asyncio @@ -672,11 +672,11 @@ def test_get_datadog_tags(): """Test the _get_datadog_tags static method with various inputs""" # Test with no standard_logging_object and default env vars base_tags = get_datadog_tags() - assert "env:" in base_tags - assert "service:" in base_tags - assert "version:" in base_tags - assert "POD_NAME:" in base_tags - assert "HOSTNAME:" in base_tags + assert any("env:" in t for t in base_tags) + assert any("service:" in t for t in base_tags) + assert any("version:" in t for t in base_tags) + assert any("POD_NAME:" in t for t in base_tags) + assert any("HOSTNAME:" in t for t in base_tags) # Test with custom env vars test_env = { @@ -705,12 +705,12 @@ def test_get_datadog_tags(): # Test with empty request_tags standard_logging_obj["request_tags"] = [] tags_empty_request = get_datadog_tags(standard_logging_obj) - assert "request_tag:" not in tags_empty_request + assert not any(t.startswith("request_tag:") for t in tags_empty_request) # Test with None request_tags standard_logging_obj["request_tags"] = None tags_none_request = get_datadog_tags(standard_logging_obj) - assert "request_tag:" not in tags_none_request + assert not any(t.startswith("request_tag:") for t in tags_none_request) @pytest.mark.asyncio diff --git a/tests/test_litellm/integrations/datadog/test_datadog_tags_regression.py b/tests/test_litellm/integrations/datadog/test_datadog_tags_regression.py index 3f1d2be413..cc9eae7a37 100644 --- a/tests/test_litellm/integrations/datadog/test_datadog_tags_regression.py +++ b/tests/test_litellm/integrations/datadog/test_datadog_tags_regression.py @@ -44,7 +44,7 @@ class TestDatadogTagsRegression: assert "env:test-env" in tags_legacy assert "service:test-service" in tags_legacy # Verify NO team tag (should not invent one) - assert "team:" not in tags_legacy + assert not any(t.startswith("team:") for t in tags_legacy) # Case 2: New feature (team info provided) payload_with_team = StandardLoggingPayload( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index 700ba86b10..954f2703e3 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -1666,3 +1666,141 @@ async def test_oauth_authorize_prefers_request_scope_over_server_config(): redirect_url = response.headers["location"] assert "scope=custom_scope1+custom_scope2" in redirect_url or "scope=custom_scope1%20custom_scope2" in redirect_url assert "default_scope" not in redirect_url + + +@pytest.mark.asyncio +async def test_token_endpoint_refresh_token_grant(): + """Test that token endpoint supports refresh_token grant type.""" + try: + from fastapi import Request + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + token_endpoint, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.proxy._types import MCPTransport + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + # Clear registry + global_mcp_server_manager.registry.clear() + + # Create mock OAuth2 server + oauth2_server = MCPServer( + server_id="google_mcp", + name="google_mcp", + server_name="google_mcp", + alias="google_mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="test_client_id", + client_secret="test_secret", + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["openid", "email"], + ) + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://proxy.litellm.example/" + mock_request.headers = {} + + # Mock httpx client response with new tokens + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "new_access_token", + "token_type": "Bearer", + "expires_in": 3599, + "refresh_token": "new_refresh_token", + } + mock_response.raise_for_status = MagicMock() + + mock_async_client = MagicMock() + mock_async_client.post = AsyncMock(return_value=mock_response) + + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = mock_async_client + + response = await token_endpoint( + request=mock_request, + grant_type="refresh_token", + code=None, + redirect_uri=None, + client_id="test_client_id", + mcp_server_name="google_mcp", + client_secret="test_secret", + refresh_token="rt-test", + scope="openid email", + ) + + # Verify the POST was called with refresh_token grant data + mock_async_client.post.assert_called_once() + call_args = mock_async_client.post.call_args + + assert call_args[1]["data"]["grant_type"] == "refresh_token" + assert call_args[1]["data"]["refresh_token"] == "rt-test" + assert call_args[1]["data"]["client_id"] == "test_client_id" + assert call_args[1]["data"]["client_secret"] == "test_secret" + assert call_args[1]["data"]["scope"] == "openid email" + + # Verify response contains the new tokens + import json + + token_data = json.loads(response.body) + assert token_data["access_token"] == "new_access_token" + assert token_data["refresh_token"] == "new_refresh_token" + + +@pytest.mark.asyncio +async def test_token_endpoint_authorization_code_missing_code(): + """Test that authorization_code grant rejects missing code param.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + exchange_token_with_server, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.proxy._types import MCPTransport + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + + server = MCPServer( + server_id="test_server", + name="test_server", + server_name="test_server", + alias="test_server", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="cid", + token_url="https://example.com/token", + ) + global_mcp_server_manager.registry[server.server_id] = server + + mock_request = MagicMock() + mock_request.base_url = "https://proxy.example/" + mock_request.headers = {} + + with pytest.raises(HTTPException) as exc_info: + await exchange_token_with_server( + request=mock_request, + mcp_server=server, + grant_type="authorization_code", + code=None, + redirect_uri="https://example.com/cb", + client_id="cid", + client_secret=None, + code_verifier=None, + ) + assert exc_info.value.status_code == 400 + assert "code is required" in str(exc_info.value.detail) diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index eeaeb49832..77ac3a040a 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -1519,6 +1519,8 @@ class TestTemporaryMCPSessionEndpoints: client_id="client", client_secret="secret", code_verifier="verifier", + refresh_token=None, + scope=None, ) assert result is exchange_response @@ -1532,6 +1534,56 @@ class TestTemporaryMCPSessionEndpoints: client_id="client", client_secret="secret", code_verifier="verifier", + refresh_token=None, + scope=None, + ) + + @pytest.mark.asyncio + async def test_mcp_token_proxies_refresh_token_grant(self): + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + mcp_token, + ) + + request = MagicMock() + server = generate_mock_mcp_server_config_record(server_id="server-1") + exchange_response = {"access_token": "new-token", "refresh_token": "new-rt"} + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404", + return_value=server, + ) as get_server, + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.exchange_token_with_server", + AsyncMock(return_value=exchange_response), + ) as exchange_mock, + ): + result = await mcp_token( + request=request, + server_id="server-1", + grant_type="refresh_token", + code=None, + redirect_uri=None, + client_id="client", + client_secret="secret", + code_verifier=None, + refresh_token="rt-123", + scope=None, + ) + + assert result is exchange_response + get_server.assert_called_once_with("server-1") + exchange_mock.assert_awaited_once_with( + request=request, + mcp_server=server, + grant_type="refresh_token", + code=None, + redirect_uri=None, + client_id="client", + client_secret="secret", + code_verifier=None, + refresh_token="rt-123", + scope=None, ) @pytest.mark.asyncio From b34231dc95ce686424592d9a063112ef04b103c3 Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Thu, 19 Mar 2026 18:20:42 +0100 Subject: [PATCH 179/539] refactor(proxy): reuse unified_guardrail singleton, rename shadowing variable Reuse the module-level unified_guardrail singleton from proxy/utils.py in _run_deferred_stream_guardrails instead of creating a new instance per call, matching the pattern used by post_call_success_hook. Rename local variable _has_post_call_guardrails to _post_call_guardrails_active to avoid shadowing the static method name. --- litellm/proxy/common_request_processing.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 0e3e89c531..1517ee6d9d 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -931,7 +931,7 @@ class ProxyBaseLLMRequestProcessing: # Defer async logging when post-call guardrails are configured so the # StandardLoggingPayload is built after guardrails write to metadata. # Cache the result to avoid scanning litellm.callbacks twice. - _has_post_call_guardrails = self._has_post_call_guardrails() + _post_call_guardrails_active = self._has_post_call_guardrails() # Non-streaming: defer the create_task in wrapper_async so the # SLP is built after guardrails write to metadata. Streaming @@ -943,7 +943,7 @@ class ProxyBaseLLMRequestProcessing: # so _enqueue_deferred_logging is never stored — the finally # block is a no-op. The CSW path handles this correctly via # _on_deferred_stream_complete, which fires its own logging. - if _has_post_call_guardrails and not self._is_streaming_request( + if _post_call_guardrails_active and not self._is_streaming_request( data=self.data, is_streaming_request=is_streaming_request ): logging_obj._defer_async_logging = True # type: ignore @@ -1057,7 +1057,7 @@ class ProxyBaseLLMRequestProcessing: CustomStreamWrapper, ) - if _has_post_call_guardrails and isinstance( + if _post_call_guardrails_active and isinstance( response, CustomStreamWrapper ): # Intentionally a live reference (not a copy) — mirrors @@ -1338,15 +1338,14 @@ class ProxyBaseLLMRequestProcessing: implementation directly rather than reimplementing the closure. """ from litellm.litellm_core_utils.thread_pool_executor import executor - from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( - UnifiedLLMGuardrails, - ) from litellm.proxy.proxy_server import llm_router as _global_llm_router - from litellm.proxy.utils import _check_and_merge_model_level_guardrails + from litellm.proxy.utils import ( + _check_and_merge_model_level_guardrails, + unified_guardrail as _unified_guardrail, + ) _response = assembled_response try: - _unified_guardrail = UnifiedLLMGuardrails() guardrail_data = _check_and_merge_model_level_guardrails( data=captured_data, llm_router=_global_llm_router ) From 6d0763b8ba30a71ba1cd7fec5643f374be85d72c Mon Sep 17 00:00:00 2001 From: Jonathan Barazany Date: Thu, 19 Mar 2026 19:28:05 +0200 Subject: [PATCH 180/539] fix: short-circuit websearch for non-Anthropic providers (github_copilot) For providers like github_copilot that don't natively support web search, Claude Code's search sub-conversations were falling through to the adapter path which strips the web_search tool and has no stream reconversion. Instead of routing search requests through the full LLM pipeline, detect web-search-only requests early (all tools are web_search, simple prompt) and execute the search directly via Tavily/Perplexity, returning a synthetic Anthropic response. No adapter, no backend LLM call needed. Fixes #21733 --- .../websearch_interception/handler.py | 114 ++++++ .../messages/handler.py | 61 ++++ .../test_websearch_short_circuit.py | 332 ++++++++++++++++++ 3 files changed, 507 insertions(+) create mode 100644 tests/test_litellm/integrations/websearch_interception/test_websearch_short_circuit.py diff --git a/litellm/integrations/websearch_interception/handler.py b/litellm/integrations/websearch_interception/handler.py index 2541a0bd7a..59510dd5d0 100644 --- a/litellm/integrations/websearch_interception/handler.py +++ b/litellm/integrations/websearch_interception/handler.py @@ -67,6 +67,120 @@ class WebSearchInterceptionLogger(CustomLogger): self.search_tool_name = search_tool_name self._request_has_websearch = False # Track if current request has web search + async def try_short_circuit_search( + self, + model: str, + messages: List[Dict], + tools: Optional[List[Dict]], + custom_llm_provider: Optional[str], + ) -> Optional[Dict[str, Any]]: + """ + Short-circuit web-search-only requests by executing the search directly. + + Claude Code sends web search as a separate, standalone /v1/messages + request with a simple prompt and only web_search tool(s). For providers + that don't natively support web search (e.g. github_copilot), there is + no need to route this through the backend LLM — we can detect the + pattern, execute the search via Tavily/Perplexity, and return a + synthetic Anthropic response immediately. + + Args: + model: Model name from the request + messages: Messages list from the request + tools: Tools list from the request + custom_llm_provider: Provider name + + Returns: + An AnthropicMessagesResponse dict if short-circuited, or None to + continue normal processing. + """ + if not tools: + return None + + # Check if provider is in enabled list + provider_str = custom_llm_provider or "" + if ( + self.enabled_providers is not None + and provider_str not in self.enabled_providers + ): + return None + + # All tools must be web search tools + if not all(is_web_search_tool(t) for t in tools): + return None + + # Extract search query from the last user message + query = self._extract_search_query(messages) + if not query: + return None + + verbose_logger.debug( + "WebSearchInterception: Short-circuit search detected " + f"(provider={provider_str}, query='{query}')" + ) + + # Execute search + try: + search_result_text = await self._execute_search(query) + except Exception as e: + verbose_logger.error( + f"WebSearchInterception: Short-circuit search failed: {e}" + ) + search_result_text = f"Search failed: {e}" + + # Build synthetic Anthropic response + from uuid import uuid4 + + response: Dict[str, Any] = { + "id": f"msg_{uuid4().hex[:24]}", + "type": "message", + "role": "assistant", + "model": model, + "content": [{"type": "text", "text": search_result_text}], + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + } + + verbose_logger.debug( + "WebSearchInterception: Short-circuit search completed, " + f"returning synthetic response ({len(search_result_text)} chars)" + ) + return response + + @staticmethod + def _extract_search_query(messages: List[Dict]) -> Optional[str]: + """ + Extract the search query from messages. + + Looks at the last user message content for the search query text. + """ + if not messages: + return None + + # Find the last user message + for msg in reversed(messages): + if msg.get("role") != "user": + continue + + content = msg.get("content") + if isinstance(content, str): + return content.strip() or None + + # Handle list-of-blocks content + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "").strip() + if text: + return text + elif isinstance(block, str): + text = block.strip() + if text: + return text + + return None + async def async_pre_call_deployment_hook( self, kwargs: Dict[str, Any], call_type: Optional[Any] ) -> Optional[dict]: diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py index 1b5f03ec72..52373a9c32 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -114,6 +114,54 @@ async def _execute_pre_request_hooks( return request_kwargs +async def _try_websearch_short_circuit( + model: str, + messages: List[Dict], + tools: Optional[List[Dict]], + custom_llm_provider: Optional[str], + stream: Optional[bool], +) -> Optional[Union[AnthropicMessagesResponse, AsyncIterator]]: + """ + Attempt to short-circuit a web-search-only request. + + Claude Code sends web search as a separate, standalone /v1/messages + request. For providers that don't natively support web search (e.g. + github_copilot), we detect this pattern, execute the search via + Tavily/Perplexity, and return a synthetic Anthropic response — bypassing + the backend LLM entirely. + + Returns the synthetic response if short-circuited, or None to continue + normal processing. + """ + if not litellm.callbacks: + return None + + from litellm.integrations.websearch_interception.handler import ( + WebSearchInterceptionLogger, + ) + + for callback in litellm.callbacks: + if not isinstance(callback, WebSearchInterceptionLogger): + continue + + response = await callback.try_short_circuit_search( + model=model, + messages=messages, + tools=tools, + custom_llm_provider=custom_llm_provider, + ) + if response is not None: + if stream: + from litellm.llms.anthropic.experimental_pass_through.messages.fake_stream_iterator import ( + FakeAnthropicMessagesStreamIterator, + ) + + return FakeAnthropicMessagesStreamIterator(response) + return response + + return None + + @client async def anthropic_messages( max_tokens: int, @@ -156,6 +204,19 @@ async def anthropic_messages( # Merge back any other modifications kwargs.update(request_kwargs) + # Short-circuit web-search-only requests: detect the pattern, execute + # search directly via Tavily/Perplexity, and return a synthetic response + # without ever touching the backend LLM or the adapter path. + short_circuit_response = await _try_websearch_short_circuit( + model=model, + messages=messages, + tools=tools, + custom_llm_provider=custom_llm_provider, + stream=stream, + ) + if short_circuit_response is not None: + return short_circuit_response + loop = asyncio.get_event_loop() kwargs["is_async"] = True diff --git a/tests/test_litellm/integrations/websearch_interception/test_websearch_short_circuit.py b/tests/test_litellm/integrations/websearch_interception/test_websearch_short_circuit.py new file mode 100644 index 0000000000..1129ee98ef --- /dev/null +++ b/tests/test_litellm/integrations/websearch_interception/test_websearch_short_circuit.py @@ -0,0 +1,332 @@ +""" +Unit tests for WebSearch Short-Circuit + +Tests the short-circuit path that detects web-search-only /v1/messages requests +and executes the search directly without routing through the backend LLM. +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from litellm.integrations.websearch_interception.handler import ( + WebSearchInterceptionLogger, +) + + +# --------------------------------------------------------------------------- +# Detection tests +# --------------------------------------------------------------------------- + + +class TestTryShortCircuitSearch: + """Tests for WebSearchInterceptionLogger.try_short_circuit_search""" + + @pytest.mark.asyncio + async def test_short_circuits_single_web_search_tool(self): + """Single web_search_20250305 tool → short-circuit fires""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + with patch.object( + logger, "_execute_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = "Title: Result\nURL: https://example.com\nSnippet: test" + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Search for Claude Code releases"}], + tools=[{"type": "web_search_20250305", "name": "web_search", "max_uses": 8}], + custom_llm_provider="github_copilot", + ) + + assert result is not None + assert result["type"] == "message" + assert result["role"] == "assistant" + assert result["stop_reason"] == "end_turn" + assert len(result["content"]) == 1 + assert result["content"][0]["type"] == "text" + assert "Result" in result["content"][0]["text"] + mock_search.assert_called_once_with("Search for Claude Code releases") + + @pytest.mark.asyncio + async def test_does_not_short_circuit_mixed_tools(self): + """Mix of web_search and other tools → NOT short-circuited""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Do something"}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 8}, + {"name": "Read", "description": "Read a file", "input_schema": {}}, + ], + custom_llm_provider="github_copilot", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_does_not_short_circuit_no_tools(self): + """No tools → NOT short-circuited""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Hello"}], + tools=None, + custom_llm_provider="github_copilot", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_does_not_short_circuit_empty_tools(self): + """Empty tools list → NOT short-circuited""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Hello"}], + tools=[], + custom_llm_provider="github_copilot", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_does_not_short_circuit_wrong_provider(self): + """Provider not in enabled_providers → NOT short-circuited""" + logger = WebSearchInterceptionLogger(enabled_providers=["bedrock"]) + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Search for something"}], + tools=[{"type": "web_search_20250305", "name": "web_search", "max_uses": 8}], + custom_llm_provider="github_copilot", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_does_not_short_circuit_no_messages(self): + """Empty messages → NOT short-circuited""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[], + tools=[{"type": "web_search_20250305", "name": "web_search", "max_uses": 8}], + custom_llm_provider="github_copilot", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_search_failure_returns_error_text(self): + """Search failure → response with error message, not exception""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + with patch.object( + logger, "_execute_search", new_callable=AsyncMock + ) as mock_search: + mock_search.side_effect = RuntimeError("Tavily API error") + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Search for something"}], + tools=[{"type": "web_search_20250305", "name": "web_search", "max_uses": 8}], + custom_llm_provider="github_copilot", + ) + + assert result is not None + assert "Search failed" in result["content"][0]["text"] + + @pytest.mark.asyncio + async def test_response_has_valid_structure(self): + """Synthetic response has all required AnthropicMessagesResponse fields""" + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + + with patch.object( + logger, "_execute_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = "search results here" + + result = await logger.try_short_circuit_search( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "Search query"}], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + custom_llm_provider="github_copilot", + ) + + assert result is not None + # Required fields + assert "id" in result + assert result["id"].startswith("msg_") + assert result["type"] == "message" + assert result["role"] == "assistant" + assert result["model"] == "github_copilot/claude-sonnet-4" + assert result["stop_reason"] == "end_turn" + assert result["stop_sequence"] is None + assert "usage" in result + assert "content" in result + + +# --------------------------------------------------------------------------- +# Query extraction tests +# --------------------------------------------------------------------------- + + +class TestExtractSearchQuery: + """Tests for WebSearchInterceptionLogger._extract_search_query""" + + def test_string_content(self): + messages = [{"role": "user", "content": "Search for Python 3.14 features"}] + assert ( + WebSearchInterceptionLogger._extract_search_query(messages) + == "Search for Python 3.14 features" + ) + + def test_list_content_with_text_block(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Perform a web search for latest news"}, + ], + } + ] + assert ( + WebSearchInterceptionLogger._extract_search_query(messages) + == "Perform a web search for latest news" + ) + + def test_takes_last_user_message(self): + messages = [ + {"role": "user", "content": "First message"}, + {"role": "assistant", "content": "Response"}, + {"role": "user", "content": "Second message"}, + ] + assert ( + WebSearchInterceptionLogger._extract_search_query(messages) == "Second message" + ) + + def test_empty_messages(self): + assert WebSearchInterceptionLogger._extract_search_query([]) is None + + def test_no_user_messages(self): + messages = [{"role": "assistant", "content": "Hello"}] + assert WebSearchInterceptionLogger._extract_search_query(messages) is None + + def test_empty_content(self): + messages = [{"role": "user", "content": ""}] + assert WebSearchInterceptionLogger._extract_search_query(messages) is None + + +# --------------------------------------------------------------------------- +# Integration with entry point +# --------------------------------------------------------------------------- + + +class TestShortCircuitEntryPoint: + """Tests for _try_websearch_short_circuit in the /v1/messages handler""" + + @pytest.mark.asyncio + async def test_returns_none_when_no_callbacks(self): + """No callbacks configured → returns None""" + from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + _try_websearch_short_circuit, + ) + + with patch("litellm.callbacks", []): + result = await _try_websearch_short_circuit( + model="test", + messages=[], + tools=[], + custom_llm_provider="github_copilot", + stream=False, + ) + assert result is None + + @pytest.mark.asyncio + async def test_returns_dict_when_not_streaming(self): + """Non-streaming short-circuit → returns dict""" + from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + _try_websearch_short_circuit, + ) + + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + with patch.object( + logger, "_execute_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = "results" + with patch("litellm.callbacks", [logger]): + result = await _try_websearch_short_circuit( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "search query"}], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + custom_llm_provider="github_copilot", + stream=False, + ) + + assert isinstance(result, dict) + assert result["content"][0]["text"] == "results" + + @pytest.mark.asyncio + async def test_returns_stream_iterator_when_streaming(self): + """Streaming short-circuit → returns FakeAnthropicMessagesStreamIterator""" + from litellm.llms.anthropic.experimental_pass_through.messages.fake_stream_iterator import ( + FakeAnthropicMessagesStreamIterator, + ) + from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + _try_websearch_short_circuit, + ) + + logger = WebSearchInterceptionLogger(enabled_providers=["github_copilot"]) + with patch.object( + logger, "_execute_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = "streaming results" + with patch("litellm.callbacks", [logger]): + result = await _try_websearch_short_circuit( + model="github_copilot/claude-sonnet-4", + messages=[{"role": "user", "content": "search query"}], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + custom_llm_provider="github_copilot", + stream=True, + ) + + assert isinstance(result, FakeAnthropicMessagesStreamIterator) + + # Verify stream produces valid SSE events + chunks = [] + async for chunk in result: + chunks.append(chunk) + + assert len(chunks) > 0 + # First chunk should be message_start + assert b"event: message_start" in chunks[0] + # Last chunk should be message_stop + assert b"event: message_stop" in chunks[-1] + # Should contain the search results text + all_data = b"".join(chunks) + assert b"streaming results" in all_data + + @pytest.mark.asyncio + async def test_skips_non_websearch_callbacks(self): + """Non-WebSearchInterceptionLogger callbacks are ignored""" + from unittest.mock import MagicMock + + from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + _try_websearch_short_circuit, + ) + + other_callback = MagicMock() + with patch("litellm.callbacks", [other_callback]): + result = await _try_websearch_short_circuit( + model="test", + messages=[{"role": "user", "content": "search"}], + tools=[{"type": "web_search_20250305", "name": "web_search"}], + custom_llm_provider="github_copilot", + stream=False, + ) + assert result is None From 0b07f628ffb2d4fe8d2c271fe578a9a2796382ec Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Thu, 19 Mar 2026 10:30:03 -0700 Subject: [PATCH 181/539] [Test] UI: Add vitest coverage for 10 previously untested components Add unit tests for: - SimpleToolCallBlock, SimpleMessageBlock, CollapsibleMessage, HistoryTree (log details drawer) - OnboardingForm (onboarding flow) - TeamsHeaderTabs, TeamsTable (teams page) - transform_key_info, filter_helpers (key/team helpers) - queryKeysFactory (query key generation utility) 47 new tests covering conditional rendering, user interactions, data transformation, and error handling. Co-Authored-By: Claude Opus 4.6 --- .../hooks/common/queryKeysFactory.test.ts | 34 +++++ .../teams/components/TeamsHeaderTabs.test.tsx | 54 ++++++++ .../components/TeamsTable/TeamsTable.test.tsx | 129 ++++++++++++++++++ .../app/onboarding/OnboardingForm.test.tsx | 95 +++++++++++++ .../key_team_helpers/filter_helpers.test.ts | 90 ++++++++++++ .../transform_key_info.test.ts | 62 +++++++++ .../CollapsibleMessage.test.tsx | 54 ++++++++ .../LogDetailsDrawer/HistoryTree.test.tsx | 50 +++++++ .../SimpleMessageBlock.test.tsx | 55 ++++++++ .../SimpleToolCallBlock.test.tsx | 51 +++++++ 10 files changed, 674 insertions(+) create mode 100644 ui/litellm-dashboard/src/app/(dashboard)/hooks/common/queryKeysFactory.test.ts create mode 100644 ui/litellm-dashboard/src/app/(dashboard)/teams/components/TeamsHeaderTabs.test.tsx create mode 100644 ui/litellm-dashboard/src/app/(dashboard)/teams/components/TeamsTable/TeamsTable.test.tsx create mode 100644 ui/litellm-dashboard/src/app/onboarding/OnboardingForm.test.tsx create mode 100644 ui/litellm-dashboard/src/components/key_team_helpers/filter_helpers.test.ts create mode 100644 ui/litellm-dashboard/src/components/key_team_helpers/transform_key_info.test.ts create mode 100644 ui/litellm-dashboard/src/components/view_logs/LogDetailsDrawer/CollapsibleMessage.test.tsx create mode 100644 ui/litellm-dashboard/src/components/view_logs/LogDetailsDrawer/HistoryTree.test.tsx create mode 100644 ui/litellm-dashboard/src/components/view_logs/LogDetailsDrawer/SimpleMessageBlock.test.tsx create mode 100644 ui/litellm-dashboard/src/components/view_logs/LogDetailsDrawer/SimpleToolCallBlock.test.tsx diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/common/queryKeysFactory.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/common/queryKeysFactory.test.ts new file mode 100644 index 0000000000..39afd04409 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/common/queryKeysFactory.test.ts @@ -0,0 +1,34 @@ +import { describe, it, expect } from "vitest"; +import { createQueryKeys } from "./queryKeysFactory"; + +describe("createQueryKeys", () => { + const keys = createQueryKeys("books"); + + it("should return the resource name as the base key", () => { + expect(keys.all).toEqual(["books"]); + }); + + it("should generate a lists key", () => { + expect(keys.lists()).toEqual(["books", "list"]); + }); + + it("should generate a list key with params", () => { + expect(keys.list({ page: 1, limit: 10 })).toEqual([ + "books", + "list", + { params: { page: 1, limit: 10 } }, + ]); + }); + + it("should generate a list key with undefined params when none provided", () => { + expect(keys.list()).toEqual(["books", "list", { params: undefined }]); + }); + + it("should generate a details key", () => { + expect(keys.details()).toEqual(["books", "detail"]); + }); + + it("should generate a detail key for a specific ID", () => { + expect(keys.detail("123")).toEqual(["books", "detail", "123"]); + }); +}); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/teams/components/TeamsHeaderTabs.test.tsx b/ui/litellm-dashboard/src/app/(dashboard)/teams/components/TeamsHeaderTabs.test.tsx new file mode 100644 index 0000000000..50a7f10f04 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/teams/components/TeamsHeaderTabs.test.tsx @@ -0,0 +1,54 @@ +import { render, screen } from "@testing-library/react"; +import React from "react"; +import { describe, expect, it, vi } from "vitest"; +import TeamsHeaderTabs from "./TeamsHeaderTabs"; + +vi.mock("@tremor/react", () => ({ + TabGroup: ({ children, ...props }: any) =>
{children}
, + TabList: ({ children, ...props }: any) =>
{children}
, + Tab: ({ children, ...props }: any) => , + TabPanels: ({ children, ...props }: any) =>
{children}
, + Text: ({ children, ...props }: any) => {children}, + Icon: ({ onClick, ...props }: any) => + + ); + } + + return ( + + columns={teamColumns} + dataSource={displayTeams} + rowKey="team_id" + pagination={false} + onChange={handleTableSort} + locale={{ + emptyText: ( +
+ +
+ No teams yet +
+
+ + Create your first team to organize members and manage access to models. + +
+ {canCreateOrManageTeams(userRole, userID, organizations) && ( + + )} +
+ ), + }} + scroll={{ x: 1000 }} + size="middle" + /> + ); + }; + + const tabItems = [ + { + key: "your-teams", + label: "Your Teams", + children: ( + <> + + + + } + suffix={isSearching ? : null} + placeholder="Search teams by name..." + onChange={(e) => handleSearchChange(e.target.value)} + allowClear + style={{ maxWidth: 400 }} + /> + handleFilterChange("organization_id", value || "")} + loading={isLoading} + /> + + { + setCurrentPage(page); + setPageSize(size); + fetchTeamsV2({ page, size }); + }} + size="small" + showTotal={(total) => `${total} teams`} + showSizeChanger + pageSizeOptions={["10", "20", "50"]} + /> + + + {renderTeamsContent()} + + + + + ), + }, + { + key: "available-teams", + label: "Available Teams", + children: , + }, + ...(isProxyAdminRole(userRole || "") + ? [ + { + key: "default-settings", + label: "Default Team Settings", + children: , + }, + ] + : []), + ]; + return ( -
- - - {canCreateOrManageTeams(userRole, userID, organizations) && ( - - )} - {selectedTeamId ? ( - { - setTeams((teams) => { - if (teams == null) { - return teams; - } - const updated = teams.map((team) => { - if (data.team_id === team.team_id) { - return updateExistingKeys(team, data); - } - return team; - }); - // Minimal fix: refresh the full team list after an update - if (accessToken) { - fetchTeams(accessToken, userID, userRole, currentOrg, setTeams); - } - return updated; - }); - }} - onClose={() => { - setSelectedTeamId(null); - setEditTeam(false); - }} - accessToken={accessToken} - is_team_admin={is_team_admin(teams?.find((team) => team.team_id === selectedTeamId))} - is_proxy_admin={userRole == "Admin"} - userModels={userModels} - editTeam={editTeam} - premiumUser={premiumUser} - /> - ) : ( - - -
- Your Teams - Available Teams - {isProxyAdminRole(userRole || "") && Default Team Settings} -
-
- {lastRefreshed && Last Refreshed: {lastRefreshed}} - -
-
- - - - Click on “Team ID” to view team details and manage team members. - - - - -
-
- {/* Search and Filter Controls */} -
- {/* Team Alias Search */} - handleFilterChange("team_alias", value)} - icon={Search} - /> + + {selectedTeamId ? ( + { + setTeams((teams) => { + if (teams == null) { + return teams; + } + return teams.map((team) => { + if (data.team_id === team.team_id) { + return updateExistingKeys(team, data); + } + return team; + }); + }); + fetchTeamsV2(); + }} + onClose={() => { + setSelectedTeamId(null); + setEditTeam(false); + }} + accessToken={accessToken} + is_team_admin={is_team_admin(teams?.find((team) => team.team_id === selectedTeamId))} + is_proxy_admin={userRole == "Admin"} + userModels={userModels} + editTeam={editTeam} + premiumUser={premiumUser} + /> + ) : ( + <> + + + + <TeamOutlined style={{ marginRight: 8 }} /> + Teams + + + Manage teams, members, and their access to models and budgets + + + {canCreateOrManageTeams(userRole, userID, organizations) && ( + + )} + - {/* Filter Button */} - setShowFilters(!showFilters)} - active={showFilters} - hasActiveFilters={!!(filters.team_id || filters.team_alias || filters.organization_id)} - /> + + + )} - {/* Reset Filters Button */} - -
- - {/* Additional Filters */} - {showFilters && ( -
- {/* Team ID Search */} - handleFilterChange("team_id", value)} - icon={User} - /> - - {/* Organization Dropdown */} -
- -
-
- )} -
-
- - - - Team Name - Team ID - Created - Spend (USD) - Budget (USD) - Models - Organization - Info - Actions - - - - - {teams && teams.length > 0 ? ( - teams - .filter((team) => { - if (!currentOrg) return true; - return team.organization_id === currentOrg.organization_id; - }) - .sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()) - .map((team: any) => ( - - - {team["team_alias"]} - - -
- - - -
-
- - {team.created_at ? new Date(team.created_at).toLocaleDateString() : "N/A"} - - - {formatNumberWithCommas(team["spend"], 4)} - - - {team["max_budget"] !== null && team["max_budget"] !== undefined - ? team["max_budget"] - : "No limit"} - - 3 ? "px-0" : ""} - > -
- {Array.isArray(team.models) ? ( -
- {team.models.length === 0 ? ( - - All Proxy Models - - ) : ( - <> -
- {team.models.length > 3 && ( -
- { - setExpandedAccordions((prev) => ({ - ...prev, - [team.team_id]: !prev[team.team_id], - })); - }} - /> -
- )} -
- {team.models.slice(0, 3).map((model: string, index: number) => - model === "all-proxy-models" ? ( - - All Proxy Models - - ) : ( - - - {model.length > 30 - ? `${getModelDisplayName(model).slice(0, 30)}...` - : getModelDisplayName(model)} - - - ), - )} - {team.models.length > 3 && !expandedAccordions[team.team_id] && ( - - - +{team.models.length - 3}{" "} - {team.models.length - 3 === 1 ? "more model" : "more models"} - - - )} - {expandedAccordions[team.team_id] && ( -
- {team.models.slice(3).map((model: string, index: number) => - model === "all-proxy-models" ? ( - - All Proxy Models - - ) : ( - - - {model.length > 30 - ? `${getModelDisplayName(model).slice(0, 30)}...` - : getModelDisplayName(model)} - - - ), - )} -
- )} -
-
- - )} -
- ) : null} -
-
- - - {getOrganizationAlias(team.organization_id, organizationsData || organizations)} - - - - {perTeamInfo && - team.team_id && - perTeamInfo[team.team_id] && - perTeamInfo[team.team_id].keys && - perTeamInfo[team.team_id].keys.length}{" "} - Keys - - - {perTeamInfo && - team.team_id && - perTeamInfo[team.team_id] && - perTeamInfo[team.team_id].team_info && - perTeamInfo[team.team_id].team_info.members_with_roles && - perTeamInfo[team.team_id].team_info.members_with_roles.length}{" "} - Members - - - - {userRole == "Admin" ? ( - <> - { - setSelectedTeamId(team.team_id); - setEditTeam(true); - }} - dataTestId="edit-team-button" - tooltipText="Edit team" - /> - handleDelete(team)} - dataTestId="delete-team-button" - tooltipText="Delete team" - /> - - ) : null} - -
- )) - ) : ( - - -
- No teams found - Adjust your filters or create a new team -
-
-
- )} -
-
- -
- -
-
- - - - {isProxyAdminRole(userRole || "") && ( - - - - )} -
-
- )} - {canCreateOrManageTeams(userRole, userID, organizations) && ( + {canCreateOrManageTeams(userRole, userID, organizations) && ( = ({ : "" } > - = ({ optionFilterProp="children" > {adminOrgs?.map((org) => ( - + {org.organization_alias}{" "} ({org.organization_id}) - + ))} - + {/* Show message when org admin needs to select organization */} {isOrgAdmin && !isSingleOrg && adminOrgs.length > 1 && (
- + Please select an organization to create a team for. You can only create teams within organizations where you are an admin. @@ -1190,11 +1211,11 @@ const Teams: React.FC = ({ - - daily - weekly - monthly - + @@ -1313,7 +1334,7 @@ const Teams: React.FC = ({ className="mt-8" help="Select existing guardrails or enter new ones" > - = ({ className="mt-8" help="Select existing policies or enter new ones" > - = ({
- + Create custom aliases for models that can be used by team members in API calls. This allows you to create shortcuts for specific models. @@ -1548,14 +1569,12 @@ const Teams: React.FC = ({
- Create Team +
)} - - -
+ ); }; diff --git a/ui/litellm-dashboard/src/components/common_components/IconActionButton/TableIconActionButtons/TableIconActionButton.tsx b/ui/litellm-dashboard/src/components/common_components/IconActionButton/TableIconActionButtons/TableIconActionButton.tsx index 488913a734..2f146aab72 100644 --- a/ui/litellm-dashboard/src/components/common_components/IconActionButton/TableIconActionButtons/TableIconActionButton.tsx +++ b/ui/litellm-dashboard/src/components/common_components/IconActionButton/TableIconActionButtons/TableIconActionButton.tsx @@ -6,6 +6,7 @@ import { ChevronUpIcon, ChevronDownIcon, ExternalLinkIcon, + ClipboardCopyIcon, } from "@heroicons/react/outline"; import { Tooltip } from "antd"; import BaseActionButton from "../BaseActionButton"; @@ -32,6 +33,7 @@ export const TableIconActionButtonMap: Record void; disabled?: boolean; loading?: boolean; + style?: React.CSSProperties; } const OrganizationDropdown: React.FC = ({ @@ -16,16 +19,18 @@ const OrganizationDropdown: React.FC = ({ onChange, disabled, loading, + style, }) => { return ( diff --git a/ui/litellm-dashboard/src/components/ui/AntDLoadingSpinner.tsx b/ui/litellm-dashboard/src/components/ui/AntDLoadingSpinner.tsx new file mode 100644 index 0000000000..9e90f77584 --- /dev/null +++ b/ui/litellm-dashboard/src/components/ui/AntDLoadingSpinner.tsx @@ -0,0 +1,12 @@ +import { Spin } from "antd"; +import { LoadingOutlined } from "@ant-design/icons"; + +interface AntDLoadingSpinnerProps { + size?: "small" | "default" | "large"; + fontSize?: number; +} + +export function AntDLoadingSpinner({ size, fontSize }: AntDLoadingSpinnerProps) { + const indicator = ; + return ; +} From ad43a35d762d62cd2cb9a18a0f1fbdef4f3fb67c Mon Sep 17 00:00:00 2001 From: Ryan Crabbe Date: Thu, 19 Mar 2026 22:50:19 -0700 Subject: [PATCH 223/539] feat: add control plane for multi-proxy worker management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a control plane capability that enables a central admin instance to manage multiple regional worker proxies from a single UI. Backend: - Worker registry loaded from YAML config (worker_id, name, url) - /.well-known/litellm-ui-config exposes is_control_plane and workers list - /v3/login + /v3/login/exchange: opaque code exchange for cross-origin username/password auth (JWT never in URL/logs, single-use 60s TTL) - SSO cookie handoff with return_to → opaque code → exchange - _validate_return_to: full origin validation (scheme+hostname+port) - Startup warning when control_plane_url set without Redis - Both /v3 endpoints gated behind control_plane_url config Frontend: - Worker selector dropdown on login page (gated behind is_control_plane) - Cross-origin SSO code exchange handling on callback - switchToWorkerUrl: localStorage-persisted worker URL for API calls - useWorker hook: shared worker state management - WorkerDropdown in navbar for switching workers - Logout/switch clears worker state from localStorage Tests: - 7 tests for /v3/login + /v3/login/exchange - 10 tests for _validate_return_to - 2 tests for control plane discovery endpoint --- .../ui_discovery_endpoints.py | 7 + litellm/proxy/management_endpoints/ui_sso.py | 85 ++++++- litellm/proxy/proxy_server.py | 187 +++++++++++++++- .../types/proxy/control_plane_endpoints.py | 14 ++ .../ui_discovery_endpoints.py | 6 +- .../test_ui_discovery_endpoints.py | 57 ++++- .../proxy/management_endpoints/test_ui_sso.py | 98 +++++++- tests/test_litellm/proxy/test_proxy_server.py | 211 ++++++++++++++++++ .../app/(dashboard)/hooks/login/useLogin.ts | 4 +- .../hooks/uiConfig/useUIConfig.test.ts | 4 + .../src/app/login/LoginPage.test.tsx | 15 +- .../src/app/login/LoginPage.tsx | 125 +++++++++-- .../Navbar/WorkerDropdown/WorkerDropdown.tsx | 38 ++++ .../src/components/navbar.tsx | 13 ++ .../src/components/networking.tsx | 126 ++++++++++- ui/litellm-dashboard/src/hooks/useWorker.ts | 65 ++++++ 16 files changed, 1025 insertions(+), 30 deletions(-) create mode 100644 litellm/types/proxy/control_plane_endpoints.py create mode 100644 ui/litellm-dashboard/src/components/Navbar/WorkerDropdown/WorkerDropdown.tsx create mode 100644 ui/litellm-dashboard/src/hooks/useWorker.ts diff --git a/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py b/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py index 2a38ceffba..233df5c6c5 100644 --- a/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py +++ b/litellm/proxy/discovery_endpoints/ui_discovery_endpoints.py @@ -27,10 +27,17 @@ async def get_ui_config(): admin_ui_disabled = os.getenv("DISABLE_ADMIN_UI", "false").lower() == "true" sso_configured = _has_user_setup_sso() + + from litellm.proxy.proxy_server import proxy_config + + is_control_plane = len(proxy_config.worker_registry) > 0 + return UiDiscoveryEndpoints( server_root_path=get_server_root_path(), proxy_base_url=get_proxy_base_url(), auto_redirect_to_sso=sso_configured and auto_redirect_ui_login_to_sso, admin_ui_disabled=admin_ui_disabled, sso_configured=sso_configured, + is_control_plane=is_control_plane, + workers=proxy_config.worker_registry if is_control_plane else [], ) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index daf1d6f131..8634a8bb33 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -16,6 +16,7 @@ import os import secrets from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast +from urllib.parse import urlencode, urlparse if TYPE_CHECKING: import httpx @@ -301,6 +302,7 @@ async def google_login( source: Optional[str] = None, key: Optional[str] = None, existing_key: Optional[str] = None, + return_to: Optional[str] = None, ): # noqa: PLR0915 """ Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env @@ -394,13 +396,23 @@ async def google_login( is True ): verbose_proxy_logger.info(f"Redirecting to SSO login for {redirect_url}") - return await SSOAuthenticationHandler.get_sso_login_redirect( + sso_redirect = await SSOAuthenticationHandler.get_sso_login_redirect( redirect_url=redirect_url, microsoft_client_id=microsoft_client_id, google_client_id=google_client_id, generic_client_id=generic_client_id, state=cli_state, ) + if return_to is not None and sso_redirect is not None: + SSOAuthenticationHandler._validate_return_to(return_to) + sso_redirect.set_cookie( + key="litellm_cp_return_to", + value=return_to, + max_age=600, + httponly=True, + samesite="lax", + ) + return sso_redirect elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env @@ -1312,12 +1324,17 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: request=request, key=key_id, existing_key=existing_key, result=result ) + # Control-plane cross-origin: read return_to from cookie. + # Starlette's cookie_parser already handles RFC 2109 unquoting. + cp_return_to: Optional[str] = request.cookies.get("litellm_cp_return_to") + return await SSOAuthenticationHandler.get_redirect_response_from_openid( result=result, request=request, received_response=received_response, generic_client_id=generic_client_id, ui_access_mode=ui_access_mode, + return_to=cp_return_to, ) @@ -1760,6 +1777,38 @@ class SSOAuthenticationHandler: Handler for SSO Authentication across all SSO providers """ + @staticmethod + def _validate_return_to(return_to: str) -> None: + """ + Validate that return_to matches the configured control_plane_url origin. + + Raises HTTPException(400) if: + - control_plane_url is not configured in general_settings + - return_to origin does not match control_plane_url origin + """ + from litellm.proxy.proxy_server import general_settings + + control_plane_url = general_settings.get("control_plane_url") + if control_plane_url is None: + raise HTTPException( + status_code=400, + detail="return_to is not allowed: control_plane_url is not configured", + ) + + def _origin(url: str) -> tuple: + parsed = urlparse(url) + scheme = (parsed.scheme or "").lower() + hostname = (parsed.hostname or "").lower() + default_port = 443 if scheme == "https" else 80 + port = parsed.port if parsed.port is not None else default_port + return (scheme, hostname, port) + + if _origin(return_to) != _origin(control_plane_url): + raise HTTPException( + status_code=400, + detail="return_to does not match the configured control_plane_url", + ) + @staticmethod async def get_sso_login_redirect( redirect_url: str, @@ -2358,6 +2407,7 @@ class SSOAuthenticationHandler: received_response: Optional[dict] = None, generic_client_id: Optional[str] = None, ui_access_mode: Optional[Dict] = None, + return_to: Optional[str] = None, ) -> RedirectResponse: import jwt @@ -2367,6 +2417,7 @@ class SSOAuthenticationHandler: master_key, premium_user, proxy_logging_obj, + redis_usage_cache, user_api_key_cache, user_custom_sso, ) @@ -2534,6 +2585,38 @@ class SSOAuthenticationHandler: master_key or "", algorithm="HS256", ) + + # Control-plane cross-origin: store JWT behind a single-use opaque + # code (60s TTL) so the token never appears in browser history / logs. + # The control plane redeems it via POST /v3/login/exchange. + if return_to is not None: + SSOAuthenticationHandler._validate_return_to(return_to) + + code = secrets.token_urlsafe(32) + cache_key = f"login_code:{code}" + cache_value = {"token": jwt_token, "redirect_url": return_to} + if redis_usage_cache is not None: + await redis_usage_cache.async_set_cache( + key=cache_key, value=cache_value, ttl=60 + ) + else: + await user_api_key_cache.async_set_cache( + key=cache_key, value=cache_value, ttl=60 + ) + + separator = "&" if "?" in return_to else "?" + redirect_url = ( + return_to + + separator + + urlencode({"login": "success", "code": code}) + ) + verbose_proxy_logger.info( + "Cross-origin SSO: redirecting to control plane with login code" + ) + redirect_response = RedirectResponse(url=redirect_url, status_code=303) + redirect_response.delete_cookie("litellm_cp_return_to") + return redirect_response + if user_id is not None and isinstance(user_id, str): litellm_dashboard_ui += "?login=success" verbose_proxy_logger.info(f"Redirecting to {litellm_dashboard_ui}") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9c29927c5c..e982c934aa 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -541,6 +541,7 @@ from litellm.types.llms.anthropic import ( AnthropicResponseUsageBlock, ) from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.proxy.control_plane_endpoints import WorkerRegistryEntry from litellm.types.proxy.management_endpoints.model_management_endpoints import ( ModelGroupInfoProxy, ) @@ -1546,6 +1547,7 @@ user_custom_key_generate = None # Sentinel: prevents PKCE-no-Redis advisory from re-logging on config hot-reload. # Tests that need to reset it can patch 'litellm.proxy.proxy_server._pkce_no_redis_warning_emitted'. _pkce_no_redis_warning_emitted: bool = False +_cp_no_redis_warning_emitted: bool = False user_custom_sso = None user_custom_ui_sso_sign_in_handler = None use_background_health_checks = None @@ -2295,6 +2297,7 @@ class ProxyConfig: self.config: Dict[str, Any] = {} self._last_semantic_filter_config: Optional[Dict[str, Any]] = None self._last_hashicorp_vault_config: Optional[Dict[str, Any]] = None + self.worker_registry: List["WorkerRegistryEntry"] = [] def is_yaml(self, config_file_path: str) -> bool: if not os.path.isfile(config_file_path): @@ -3095,6 +3098,21 @@ class ProxyConfig: "Set PKCE_STRICT_CACHE_MISS=true to fail fast with a 401 on cache misses " "instead of continuing without a code_verifier." ) + + ### CONTROL PLANE CODE-EXCHANGE PREREQUISITE CHECK ### + cp_url = general_settings.get("control_plane_url") + if cp_url and redis_usage_cache is None: + global _cp_no_redis_warning_emitted + if not _cp_no_redis_warning_emitted: + _cp_no_redis_warning_emitted = True + verbose_proxy_logger.warning( + "control_plane_url is configured but Redis is not configured for LiteLLM caching. " + "Login codes (SSO and /v3/login) will not be shared across instances — " + "the /v3/login/exchange call may land on a different pod and fail with 401. " + "Configure Redis via the 'cache' section in your proxy config, " + "or ensure sticky sessions for single-instance deployments." + ) + ### STORE MODEL IN DB ### feature flag for `/model/new` store_model_in_db = general_settings.get("store_model_in_db", False) if store_model_in_db is None: @@ -3385,7 +3403,15 @@ class ProxyConfig: litellm.vector_store_registry.load_vector_stores_from_config( vector_store_registry_config ) - pass + + ## WORKER REGISTRY (Control Plane) + worker_registry_config = config.get("worker_registry", None) + if worker_registry_config: + self.worker_registry = [ + WorkerRegistryEntry(**e) for e in worker_registry_config + ] + else: + self.worker_registry = [] async def _init_policy_engine( self, @@ -11095,6 +11121,165 @@ async def login_v2(request: Request): # noqa: PLR0915 ) +@router.post( + "/v3/login", include_in_schema=False +) # control-plane login — always returns token in body for cross-origin use +async def login_v3(request: Request): # noqa: PLR0915 + global premium_user, general_settings, master_key + from litellm.proxy.auth.login_utils import authenticate_user, create_ui_token_object + from litellm.proxy.utils import get_custom_url + + try: + if not general_settings.get("control_plane_url"): + raise ProxyException( + message="/v3/login is only available on workers with control_plane_url configured", + type=ProxyErrorTypes.not_found_error, + param="control_plane_url", + code=status.HTTP_404_NOT_FOUND, + ) + + body = await request.json() + username = str(body.get("username")) + password = str(body.get("password")) + + login_result = await authenticate_user( + username=username, + password=password, + master_key=master_key, + prisma_client=prisma_client, + ) + + returned_ui_token_object = create_ui_token_object( + login_result=login_result, + general_settings=general_settings, + premium_user=premium_user, + ) + + import jwt + + jwt_token = jwt.encode( + cast(dict, returned_ui_token_object), + cast(str, master_key), + algorithm="HS256", + ) + + litellm_dashboard_ui = get_custom_url(str(request.base_url)) + if litellm_dashboard_ui.endswith("/"): + litellm_dashboard_ui += "ui/" + else: + litellm_dashboard_ui += "/ui/" + litellm_dashboard_ui += "?login=success" + + # Store JWT behind a single-use opaque code (60s TTL) + code = secrets.token_urlsafe(32) + cache_key = f"login_code:{code}" + cache_value = {"token": jwt_token, "redirect_url": litellm_dashboard_ui} + if redis_usage_cache is not None: + await redis_usage_cache.async_set_cache( + key=cache_key, value=cache_value, ttl=60 + ) + else: + await user_api_key_cache.async_set_cache( + key=cache_key, value=cache_value, ttl=60 + ) + + return JSONResponse( + content={"code": code, "expires_in": 60}, + status_code=status.HTTP_200_OK, + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.login_v3(): Exception occurred - {}".format( + str(e) + ) + ) + if isinstance(e, ProxyException): + raise e + elif isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", str(e)), + type=ProxyErrorTypes.auth_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + else: + error_msg = f"{str(e)}" + raise ProxyException( + message=error_msg, + type=ProxyErrorTypes.auth_error, + param="None", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.post( + "/v3/login/exchange", include_in_schema=False +) # exchange single-use opaque code for JWT +async def login_v3_exchange(request: Request): + try: + if not general_settings.get("control_plane_url"): + raise ProxyException( + message="/v3/login/exchange is only available on workers with control_plane_url configured", + type=ProxyErrorTypes.not_found_error, + param="control_plane_url", + code=status.HTTP_404_NOT_FOUND, + ) + + body = await request.json() + code = body.get("code") + if not code: + raise ProxyException( + message="Missing 'code' parameter", + type=ProxyErrorTypes.auth_error, + param="code", + code=status.HTTP_400_BAD_REQUEST, + ) + + cache_key = f"login_code:{code}" + if redis_usage_cache is not None: + cached_data = await redis_usage_cache.async_get_cache(key=cache_key) + else: + cached_data = await user_api_key_cache.async_get_cache(key=cache_key) + + if not cached_data or not isinstance(cached_data, dict): + raise ProxyException( + message="Invalid or expired login code", + type=ProxyErrorTypes.auth_error, + param="code", + code=status.HTTP_401_UNAUTHORIZED, + ) + + # Single-use: delete immediately + if redis_usage_cache is not None: + await redis_usage_cache.async_delete_cache(key=cache_key) + else: + await user_api_key_cache.async_delete_cache(key=cache_key) + + json_response = JSONResponse( + content={ + "token": cached_data["token"], + "redirect_url": cached_data["redirect_url"], + }, + status_code=status.HTTP_200_OK, + ) + json_response.set_cookie(key="token", value=cached_data["token"]) + return json_response + except ProxyException: + raise + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.login_v3_exchange(): Exception occurred - {}".format( + str(e) + ) + ) + raise ProxyException( + message=str(e), + type=ProxyErrorTypes.auth_error, + param="None", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + @app.get("/onboarding/get_token", include_in_schema=False) async def onboarding(invite_link: str, request: Request): """ diff --git a/litellm/types/proxy/control_plane_endpoints.py b/litellm/types/proxy/control_plane_endpoints.py new file mode 100644 index 0000000000..8bf4c44b20 --- /dev/null +++ b/litellm/types/proxy/control_plane_endpoints.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, field_validator + + +class WorkerRegistryEntry(BaseModel): + worker_id: str + name: str + url: str + + @field_validator("url") + @classmethod + def url_must_be_http(cls, v: str) -> str: + if not v.startswith(("http://", "https://")): + raise ValueError("Worker URL must start with http:// or https://") + return v diff --git a/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py b/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py index 4a4cdaa2ba..46cd3f49f1 100644 --- a/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py +++ b/litellm/types/proxy/discovery_endpoints/ui_discovery_endpoints.py @@ -1,7 +1,9 @@ -from typing import Optional +from typing import List, Optional from pydantic import BaseModel +from litellm.types.proxy.control_plane_endpoints import WorkerRegistryEntry + class UiDiscoveryEndpoints(BaseModel): server_root_path: str @@ -9,3 +11,5 @@ class UiDiscoveryEndpoints(BaseModel): auto_redirect_to_sso: bool admin_ui_disabled: bool sso_configured: bool + is_control_plane: bool = False + workers: List[WorkerRegistryEntry] = [] diff --git a/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py b/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py index f15960a607..54a127f435 100644 --- a/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py +++ b/tests/test_litellm/proxy/discovery_endpoints/test_ui_discovery_endpoints.py @@ -1,6 +1,6 @@ import os import sys -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI @@ -11,6 +11,7 @@ sys.path.insert( ) from litellm.proxy.discovery_endpoints.ui_discovery_endpoints import router +from litellm.types.proxy.control_plane_endpoints import WorkerRegistryEntry def test_ui_discovery_endpoints_with_defaults(): @@ -245,9 +246,9 @@ def test_ui_discovery_endpoints_with_admin_ui_enabled(): patch("litellm.proxy.utils.get_proxy_base_url", return_value=None), \ patch("litellm.proxy.auth.auth_utils._has_user_setup_sso", return_value=False), \ patch.dict(os.environ, {"DISABLE_ADMIN_UI": "false"}, clear=False): - + response = client.get("/.well-known/litellm-ui-config") - + assert response.status_code == 200 data = response.json() assert data["server_root_path"] == "/" @@ -256,3 +257,53 @@ def test_ui_discovery_endpoints_with_admin_ui_enabled(): assert data["admin_ui_disabled"] is False assert data["sso_configured"] is False + +def test_ui_discovery_endpoints_is_control_plane_true_when_workers_configured(): + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + mock_config = MagicMock() + mock_config.worker_registry = [ + WorkerRegistryEntry( + worker_id="team-a", name="Team A", url="https://worker-1:4001" + ), + ] + + with patch("litellm.proxy.utils.get_server_root_path", return_value="/"), \ + patch("litellm.proxy.utils.get_proxy_base_url", return_value=None), \ + patch("litellm.proxy.auth.auth_utils._has_user_setup_sso", return_value=False), \ + patch("litellm.proxy.proxy_server.proxy_config", mock_config), \ + patch.dict(os.environ, {"DISABLE_ADMIN_UI": "false"}, clear=False): + + response = client.get("/.well-known/litellm-ui-config") + + assert response.status_code == 200 + data = response.json() + assert data["is_control_plane"] is True + assert len(data["workers"]) == 1 + assert data["workers"][0]["worker_id"] == "team-a" + assert data["workers"][0]["name"] == "Team A" + assert data["workers"][0]["url"] == "https://worker-1:4001" + + +def test_ui_discovery_endpoints_is_control_plane_false_when_no_workers(): + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + mock_config = MagicMock() + mock_config.worker_registry = [] + + with patch("litellm.proxy.utils.get_server_root_path", return_value="/"), \ + patch("litellm.proxy.utils.get_proxy_base_url", return_value=None), \ + patch("litellm.proxy.auth.auth_utils._has_user_setup_sso", return_value=False), \ + patch("litellm.proxy.proxy_server.proxy_config", mock_config), \ + patch.dict(os.environ, {"DISABLE_ADMIN_UI": "false"}, clear=False): + + response = client.get("/.well-known/litellm-ui-config") + + assert response.status_code == 200 + data = response.json() + assert data["is_control_plane"] is False + assert data["workers"] == [] diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index d43b2c4ba0..fc9c37b7f8 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest -from fastapi import Request +from fastapi import HTTPException, Request from litellm._uuid import uuid @@ -5160,3 +5160,99 @@ def test_generic_response_convertor_extra_attributes_missing_field(monkeypatch): assert result.extra_fields["missing_field"] is None assert result.extra_fields["another_missing"] is None + +class TestValidateReturnTo: + """Tests for SSOAuthenticationHandler._validate_return_to""" + + def test_rejects_when_no_control_plane_url_configured(self, monkeypatch): + """return_to should be rejected if control_plane_url is not in general_settings.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", {} + ) + with pytest.raises(HTTPException) as exc_info: + SSOAuthenticationHandler._validate_return_to("https://cp.example.com/ui") + assert exc_info.value.status_code == 400 + assert "not configured" in exc_info.value.detail + + def test_allows_matching_origin(self, monkeypatch): + """return_to matching the configured control_plane_url origin should pass.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + # Should not raise + SSOAuthenticationHandler._validate_return_to("https://cp.example.com/ui?page=models") + + def test_allows_matching_origin_with_trailing_slash(self, monkeypatch): + """Trailing slash on control_plane_url should not affect origin comparison.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com/"}, + ) + SSOAuthenticationHandler._validate_return_to("https://cp.example.com/ui") + + def test_rejects_prefix_attack(self, monkeypatch): + """return_to like cp.example.com.evil.com must be rejected (not just prefix match).""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + with pytest.raises(HTTPException) as exc_info: + SSOAuthenticationHandler._validate_return_to("https://cp.example.com.evil.com/steal") + assert exc_info.value.status_code == 400 + + def test_rejects_different_origin(self, monkeypatch): + """return_to pointing to a completely different domain should be rejected.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + with pytest.raises(HTTPException) as exc_info: + SSOAuthenticationHandler._validate_return_to("https://evil.com/phish") + assert exc_info.value.status_code == 400 + + def test_case_insensitive_hostname(self, monkeypatch): + """Hostname comparison should be case-insensitive per RFC 3986.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://CP.Example.COM"}, + ) + # Should not raise + SSOAuthenticationHandler._validate_return_to("https://cp.example.com/ui") + + def test_rejects_scheme_mismatch(self, monkeypatch): + """http:// must be rejected when control_plane_url uses https://.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + with pytest.raises(HTTPException) as exc_info: + SSOAuthenticationHandler._validate_return_to("http://cp.example.com/ui") + assert exc_info.value.status_code == 400 + + def test_rejects_port_mismatch(self, monkeypatch): + """Non-default port must be rejected.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + with pytest.raises(HTTPException) as exc_info: + SSOAuthenticationHandler._validate_return_to("https://cp.example.com:8443/ui") + assert exc_info.value.status_code == 400 + + def test_allows_explicit_default_port(self, monkeypatch): + """https://host:443 should match https://host (default port normalisation).""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + SSOAuthenticationHandler._validate_return_to("https://cp.example.com:443/ui") + + def test_allows_matching_custom_port(self, monkeypatch): + """Both sides on the same custom port should match.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com:3000"}, + ) + SSOAuthenticationHandler._validate_return_to("https://cp.example.com:3000/ui") + diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 112a06b173..bd6162f225 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -236,6 +236,217 @@ def test_login_v2_returns_json_on_invalid_json_body(monkeypatch): assert isinstance(data["error"], dict) +def test_login_v3_rejected_without_control_plane_url(monkeypatch): + """v3/login returns 404 when control_plane_url is not configured.""" + mock_prisma_client = MagicMock() + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key") + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + + client = TestClient(app) + response = client.post( + "/v3/login", + json={"username": "alice", "password": "secret"}, + ) + + assert response.status_code == 404 + assert "control_plane_url" in response.json()["error"]["message"] + + +def test_login_v3_returns_code(monkeypatch): + """v3/login returns an opaque code, not the JWT directly.""" + mock_prisma_client = MagicMock() + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.authenticate_user", + AsyncMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.create_ui_token_object", + MagicMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token")) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key") + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + mock_config = MagicMock() + mock_config.worker_registry = [] + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config) + monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "") + monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None) + + client = TestClient(app) + response = client.post( + "/v3/login", + json={"username": "alice", "password": "secret"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "code" in data + assert data["expires_in"] == 60 + assert "token" not in data + + +def test_login_v3_exchange_happy_path(monkeypatch): + """Full flow: v3/login returns code, v3/login/exchange redeems it for JWT.""" + mock_prisma_client = MagicMock() + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.authenticate_user", + AsyncMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.create_ui_token_object", + MagicMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token")) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key") + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + mock_config = MagicMock() + mock_config.worker_registry = [] + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config) + monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "") + monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None) + + client = TestClient(app) + + # Step 1: login — get code + login_response = client.post( + "/v3/login", + json={"username": "alice", "password": "secret"}, + ) + assert login_response.status_code == 200 + code = login_response.json()["code"] + + # Step 2: exchange — get JWT + exchange_response = client.post( + "/v3/login/exchange", + json={"code": code}, + ) + assert exchange_response.status_code == 200 + exchange_data = exchange_response.json() + assert exchange_data["token"] == "signed-token" + assert "redirect_url" in exchange_data + assert exchange_response.cookies.get("token") == "signed-token" + + +def test_login_v3_exchange_single_use(monkeypatch): + """Code can only be redeemed once.""" + mock_prisma_client = MagicMock() + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.authenticate_user", + AsyncMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.create_ui_token_object", + MagicMock(return_value={"user_id": "test-user"}), + ) + monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token")) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key") + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + mock_config = MagicMock() + mock_config.worker_registry = [] + monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config) + monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "") + monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None) + + client = TestClient(app) + + login_response = client.post( + "/v3/login", + json={"username": "alice", "password": "secret"}, + ) + code = login_response.json()["code"] + + # First exchange succeeds + first = client.post("/v3/login/exchange", json={"code": code}) + assert first.status_code == 200 + + # Second exchange fails + second = client.post("/v3/login/exchange", json={"code": code}) + assert second.status_code == 401 + + +def test_login_v3_exchange_invalid_code(monkeypatch): + """Random code returns 401.""" + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + client = TestClient(app) + response = client.post( + "/v3/login/exchange", + json={"code": "nonexistent-code"}, + ) + assert response.status_code == 401 + + +def test_login_v3_exchange_rejected_without_control_plane_url(monkeypatch): + """v3/login/exchange returns 404 when control_plane_url is not configured.""" + monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {}) + + client = TestClient(app) + response = client.post( + "/v3/login/exchange", + json={"code": "some-code"}, + ) + + assert response.status_code == 404 + assert "control_plane_url" in response.json()["error"]["message"] + + +def test_login_v3_returns_json_on_proxy_exception(monkeypatch): + """Test that /v3/login returns JSON error when ProxyException is raised""" + from litellm.proxy._types import ProxyErrorTypes, ProxyException + + mock_prisma_client = MagicMock() + mock_authenticate_user = AsyncMock( + side_effect=ProxyException( + message="Invalid credentials", + type=ProxyErrorTypes.auth_error, + param="password", + code=401, + ) + ) + + monkeypatch.setattr( + "litellm.proxy.auth.login_utils.authenticate_user", + mock_authenticate_user, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key") + monkeypatch.setattr( + "litellm.proxy.proxy_server.general_settings", + {"control_plane_url": "https://cp.example.com"}, + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + + client = TestClient(app) + response = client.post( + "/v3/login", + json={"username": "alice", "password": "wrong"}, + ) + + assert response.status_code == 401 + assert response.headers["content-type"] == "application/json" + data = response.json() + assert "error" in data + assert data["error"]["message"] == "Invalid credentials" + assert data["error"]["type"] == "auth_error" + + def test_fallback_login_has_no_deprecation_banner(client_no_auth): response = client_no_auth.get("/fallback/login") diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/login/useLogin.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/login/useLogin.ts index a15b4a06d1..be53b1c80a 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/login/useLogin.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/login/useLogin.ts @@ -3,8 +3,8 @@ import { loginCall, LoginRequest } from "@/components/networking"; export const useLogin = () => { return useMutation({ - mutationFn: async ({ username, password }: LoginRequest) => { - const result = await loginCall(username, password); + mutationFn: async ({ username, password, useV3 }: LoginRequest) => { + const result = await loginCall(username, password, useV3); return result; }, }); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts index aba5dddf13..b05bae1e81 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/uiConfig/useUIConfig.test.ts @@ -28,6 +28,8 @@ const mockUIConfig: LiteLLMWellKnownUiConfig = { proxy_base_url: "https://proxy.example.com", auto_redirect_to_sso: true, admin_ui_disabled: false, + is_control_plane: false, + workers: [], }; describe("useUIConfig", () => { @@ -102,6 +104,8 @@ describe("useUIConfig", () => { auto_redirect_to_sso: false, sso_configured: false, admin_ui_disabled: true, + is_control_plane: false, + workers: [], }; // Mock successful API call with different data diff --git a/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx b/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx index ad2dde2da8..866b7d0f17 100644 --- a/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx +++ b/ui/litellm-dashboard/src/app/login/LoginPage.test.tsx @@ -41,6 +41,17 @@ vi.mock("@/app/(dashboard)/hooks/login/useLogin", () => ({ })), })); +vi.mock("@/hooks/useWorker", () => ({ + useWorker: vi.fn(() => ({ + isControlPlane: false, + workers: [], + selectedWorkerId: null, + selectedWorker: null, + selectWorker: vi.fn(), + disconnectFromWorker: vi.fn(), + })), +})); + import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig"; import { getCookie } from "@/utils/cookieUtils"; import { isJwtExpired } from "@/utils/jwtUtils"; @@ -108,7 +119,7 @@ describe("LoginPage", () => { ); await waitFor(() => { - expect(mockReplace).toHaveBeenCalledWith("http://localhost:4000/ui"); + expect(mockReplace).toHaveBeenCalledWith("/ui"); }); }); @@ -189,7 +200,7 @@ describe("LoginPage", () => { ); await waitFor(() => { - expect(mockReplace).toHaveBeenCalledWith("http://localhost:4000/ui"); + expect(mockReplace).toHaveBeenCalledWith("/ui"); }); expect(mockPush).not.toHaveBeenCalled(); diff --git a/ui/litellm-dashboard/src/app/login/LoginPage.tsx b/ui/litellm-dashboard/src/app/login/LoginPage.tsx index d54140c5a2..5a9d420456 100644 --- a/ui/litellm-dashboard/src/app/login/LoginPage.tsx +++ b/ui/litellm-dashboard/src/app/login/LoginPage.tsx @@ -3,14 +3,15 @@ import { useLogin } from "@/app/(dashboard)/hooks/login/useLogin"; import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig"; import LoadingScreen from "@/components/common_components/LoadingScreen"; -import { getProxyBaseUrl } from "@/components/networking"; -import { getCookie } from "@/utils/cookieUtils"; +import { exchangeLoginCode, getProxyBaseUrl, switchToWorkerUrl } from "@/components/networking"; +import { clearTokenCookies, getCookie } from "@/utils/cookieUtils"; import { isJwtExpired } from "@/utils/jwtUtils"; import { consumeReturnUrl, getReturnUrl, isValidReturnUrl } from "@/utils/returnUrlUtils"; -import { InfoCircleOutlined } from "@ant-design/icons"; -import { Alert, Button, Card, Form, Input, Popover, Space, Typography } from "antd"; +import { InfoCircleOutlined, CloudServerOutlined } from "@ant-design/icons"; +import { Alert, Button, Card, Form, Input, Popover, Select, Space, Typography } from "antd"; import { useRouter } from "next/navigation"; import { useEffect, useState } from "react"; +import { useWorker } from "@/hooks/useWorker"; function LoginPageContent() { const [username, setUsername] = useState(""); @@ -19,6 +20,17 @@ function LoginPageContent() { const { data: uiConfig, isLoading: isConfigLoading } = useUIConfig(); const loginMutation = useLogin(); const router = useRouter(); + const { workers, selectWorker } = useWorker(); + const [selectedWorkerId, setSelectedWorkerId] = useState(null); + + // Pre-select worker from URL param (e.g. /ui/login?worker=team-b) + useEffect(() => { + const params = new URLSearchParams(window.location.search); + const workerParam = params.get("worker"); + if (workerParam) { + setSelectedWorkerId(workerParam); + } + }, []); useEffect(() => { if (isConfigLoading) { @@ -31,6 +43,44 @@ function LoginPageContent() { return; } + // Cross-origin SSO: worker redirected back with a single-use code. + // Exchange it for the JWT via the worker's /v3/login/exchange endpoint. + const params = new URLSearchParams(window.location.search); + const ssoCode = params.get("code"); + if (ssoCode) { + const workerUrl = localStorage.getItem("litellm_worker_url"); + exchangeLoginCode(ssoCode, workerUrl).then(() => { + params.delete("code"); + const cleanSearch = params.toString(); + window.history.replaceState(null, "", window.location.pathname + (cleanSearch ? `?${cleanSearch}` : "")); + router.replace("/ui/?login=success"); + }); + return; + } + + // Backwards compat: handle direct token in URL (legacy flow) + const urlToken = params.get("token"); + if (urlToken && !isJwtExpired(urlToken)) { + document.cookie = `token=${urlToken}; path=/; SameSite=Lax`; + params.delete("token"); + const cleanSearch = params.toString(); + window.history.replaceState( + null, + "", + window.location.pathname + (cleanSearch ? `?${cleanSearch}` : ""), + ); + router.replace("/ui/?login=success"); + return; + } + + // If switching workers on a control plane, clear the old token and show login + const switchingWorker = params.has("worker"); + if (switchingWorker && uiConfig?.is_control_plane) { + clearTokenCookies(); + setIsLoading(false); + return; + } + const rawToken = getCookie("token"); if (rawToken && !isJwtExpired(rawToken)) { // User already logged in - redirect to return URL or default @@ -38,7 +88,7 @@ function LoginPageContent() { if (returnUrl) { router.replace(returnUrl); } else { - router.replace(`${getProxyBaseUrl()}/ui`); + router.replace("/ui"); } return; } @@ -58,16 +108,35 @@ function LoginPageContent() { }, [isConfigLoading, router, uiConfig]); const handleSubmit = () => { + // If a worker is selected, point proxyBaseUrl at it before login + const selectedWorker = workers.find((w) => w.worker_id === selectedWorkerId); + if (selectedWorker) { + switchToWorkerUrl(selectedWorker.url); + } + loginMutation.mutate( - { username, password }, + { username, password, useV3: !!selectedWorker }, { onSuccess: (data) => { - // Check if we have a return URL to use instead of the default redirect - const returnUrl = consumeReturnUrl(); - if (returnUrl) { - router.push(returnUrl); + // Update the worker context with the selected worker + if (selectedWorker) { + selectWorker(selectedWorker.worker_id); + // Stay on the CP's UI — proxyBaseUrl already points at the worker + router.push("/ui/?login=success"); } else { - router.push(data.redirect_url); + // Normal (non-control-plane) login — follow the server's redirect + const returnUrl = consumeReturnUrl(); + if (returnUrl) { + router.push(returnUrl); + } else { + router.push(data.redirect_url); + } + } + }, + onError: () => { + // Reset proxyBaseUrl on login failure + if (selectedWorker) { + switchToWorkerUrl(null); } }, }, @@ -154,6 +223,22 @@ function LoginPageContent() { {error && }
+ {uiConfig?.is_control_plane && workers.length > 0 && ( + + + (option?.label as string ?? "").toLowerCase().includes(input.toLowerCase()) + } + value={selectedWorker.worker_id} + style={{ minWidth: 180 }} + suffixIcon={} + options={workers.map((w) => ({ + label: w.name, + value: w.worker_id, + disabled: w.worker_id === selectedWorker.worker_id, + }))} + onChange={(newWorkerId) => { + onWorkerSwitch(newWorkerId); + }} + /> + ); +}; + +export default WorkerDropdown; diff --git a/ui/litellm-dashboard/src/components/navbar.tsx b/ui/litellm-dashboard/src/components/navbar.tsx index c46a3af5a6..96d6ce613b 100644 --- a/ui/litellm-dashboard/src/components/navbar.tsx +++ b/ui/litellm-dashboard/src/components/navbar.tsx @@ -4,6 +4,7 @@ import { getProxyBaseUrl } from "@/components/networking"; import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig"; import { useTheme } from "@/contexts/ThemeContext"; import { clearTokenCookies } from "@/utils/cookieUtils"; +import { clearStoredReturnUrl } from "@/utils/returnUrlUtils"; import { fetchProxySettings } from "@/utils/proxyUtils"; import { MenuFoldOutlined, MenuUnfoldOutlined, MessageOutlined, MoonOutlined, SunOutlined } from "@ant-design/icons"; import { Button, Switch, Tag } from "antd"; @@ -12,6 +13,7 @@ import React, { useEffect, useState } from "react"; import { BlogDropdown } from "./Navbar/BlogDropdown/BlogDropdown"; import { CommunityEngagementButtons } from "./Navbar/CommunityEngagementButtons/CommunityEngagementButtons"; import UserDropdown from "./Navbar/UserDropdown/UserDropdown"; +import WorkerDropdown from "./Navbar/WorkerDropdown/WorkerDropdown"; interface NavbarProps { userID: string | null; @@ -77,9 +79,19 @@ const Navbar: React.FC = ({ const handleLogout = () => { clearTokenCookies(); + localStorage.removeItem("litellm_selected_worker_id"); + localStorage.removeItem("litellm_worker_url"); window.location.href = logoutUrl; }; + const handleWorkerSwitch = (workerId: string) => { + clearTokenCookies(); + clearStoredReturnUrl(); + localStorage.removeItem("litellm_selected_worker_id"); + localStorage.removeItem("litellm_worker_url"); + window.location.href = `/ui/login?worker=${encodeURIComponent(workerId)}`; + }; + return (