diff --git a/litellm/litellm_core_utils/redact_messages.py b/litellm/litellm_core_utils/redact_messages.py index dbc9cabdc7..763596336a 100644 --- a/litellm/litellm_core_utils/redact_messages.py +++ b/litellm/litellm_core_utils/redact_messages.py @@ -17,6 +17,10 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import ( get_metadata_variable_name_from_kwargs, ) +from litellm.llms.vertex_ai.common_utils import ( + redact_vertex_ai_metadata_from_litellm_params, + redact_vertex_ai_metadata_from_logged_object, +) from litellm.secret_managers.main import str_to_bool from litellm.types.utils import StandardCallbackDynamicParams @@ -119,10 +123,12 @@ def _redact_standard_logging_object(model_call_details: dict): # ResponsesAPIResponse format - redact content in output items if isinstance(response.get("output"), list): _redact_responses_api_output_dict(response["output"], redacted_str) + redact_vertex_ai_metadata_from_logged_object(response) elif isinstance(response, dict) and "choices" in response: # ModelResponse dict format - redact content in choices if isinstance(response.get("choices"), list): _redact_model_response_dict_choices(response["choices"], redacted_str) + redact_vertex_ai_metadata_from_logged_object(response) elif isinstance(response, str): standard_logging_object["response"] = redacted_str else: @@ -164,6 +170,7 @@ def perform_redaction(model_call_details: dict, result): model_call_details["prompt"] = "" model_call_details["input"] = "" _redact_standard_logging_object(model_call_details) + redact_vertex_ai_metadata_from_litellm_params(model_call_details) # Redact streaming response if ( @@ -174,6 +181,7 @@ def perform_redaction(model_call_details: dict, result): if hasattr(_streaming_response, "choices"): for choice in _streaming_response.choices: _redact_choice_content(choice) + redact_vertex_ai_metadata_from_logged_object(_streaming_response) elif hasattr(_streaming_response, "output"): _redact_responses_api_output(_streaming_response.output) # Redact reasoning field in ResponsesAPIResponse @@ -200,12 +208,14 @@ def perform_redaction(model_call_details: dict, result): if hasattr(_result, "choices") and _result.choices is not None: for choice in _result.choices: _redact_choice_content(choice) + redact_vertex_ai_metadata_from_logged_object(_result) elif isinstance(_result, dict) and "choices" in _result: # Handle dict representation of ModelResponse (e.g., from model_dump()) if _result.get("choices") is not None: _redact_model_response_dict_choices( _result["choices"], "redacted-by-litellm" ) + redact_vertex_ai_metadata_from_logged_object(_result) elif isinstance(_result, dict) and "output" in _result: if isinstance(_result.get("output"), list): _redact_responses_api_output_dict( diff --git a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py index fe7c62c384..6257cce9ae 100644 --- a/litellm/litellm_core_utils/streaming_chunk_builder_utils.py +++ b/litellm/litellm_core_utils/streaming_chunk_builder_utils.py @@ -20,6 +20,7 @@ from litellm.types.utils import ( ServerToolUse, Usage, ) +from litellm._logging import verbose_logger from litellm.utils import print_verbose, token_counter if TYPE_CHECKING: @@ -79,6 +80,54 @@ class ChunkProcessor: model_response._hidden_params = chunk.get("_hidden_params", {}) return model_response + @staticmethod + def apply_provider_assembled_streaming_metadata( + response: ModelResponse, + chunks: List[Any], + logging_obj: Optional[Any] = None, + ) -> None: + if not chunks: + return + + model = getattr(response, "model", None) + if not model: + return + + custom_llm_provider = None + if logging_obj is not None: + custom_llm_provider = logging_obj.model_call_details.get( + "custom_llm_provider" + ) + + try: + from litellm.litellm_core_utils.get_llm_provider_logic import ( + get_llm_provider, + ) + from litellm.types.utils import LlmProviders + from litellm.utils import ProviderConfigManager + + if custom_llm_provider: + provider = LlmProviders(custom_llm_provider) + else: + _, provider_str, _, _ = get_llm_provider(model) + provider = LlmProviders(provider_str) + + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, + provider=provider, + ) + if provider_config is not None: + provider_config.apply_assembled_streaming_response_metadata( + response=response, + chunks=chunks, + ) + except Exception as e: + verbose_logger.debug( + "apply_provider_assembled_streaming_metadata failed for model=%s: %s", + model, + e, + ) + @staticmethod def _get_chunk_id(chunks: List[Dict[str, Any]]) -> str: """ diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 5f35a58ce1..8f9d5cad7c 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -442,6 +442,14 @@ class BaseConfig(ABC): """Hook for providers to post-process streaming responses. Default: pass-through.""" return stream + def apply_assembled_streaming_response_metadata( + self, + response: "ModelResponse", + chunks: List[Any], + ) -> None: + """Hook for providers to merge chunk metadata into assembled streaming responses.""" + return None + def calculate_additional_costs( self, model: str, prompt_tokens: int, completion_tokens: int ) -> Optional[dict]: diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index e6e3965110..85c23d8603 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -12,7 +12,11 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.openai import AllMessageValues -from litellm.types.llms.vertex_ai import PartType, Schema +from litellm.types.llms.vertex_ai import ( + VERTEX_AI_PROVIDER_METADATA_FIELDS, + PartType, + Schema, +) from litellm.types.utils import TokenCountResponse from litellm.utils import supports_response_schema, supports_system_messages @@ -27,6 +31,47 @@ class VertexAIError(BaseLLMException): super().__init__(message=message, status_code=status_code, headers=headers) +def redact_vertex_ai_metadata_from_logged_object(obj: Any) -> None: + if isinstance(obj, dict): + for field in VERTEX_AI_PROVIDER_METADATA_FIELDS: + if field in obj: + obj[field] = [] + hidden_params = obj.get("_hidden_params") + if isinstance(hidden_params, dict): + for field in VERTEX_AI_PROVIDER_METADATA_FIELDS: + hidden_params.pop(field, None) + return + + for field in VERTEX_AI_PROVIDER_METADATA_FIELDS: + if hasattr(obj, field): + setattr(obj, field, []) + hidden_params = getattr(obj, "_hidden_params", None) + if isinstance(hidden_params, dict): + for field in VERTEX_AI_PROVIDER_METADATA_FIELDS: + hidden_params.pop(field, None) + + +def redact_vertex_ai_metadata_from_litellm_params(model_call_details: dict) -> None: + """ + success_handler() merges response._hidden_params into + litellm_params.metadata['hidden_params'] before redaction runs, so the Vertex + metadata must be scrubbed from that copy too. + """ + litellm_params = model_call_details.get("litellm_params") + if not isinstance(litellm_params, dict): + return + + for metadata_key in ("metadata", "litellm_metadata"): + metadata = litellm_params.get(metadata_key) + if not isinstance(metadata, dict): + continue + hidden_params = metadata.get("hidden_params") + if not isinstance(hidden_params, dict): + continue + for field in VERTEX_AI_PROVIDER_METADATA_FIELDS: + hidden_params.pop(field, None) + + def vertex_request_labels_from_litellm_params( litellm_params: Optional[dict], ) -> Optional[Dict[str, str]]: diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 7d355a2e90..430a789d2a 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -63,6 +63,7 @@ from litellm.types.llms.openai import ( OpenAIChatCompletionFinishReason, ) from litellm.types.llms.vertex_ai import ( + VERTEX_AI_PROVIDER_METADATA_FIELDS, VERTEX_CREDENTIALS_TYPES, Candidates, ContentType, @@ -2258,6 +2259,71 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): citation_metadata, ) + @staticmethod + def _get_stream_chunk_attr(chunk: Any, field_name: str) -> Any: + if isinstance(chunk, dict): + value = chunk.get(field_name) + if value is not None: + return value + model_extra = chunk.get("model_extra") + if isinstance(model_extra, dict): + value = model_extra.get(field_name) + if value is not None: + return value + hidden_params = chunk.get("_hidden_params") + if isinstance(hidden_params, dict): + return hidden_params.get(field_name) + return None + return getattr(chunk, field_name, None) + + @staticmethod + def _set_stream_metadata_on_response( + model_response: Any, + grounding_metadata: List[dict], + url_context_metadata: List[dict], + safety_ratings: List[dict], + citation_metadata: List[dict], + ) -> None: + setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) # type: ignore + if grounding_metadata: + model_response._hidden_params["vertex_ai_grounding_metadata"] = ( + grounding_metadata + ) + setattr(model_response, "vertex_ai_url_context_metadata", url_context_metadata) # type: ignore + if url_context_metadata: + model_response._hidden_params["vertex_ai_url_context_metadata"] = ( + url_context_metadata + ) + setattr(model_response, "vertex_ai_safety_ratings", safety_ratings) # type: ignore + setattr(model_response, "vertex_ai_safety_results", safety_ratings) # type: ignore + if safety_ratings: + model_response._hidden_params["vertex_ai_safety_ratings"] = safety_ratings + model_response._hidden_params["vertex_ai_safety_results"] = safety_ratings + setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) # type: ignore + if citation_metadata: + model_response._hidden_params["vertex_ai_citation_metadata"] = ( + citation_metadata + ) + + def apply_assembled_streaming_response_metadata( + self, + response: ModelResponse, + chunks: List[Any], + ) -> None: + for field_name in VERTEX_AI_PROVIDER_METADATA_FIELDS: + merged: List[Any] = [] + for chunk in chunks: + value = VertexGeminiConfig._get_stream_chunk_attr(chunk, field_name) + if not value: + continue + if isinstance(value, list): + merged.extend(value) + else: + merged.append(value) + if merged: + setattr(response, field_name, merged) + response._hidden_params[field_name] = merged + @staticmethod def _convert_grounding_metadata_to_annotations( grounding_metadata: List[dict], @@ -3390,10 +3456,13 @@ class ModelResponseIterator: if choice.finish_reason == "stop": choice.finish_reason = "tool_calls" - setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) # type: ignore - setattr(model_response, "vertex_ai_url_context_metadata", url_context_metadata) # type: ignore - setattr(model_response, "vertex_ai_safety_ratings", safety_ratings) # type: ignore - setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) # type: ignore + VertexGeminiConfig._set_stream_metadata_on_response( + model_response, + grounding_metadata, + url_context_metadata, + safety_ratings, + citation_metadata, + ) return ( grounding_metadata, diff --git a/litellm/main.py b/litellm/main.py index 64891e2def..1a0d0312d7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -7761,6 +7761,9 @@ def stream_chunk_builder( # noqa: PLR0915 "cost", logging_obj._response_cost_calculator(result=response), ) + processor.apply_provider_assembled_streaming_metadata( + response, chunks, logging_obj + ) return response tool_call_chunks = [ @@ -7940,6 +7943,9 @@ def stream_chunk_builder( # noqa: PLR0915 usage, "cost", logging_obj._response_cost_calculator(result=response) ) + processor.apply_provider_assembled_streaming_metadata( + response, chunks, logging_obj + ) return response except Exception as e: verbose_logger.exception( diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 00db7b199b..b28fee5128 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -758,3 +758,12 @@ class VertexPartnerProvider(str, Enum): llama = "llama" ai21 = "ai21" claude = "claude" + + +VERTEX_AI_PROVIDER_METADATA_FIELDS = ( + "vertex_ai_grounding_metadata", + "vertex_ai_url_context_metadata", + "vertex_ai_safety_ratings", + "vertex_ai_safety_results", + "vertex_ai_citation_metadata", +) diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index d57d8dafdb..34edd6eccf 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -2165,6 +2165,41 @@ def test_get_assembled_streaming_response_returns_result_for_streaming(): assert assembled is result +def test_streaming_success_handler_includes_vertex_ai_metadata_in_standard_logging(): + """Assembled streaming responses should include Vertex AI metadata in logging payload.""" + import datetime + + from litellm.types.utils import Choices, Message + + logging_obj = _make_logging_obj(stream=True) + grounding_metadata = [{"webSearchQueries": ["weather in SF"]}] + url_context_metadata = [{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}] + result = ModelResponse( + id="resp-1", + choices=[ + Choices( + index=0, + message=Message(role="assistant", content="hello"), + finish_reason="stop", + ) + ], + model="gemini-2.5-flash", + ) + setattr(result, "vertex_ai_grounding_metadata", grounding_metadata) + setattr(result, "vertex_ai_url_context_metadata", url_context_metadata) + result._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata + result._hidden_params["vertex_ai_url_context_metadata"] = url_context_metadata + + start = datetime.datetime.now() + end = datetime.datetime.now() + logging_obj.success_handler(result=result, start_time=start, end_time=end) + + payload = logging_obj.model_call_details.get("standard_logging_object") + assert payload is not None + assert payload["response"]["vertex_ai_grounding_metadata"] == grounding_metadata + assert payload["response"]["vertex_ai_url_context_metadata"] == url_context_metadata + + def test_get_assembled_streaming_response_returns_none_for_non_streaming_text_completion(): """Non-streaming TextCompletionResponse should also return None.""" import datetime diff --git a/tests/test_litellm/litellm_core_utils/test_redact_messages.py b/tests/test_litellm/litellm_core_utils/test_redact_messages.py index 60cfff6e4a..36f220f9a2 100644 --- a/tests/test_litellm/litellm_core_utils/test_redact_messages.py +++ b/tests/test_litellm/litellm_core_utils/test_redact_messages.py @@ -349,3 +349,96 @@ class TestPerformRedaction: assert redacted.output[0].content[0].text == "redacted-by-litellm" assert response.output[0].content[0].text == "sensitive output" + + def test_redacts_vertex_provider_metadata_in_standard_logging_response(self): + details = { + "standard_logging_object": { + "messages": [{"role": "user", "content": "sensitive prompt"}], + "response": { + "choices": [ + { + "message": { + "content": "sensitive answer", + "role": "assistant", + } + } + ], + "vertex_ai_grounding_metadata": [ + {"webSearchQueries": ["sensitive search term"]} + ], + "vertex_ai_url_context_metadata": [ + {"urlMetadata": [{"retrievedUrl": "https://example.com"}]} + ], + }, + } + } + + perform_redaction(details, None) + + response = details["standard_logging_object"]["response"] + assert response["choices"][0]["message"]["content"] == "redacted-by-litellm" + assert response["vertex_ai_grounding_metadata"] == [] + assert response["vertex_ai_url_context_metadata"] == [] + + def test_redacts_vertex_provider_metadata_on_streaming_model_response(self): + response = litellm.ModelResponse( + id="resp-1", + choices=[ + litellm.Choices( + message=litellm.Message( + content="sensitive answer", + role="assistant", + ) + ) + ], + model="gemini-2.5-flash", + ) + setattr( + response, + "vertex_ai_grounding_metadata", + [{"webSearchQueries": ["sensitive search term"]}], + ) + response._hidden_params["vertex_ai_grounding_metadata"] = [ + {"webSearchQueries": ["sensitive search term"]} + ] + + details = { + "stream": True, + "complete_streaming_response": response, + } + + perform_redaction(details, response) + + assert response.choices[0].message.content == "redacted-by-litellm" + assert getattr(response, "vertex_ai_grounding_metadata") == [] + assert "vertex_ai_grounding_metadata" not in response._hidden_params + + def test_redacts_vertex_provider_metadata_from_metadata_hidden_params(self): + """Streaming success_handler copies _hidden_params into metadata before redaction.""" + details = { + "stream": True, + "litellm_params": { + "metadata": { + "hidden_params": { + "response_cost": 0.01, + "vertex_ai_grounding_metadata": [ + {"webSearchQueries": ["sensitive search term"]} + ], + "vertex_ai_url_context_metadata": [ + {"urlMetadata": [{"retrievedUrl": "https://example.com"}]} + ], + "vertex_ai_safety_ratings": [{"category": "HARM"}], + "vertex_ai_citation_metadata": [{"citations": ["source"]}], + } + } + }, + } + + perform_redaction(details, None) + + hidden_params = details["litellm_params"]["metadata"]["hidden_params"] + assert hidden_params["response_cost"] == 0.01 + assert "vertex_ai_grounding_metadata" not in hidden_params + assert "vertex_ai_url_context_metadata" not in hidden_params + assert "vertex_ai_safety_ratings" not in hidden_params + assert "vertex_ai_citation_metadata" not in hidden_params diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_chunk_builder_utils.py b/tests/test_litellm/litellm_core_utils/test_streaming_chunk_builder_utils.py index e40a0817fd..77765340c6 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_chunk_builder_utils.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_chunk_builder_utils.py @@ -613,3 +613,153 @@ def test_stream_chunk_builder_dict_snapshot_preserves_hidden_provider_fields(): assert ( response._hidden_params["provider_specific_fields"]["traffic_type"] == "default" ) + + +def test_stream_chunk_builder_propagates_vertex_ai_metadata_from_chunks(): + """Vertex AI metadata on streaming chunks must appear on assembled response.""" + grounding_metadata = [{"webSearchQueries": ["weather in SF"]}] + url_context_metadata = [{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}] + + chunk1 = ModelResponseStream( + id="chatcmpl-vertex-1", + created=1, + model="gemini-2.5-flash", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="The weather", role="assistant"), + ) + ], + ) + setattr(chunk1, "vertex_ai_grounding_metadata", grounding_metadata) + chunk1._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata + + chunk2 = ModelResponseStream( + id="chatcmpl-vertex-1", + created=1, + model="gemini-2.5-flash", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason="stop", + index=0, + delta=Delta(content=" is sunny.", role="assistant"), + ) + ], + ) + setattr(chunk2, "vertex_ai_url_context_metadata", url_context_metadata) + chunk2._hidden_params["vertex_ai_url_context_metadata"] = url_context_metadata + + response = stream_chunk_builder(chunks=[chunk1, chunk2]) + assert response is not None + assert getattr(response, "vertex_ai_grounding_metadata") == grounding_metadata + assert getattr(response, "vertex_ai_url_context_metadata") == url_context_metadata + assert response._hidden_params["vertex_ai_grounding_metadata"] == grounding_metadata + assert ( + response._hidden_params["vertex_ai_url_context_metadata"] + == url_context_metadata + ) + + dumped = response.model_dump() + assert dumped["vertex_ai_grounding_metadata"] == grounding_metadata + assert dumped["vertex_ai_url_context_metadata"] == url_context_metadata + + +def test_stream_chunk_builder_uses_assembled_model_for_provider_metadata(): + grounding_metadata = [{"webSearchQueries": ["weather in SF"]}] + + chunk1 = ModelResponseStream( + id="chatcmpl-vertex-router", + created=1, + model="gpt-4o", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="The weather", role="assistant"), + ) + ], + ) + chunk2 = ModelResponseStream( + id="chatcmpl-vertex-router", + created=1, + model="gemini-2.5-flash", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason="stop", + index=0, + delta=Delta(content=" is sunny.", role=None), + ) + ], + ) + setattr(chunk2, "vertex_ai_grounding_metadata", grounding_metadata) + chunk2._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata + + response = stream_chunk_builder(chunks=[chunk1, chunk2]) + assert response is not None + assert response.model == "gemini-2.5-flash" + assert getattr(response, "vertex_ai_grounding_metadata") == grounding_metadata + + +def test_stream_chunk_builder_propagates_vertex_ai_safety_results(): + """Assembled response must expose safety data under the non-streaming field name.""" + safety_ratings = [ + [{"category": "HARM_CATEGORY_HATE_SPEECH", "probability": "NEGLIGIBLE"}] + ] + + chunk = ModelResponseStream( + id="chatcmpl-vertex-safety", + created=1, + model="gemini-2.5-flash", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason="stop", + index=0, + delta=Delta(content="hello", role="assistant"), + ) + ], + ) + setattr(chunk, "vertex_ai_safety_ratings", safety_ratings) + setattr(chunk, "vertex_ai_safety_results", safety_ratings) + chunk._hidden_params["vertex_ai_safety_ratings"] = safety_ratings + chunk._hidden_params["vertex_ai_safety_results"] = safety_ratings + + response = stream_chunk_builder(chunks=[chunk]) + assert response is not None + assert getattr(response, "vertex_ai_safety_results") == safety_ratings + assert response._hidden_params["vertex_ai_safety_results"] == safety_ratings + assert response.model_dump()["vertex_ai_safety_results"] == safety_ratings + + +def test_stream_chunk_builder_propagates_vertex_ai_metadata_from_dict_chunks(): + """Dict snapshot chunks (model_dump) should also propagate Vertex AI metadata.""" + chunk_dict = ModelResponseStream( + id="chatcmpl-vertex-2", + created=1, + model="gemini-2.5-flash", + object="chat.completion.chunk", + choices=[ + StreamingChoices( + finish_reason="stop", + index=0, + delta=Delta(content="hello", role="assistant"), + ) + ], + ).model_dump() + chunk_dict["_hidden_params"] = { + "vertex_ai_grounding_metadata": [{"webSearchQueries": ["test query"]}] + } + + response = stream_chunk_builder(chunks=[chunk_dict]) + assert response is not None + assert getattr(response, "vertex_ai_grounding_metadata") == [ + {"webSearchQueries": ["test query"]} + ] + assert response.model_dump()["vertex_ai_grounding_metadata"] == [ + {"webSearchQueries": ["test query"]} + ] diff --git a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py index 0d02521433..671d7355e8 100644 --- a/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py +++ b/tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py @@ -1459,6 +1459,26 @@ def test_vertex_ai_process_candidates_with_grounding_metadata(): assert len(result[0]) == 1 +def test_set_stream_metadata_mirrors_non_streaming_safety_field_names(): + safety_ratings = [ + [{"category": "HARM_CATEGORY_HATE_SPEECH", "probability": "NEGLIGIBLE"}] + ] + + model_response = ModelResponse() + VertexGeminiConfig._set_stream_metadata_on_response( + model_response=model_response, + grounding_metadata=[], + url_context_metadata=[], + safety_ratings=safety_ratings, + citation_metadata=[], + ) + + assert getattr(model_response, "vertex_ai_safety_ratings") == safety_ratings + assert getattr(model_response, "vertex_ai_safety_results") == safety_ratings + assert model_response._hidden_params["vertex_ai_safety_ratings"] == safety_ratings + assert model_response._hidden_params["vertex_ai_safety_results"] == safety_ratings + + def test_vertex_ai_tool_call_id_format(): """ Test that tool call IDs have the correct format and length.