fix(vertex): propagate Vertex AI metadata in streaming success callbacks (#29899)
* fix(vertex): propagate Vertex AI metadata in streaming success callbacks Streaming calls assembled via stream_chunk_builder were missing vertex_ai_grounding_metadata and vertex_ai_url_context_metadata in standard_logging_object.response. Merge metadata from chunks into the assembled response and mirror non-streaming hidden_params on Gemini chunks. Co-authored-by: Cursor <cursoragent@cursor.com> * refactor(vertex): move streaming metadata merge into provider config hook Address review feedback by delegating assembled-stream metadata propagation to VertexGeminiConfig via BaseConfig.apply_assembled_streaming_response_metadata, and only write chunk hidden_params when metadata is non-empty. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(redaction): scrub Vertex provider metadata when message logging is off Clear vertex_ai_grounding_metadata and related fields from standard logging responses and assembled streaming ModelResponse objects so turn_off_message_logging cannot leak prompt-derived web search queries. Co-authored-by: Cursor <cursoragent@cursor.com> * Use assembled model for streaming metadata hook * Fix Vertex metadata redaction bypass in logging callbacks. Scrub Vertex provider fields from litellm_params.metadata.hidden_params during perform_redaction so streaming success_handler merges do not leak prompt-derived metadata when message logging is disabled. Co-authored-by: Cursor <cursoragent@cursor.com> * Fix Vertex streaming metadata from hidden params * fix(vertex): mirror vertex_ai_safety_results on assembled streaming responses The non-streaming transform_response stores safety data under vertex_ai_safety_results, but the streaming path only wrote vertex_ai_safety_ratings. Assembled streaming responses therefore never carried vertex_ai_safety_results, so any consumer reading that field saw a silent difference between streaming and non-streaming calls. Set vertex_ai_safety_results alongside vertex_ai_safety_ratings in the shared stream metadata setter and add it to the assembled metadata field list so it propagates through stream_chunk_builder. * fix(streaming): log provider streaming metadata hook failures instead of swallowing them * refactor(vertex): share single Vertex metadata field tuple across redaction and streaming * refactor(vertex): move Vertex metadata redaction helpers into llms/vertex_ai --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
This commit is contained in:
parent
1c881eee5d
commit
dfd6cbc514
@ -17,6 +17,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
get_metadata_variable_name_from_kwargs,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
redact_vertex_ai_metadata_from_litellm_params,
|
||||
redact_vertex_ai_metadata_from_logged_object,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
@ -119,10 +123,12 @@ def _redact_standard_logging_object(model_call_details: dict):
|
||||
# ResponsesAPIResponse format - redact content in output items
|
||||
if isinstance(response.get("output"), list):
|
||||
_redact_responses_api_output_dict(response["output"], redacted_str)
|
||||
redact_vertex_ai_metadata_from_logged_object(response)
|
||||
elif isinstance(response, dict) and "choices" in response:
|
||||
# ModelResponse dict format - redact content in choices
|
||||
if isinstance(response.get("choices"), list):
|
||||
_redact_model_response_dict_choices(response["choices"], redacted_str)
|
||||
redact_vertex_ai_metadata_from_logged_object(response)
|
||||
elif isinstance(response, str):
|
||||
standard_logging_object["response"] = redacted_str
|
||||
else:
|
||||
@ -164,6 +170,7 @@ def perform_redaction(model_call_details: dict, result):
|
||||
model_call_details["prompt"] = ""
|
||||
model_call_details["input"] = ""
|
||||
_redact_standard_logging_object(model_call_details)
|
||||
redact_vertex_ai_metadata_from_litellm_params(model_call_details)
|
||||
|
||||
# Redact streaming response
|
||||
if (
|
||||
@ -174,6 +181,7 @@ def perform_redaction(model_call_details: dict, result):
|
||||
if hasattr(_streaming_response, "choices"):
|
||||
for choice in _streaming_response.choices:
|
||||
_redact_choice_content(choice)
|
||||
redact_vertex_ai_metadata_from_logged_object(_streaming_response)
|
||||
elif hasattr(_streaming_response, "output"):
|
||||
_redact_responses_api_output(_streaming_response.output)
|
||||
# Redact reasoning field in ResponsesAPIResponse
|
||||
@ -200,12 +208,14 @@ def perform_redaction(model_call_details: dict, result):
|
||||
if hasattr(_result, "choices") and _result.choices is not None:
|
||||
for choice in _result.choices:
|
||||
_redact_choice_content(choice)
|
||||
redact_vertex_ai_metadata_from_logged_object(_result)
|
||||
elif isinstance(_result, dict) and "choices" in _result:
|
||||
# Handle dict representation of ModelResponse (e.g., from model_dump())
|
||||
if _result.get("choices") is not None:
|
||||
_redact_model_response_dict_choices(
|
||||
_result["choices"], "redacted-by-litellm"
|
||||
)
|
||||
redact_vertex_ai_metadata_from_logged_object(_result)
|
||||
elif isinstance(_result, dict) and "output" in _result:
|
||||
if isinstance(_result.get("output"), list):
|
||||
_redact_responses_api_output_dict(
|
||||
|
||||
@ -20,6 +20,7 @@ from litellm.types.utils import (
|
||||
ServerToolUse,
|
||||
Usage,
|
||||
)
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.utils import print_verbose, token_counter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -79,6 +80,54 @@ class ChunkProcessor:
|
||||
model_response._hidden_params = chunk.get("_hidden_params", {})
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def apply_provider_assembled_streaming_metadata(
|
||||
response: ModelResponse,
|
||||
chunks: List[Any],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> None:
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
model = getattr(response, "model", None)
|
||||
if not model:
|
||||
return
|
||||
|
||||
custom_llm_provider = None
|
||||
if logging_obj is not None:
|
||||
custom_llm_provider = logging_obj.model_call_details.get(
|
||||
"custom_llm_provider"
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import (
|
||||
get_llm_provider,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
if custom_llm_provider:
|
||||
provider = LlmProviders(custom_llm_provider)
|
||||
else:
|
||||
_, provider_str, _, _ = get_llm_provider(model)
|
||||
provider = LlmProviders(provider_str)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model,
|
||||
provider=provider,
|
||||
)
|
||||
if provider_config is not None:
|
||||
provider_config.apply_assembled_streaming_response_metadata(
|
||||
response=response,
|
||||
chunks=chunks,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
"apply_provider_assembled_streaming_metadata failed for model=%s: %s",
|
||||
model,
|
||||
e,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_chunk_id(chunks: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
|
||||
@ -442,6 +442,14 @@ class BaseConfig(ABC):
|
||||
"""Hook for providers to post-process streaming responses. Default: pass-through."""
|
||||
return stream
|
||||
|
||||
def apply_assembled_streaming_response_metadata(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
chunks: List[Any],
|
||||
) -> None:
|
||||
"""Hook for providers to merge chunk metadata into assembled streaming responses."""
|
||||
return None
|
||||
|
||||
def calculate_additional_costs(
|
||||
self, model: str, prompt_tokens: int, completion_tokens: int
|
||||
) -> Optional[dict]:
|
||||
|
||||
@ -12,7 +12,11 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import PartType, Schema
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_AI_PROVIDER_METADATA_FIELDS,
|
||||
PartType,
|
||||
Schema,
|
||||
)
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
from litellm.utils import supports_response_schema, supports_system_messages
|
||||
|
||||
@ -27,6 +31,47 @@ class VertexAIError(BaseLLMException):
|
||||
super().__init__(message=message, status_code=status_code, headers=headers)
|
||||
|
||||
|
||||
def redact_vertex_ai_metadata_from_logged_object(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
for field in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
if field in obj:
|
||||
obj[field] = []
|
||||
hidden_params = obj.get("_hidden_params")
|
||||
if isinstance(hidden_params, dict):
|
||||
for field in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
hidden_params.pop(field, None)
|
||||
return
|
||||
|
||||
for field in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
if hasattr(obj, field):
|
||||
setattr(obj, field, [])
|
||||
hidden_params = getattr(obj, "_hidden_params", None)
|
||||
if isinstance(hidden_params, dict):
|
||||
for field in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
hidden_params.pop(field, None)
|
||||
|
||||
|
||||
def redact_vertex_ai_metadata_from_litellm_params(model_call_details: dict) -> None:
|
||||
"""
|
||||
success_handler() merges response._hidden_params into
|
||||
litellm_params.metadata['hidden_params'] before redaction runs, so the Vertex
|
||||
metadata must be scrubbed from that copy too.
|
||||
"""
|
||||
litellm_params = model_call_details.get("litellm_params")
|
||||
if not isinstance(litellm_params, dict):
|
||||
return
|
||||
|
||||
for metadata_key in ("metadata", "litellm_metadata"):
|
||||
metadata = litellm_params.get(metadata_key)
|
||||
if not isinstance(metadata, dict):
|
||||
continue
|
||||
hidden_params = metadata.get("hidden_params")
|
||||
if not isinstance(hidden_params, dict):
|
||||
continue
|
||||
for field in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
hidden_params.pop(field, None)
|
||||
|
||||
|
||||
def vertex_request_labels_from_litellm_params(
|
||||
litellm_params: Optional[dict],
|
||||
) -> Optional[Dict[str, str]]:
|
||||
|
||||
@ -63,6 +63,7 @@ from litellm.types.llms.openai import (
|
||||
OpenAIChatCompletionFinishReason,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_AI_PROVIDER_METADATA_FIELDS,
|
||||
VERTEX_CREDENTIALS_TYPES,
|
||||
Candidates,
|
||||
ContentType,
|
||||
@ -2258,6 +2259,71 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||
citation_metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_chunk_attr(chunk: Any, field_name: str) -> Any:
|
||||
if isinstance(chunk, dict):
|
||||
value = chunk.get(field_name)
|
||||
if value is not None:
|
||||
return value
|
||||
model_extra = chunk.get("model_extra")
|
||||
if isinstance(model_extra, dict):
|
||||
value = model_extra.get(field_name)
|
||||
if value is not None:
|
||||
return value
|
||||
hidden_params = chunk.get("_hidden_params")
|
||||
if isinstance(hidden_params, dict):
|
||||
return hidden_params.get(field_name)
|
||||
return None
|
||||
return getattr(chunk, field_name, None)
|
||||
|
||||
@staticmethod
|
||||
def _set_stream_metadata_on_response(
|
||||
model_response: Any,
|
||||
grounding_metadata: List[dict],
|
||||
url_context_metadata: List[dict],
|
||||
safety_ratings: List[dict],
|
||||
citation_metadata: List[dict],
|
||||
) -> None:
|
||||
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) # type: ignore
|
||||
if grounding_metadata:
|
||||
model_response._hidden_params["vertex_ai_grounding_metadata"] = (
|
||||
grounding_metadata
|
||||
)
|
||||
setattr(model_response, "vertex_ai_url_context_metadata", url_context_metadata) # type: ignore
|
||||
if url_context_metadata:
|
||||
model_response._hidden_params["vertex_ai_url_context_metadata"] = (
|
||||
url_context_metadata
|
||||
)
|
||||
setattr(model_response, "vertex_ai_safety_ratings", safety_ratings) # type: ignore
|
||||
setattr(model_response, "vertex_ai_safety_results", safety_ratings) # type: ignore
|
||||
if safety_ratings:
|
||||
model_response._hidden_params["vertex_ai_safety_ratings"] = safety_ratings
|
||||
model_response._hidden_params["vertex_ai_safety_results"] = safety_ratings
|
||||
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) # type: ignore
|
||||
if citation_metadata:
|
||||
model_response._hidden_params["vertex_ai_citation_metadata"] = (
|
||||
citation_metadata
|
||||
)
|
||||
|
||||
def apply_assembled_streaming_response_metadata(
|
||||
self,
|
||||
response: ModelResponse,
|
||||
chunks: List[Any],
|
||||
) -> None:
|
||||
for field_name in VERTEX_AI_PROVIDER_METADATA_FIELDS:
|
||||
merged: List[Any] = []
|
||||
for chunk in chunks:
|
||||
value = VertexGeminiConfig._get_stream_chunk_attr(chunk, field_name)
|
||||
if not value:
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
merged.extend(value)
|
||||
else:
|
||||
merged.append(value)
|
||||
if merged:
|
||||
setattr(response, field_name, merged)
|
||||
response._hidden_params[field_name] = merged
|
||||
|
||||
@staticmethod
|
||||
def _convert_grounding_metadata_to_annotations(
|
||||
grounding_metadata: List[dict],
|
||||
@ -3390,10 +3456,13 @@ class ModelResponseIterator:
|
||||
if choice.finish_reason == "stop":
|
||||
choice.finish_reason = "tool_calls"
|
||||
|
||||
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) # type: ignore
|
||||
setattr(model_response, "vertex_ai_url_context_metadata", url_context_metadata) # type: ignore
|
||||
setattr(model_response, "vertex_ai_safety_ratings", safety_ratings) # type: ignore
|
||||
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) # type: ignore
|
||||
VertexGeminiConfig._set_stream_metadata_on_response(
|
||||
model_response,
|
||||
grounding_metadata,
|
||||
url_context_metadata,
|
||||
safety_ratings,
|
||||
citation_metadata,
|
||||
)
|
||||
|
||||
return (
|
||||
grounding_metadata,
|
||||
|
||||
@ -7761,6 +7761,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||
"cost",
|
||||
logging_obj._response_cost_calculator(result=response),
|
||||
)
|
||||
processor.apply_provider_assembled_streaming_metadata(
|
||||
response, chunks, logging_obj
|
||||
)
|
||||
return response
|
||||
|
||||
tool_call_chunks = [
|
||||
@ -7940,6 +7943,9 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||
usage, "cost", logging_obj._response_cost_calculator(result=response)
|
||||
)
|
||||
|
||||
processor.apply_provider_assembled_streaming_metadata(
|
||||
response, chunks, logging_obj
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
|
||||
@ -758,3 +758,12 @@ class VertexPartnerProvider(str, Enum):
|
||||
llama = "llama"
|
||||
ai21 = "ai21"
|
||||
claude = "claude"
|
||||
|
||||
|
||||
VERTEX_AI_PROVIDER_METADATA_FIELDS = (
|
||||
"vertex_ai_grounding_metadata",
|
||||
"vertex_ai_url_context_metadata",
|
||||
"vertex_ai_safety_ratings",
|
||||
"vertex_ai_safety_results",
|
||||
"vertex_ai_citation_metadata",
|
||||
)
|
||||
|
||||
@ -2165,6 +2165,41 @@ def test_get_assembled_streaming_response_returns_result_for_streaming():
|
||||
assert assembled is result
|
||||
|
||||
|
||||
def test_streaming_success_handler_includes_vertex_ai_metadata_in_standard_logging():
|
||||
"""Assembled streaming responses should include Vertex AI metadata in logging payload."""
|
||||
import datetime
|
||||
|
||||
from litellm.types.utils import Choices, Message
|
||||
|
||||
logging_obj = _make_logging_obj(stream=True)
|
||||
grounding_metadata = [{"webSearchQueries": ["weather in SF"]}]
|
||||
url_context_metadata = [{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}]
|
||||
result = ModelResponse(
|
||||
id="resp-1",
|
||||
choices=[
|
||||
Choices(
|
||||
index=0,
|
||||
message=Message(role="assistant", content="hello"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
model="gemini-2.5-flash",
|
||||
)
|
||||
setattr(result, "vertex_ai_grounding_metadata", grounding_metadata)
|
||||
setattr(result, "vertex_ai_url_context_metadata", url_context_metadata)
|
||||
result._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata
|
||||
result._hidden_params["vertex_ai_url_context_metadata"] = url_context_metadata
|
||||
|
||||
start = datetime.datetime.now()
|
||||
end = datetime.datetime.now()
|
||||
logging_obj.success_handler(result=result, start_time=start, end_time=end)
|
||||
|
||||
payload = logging_obj.model_call_details.get("standard_logging_object")
|
||||
assert payload is not None
|
||||
assert payload["response"]["vertex_ai_grounding_metadata"] == grounding_metadata
|
||||
assert payload["response"]["vertex_ai_url_context_metadata"] == url_context_metadata
|
||||
|
||||
|
||||
def test_get_assembled_streaming_response_returns_none_for_non_streaming_text_completion():
|
||||
"""Non-streaming TextCompletionResponse should also return None."""
|
||||
import datetime
|
||||
|
||||
@ -349,3 +349,96 @@ class TestPerformRedaction:
|
||||
|
||||
assert redacted.output[0].content[0].text == "redacted-by-litellm"
|
||||
assert response.output[0].content[0].text == "sensitive output"
|
||||
|
||||
def test_redacts_vertex_provider_metadata_in_standard_logging_response(self):
|
||||
details = {
|
||||
"standard_logging_object": {
|
||||
"messages": [{"role": "user", "content": "sensitive prompt"}],
|
||||
"response": {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "sensitive answer",
|
||||
"role": "assistant",
|
||||
}
|
||||
}
|
||||
],
|
||||
"vertex_ai_grounding_metadata": [
|
||||
{"webSearchQueries": ["sensitive search term"]}
|
||||
],
|
||||
"vertex_ai_url_context_metadata": [
|
||||
{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
perform_redaction(details, None)
|
||||
|
||||
response = details["standard_logging_object"]["response"]
|
||||
assert response["choices"][0]["message"]["content"] == "redacted-by-litellm"
|
||||
assert response["vertex_ai_grounding_metadata"] == []
|
||||
assert response["vertex_ai_url_context_metadata"] == []
|
||||
|
||||
def test_redacts_vertex_provider_metadata_on_streaming_model_response(self):
|
||||
response = litellm.ModelResponse(
|
||||
id="resp-1",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
message=litellm.Message(
|
||||
content="sensitive answer",
|
||||
role="assistant",
|
||||
)
|
||||
)
|
||||
],
|
||||
model="gemini-2.5-flash",
|
||||
)
|
||||
setattr(
|
||||
response,
|
||||
"vertex_ai_grounding_metadata",
|
||||
[{"webSearchQueries": ["sensitive search term"]}],
|
||||
)
|
||||
response._hidden_params["vertex_ai_grounding_metadata"] = [
|
||||
{"webSearchQueries": ["sensitive search term"]}
|
||||
]
|
||||
|
||||
details = {
|
||||
"stream": True,
|
||||
"complete_streaming_response": response,
|
||||
}
|
||||
|
||||
perform_redaction(details, response)
|
||||
|
||||
assert response.choices[0].message.content == "redacted-by-litellm"
|
||||
assert getattr(response, "vertex_ai_grounding_metadata") == []
|
||||
assert "vertex_ai_grounding_metadata" not in response._hidden_params
|
||||
|
||||
def test_redacts_vertex_provider_metadata_from_metadata_hidden_params(self):
|
||||
"""Streaming success_handler copies _hidden_params into metadata before redaction."""
|
||||
details = {
|
||||
"stream": True,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"hidden_params": {
|
||||
"response_cost": 0.01,
|
||||
"vertex_ai_grounding_metadata": [
|
||||
{"webSearchQueries": ["sensitive search term"]}
|
||||
],
|
||||
"vertex_ai_url_context_metadata": [
|
||||
{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}
|
||||
],
|
||||
"vertex_ai_safety_ratings": [{"category": "HARM"}],
|
||||
"vertex_ai_citation_metadata": [{"citations": ["source"]}],
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
perform_redaction(details, None)
|
||||
|
||||
hidden_params = details["litellm_params"]["metadata"]["hidden_params"]
|
||||
assert hidden_params["response_cost"] == 0.01
|
||||
assert "vertex_ai_grounding_metadata" not in hidden_params
|
||||
assert "vertex_ai_url_context_metadata" not in hidden_params
|
||||
assert "vertex_ai_safety_ratings" not in hidden_params
|
||||
assert "vertex_ai_citation_metadata" not in hidden_params
|
||||
|
||||
@ -613,3 +613,153 @@ def test_stream_chunk_builder_dict_snapshot_preserves_hidden_provider_fields():
|
||||
assert (
|
||||
response._hidden_params["provider_specific_fields"]["traffic_type"] == "default"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chunk_builder_propagates_vertex_ai_metadata_from_chunks():
|
||||
"""Vertex AI metadata on streaming chunks must appear on assembled response."""
|
||||
grounding_metadata = [{"webSearchQueries": ["weather in SF"]}]
|
||||
url_context_metadata = [{"urlMetadata": [{"retrievedUrl": "https://example.com"}]}]
|
||||
|
||||
chunk1 = ModelResponseStream(
|
||||
id="chatcmpl-vertex-1",
|
||||
created=1,
|
||||
model="gemini-2.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content="The weather", role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
setattr(chunk1, "vertex_ai_grounding_metadata", grounding_metadata)
|
||||
chunk1._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata
|
||||
|
||||
chunk2 = ModelResponseStream(
|
||||
id="chatcmpl-vertex-1",
|
||||
created=1,
|
||||
model="gemini-2.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content=" is sunny.", role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
setattr(chunk2, "vertex_ai_url_context_metadata", url_context_metadata)
|
||||
chunk2._hidden_params["vertex_ai_url_context_metadata"] = url_context_metadata
|
||||
|
||||
response = stream_chunk_builder(chunks=[chunk1, chunk2])
|
||||
assert response is not None
|
||||
assert getattr(response, "vertex_ai_grounding_metadata") == grounding_metadata
|
||||
assert getattr(response, "vertex_ai_url_context_metadata") == url_context_metadata
|
||||
assert response._hidden_params["vertex_ai_grounding_metadata"] == grounding_metadata
|
||||
assert (
|
||||
response._hidden_params["vertex_ai_url_context_metadata"]
|
||||
== url_context_metadata
|
||||
)
|
||||
|
||||
dumped = response.model_dump()
|
||||
assert dumped["vertex_ai_grounding_metadata"] == grounding_metadata
|
||||
assert dumped["vertex_ai_url_context_metadata"] == url_context_metadata
|
||||
|
||||
|
||||
def test_stream_chunk_builder_uses_assembled_model_for_provider_metadata():
|
||||
grounding_metadata = [{"webSearchQueries": ["weather in SF"]}]
|
||||
|
||||
chunk1 = ModelResponseStream(
|
||||
id="chatcmpl-vertex-router",
|
||||
created=1,
|
||||
model="gpt-4o",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content="The weather", role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
chunk2 = ModelResponseStream(
|
||||
id="chatcmpl-vertex-router",
|
||||
created=1,
|
||||
model="gemini-2.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content=" is sunny.", role=None),
|
||||
)
|
||||
],
|
||||
)
|
||||
setattr(chunk2, "vertex_ai_grounding_metadata", grounding_metadata)
|
||||
chunk2._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata
|
||||
|
||||
response = stream_chunk_builder(chunks=[chunk1, chunk2])
|
||||
assert response is not None
|
||||
assert response.model == "gemini-2.5-flash"
|
||||
assert getattr(response, "vertex_ai_grounding_metadata") == grounding_metadata
|
||||
|
||||
|
||||
def test_stream_chunk_builder_propagates_vertex_ai_safety_results():
|
||||
"""Assembled response must expose safety data under the non-streaming field name."""
|
||||
safety_ratings = [
|
||||
[{"category": "HARM_CATEGORY_HATE_SPEECH", "probability": "NEGLIGIBLE"}]
|
||||
]
|
||||
|
||||
chunk = ModelResponseStream(
|
||||
id="chatcmpl-vertex-safety",
|
||||
created=1,
|
||||
model="gemini-2.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content="hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
setattr(chunk, "vertex_ai_safety_ratings", safety_ratings)
|
||||
setattr(chunk, "vertex_ai_safety_results", safety_ratings)
|
||||
chunk._hidden_params["vertex_ai_safety_ratings"] = safety_ratings
|
||||
chunk._hidden_params["vertex_ai_safety_results"] = safety_ratings
|
||||
|
||||
response = stream_chunk_builder(chunks=[chunk])
|
||||
assert response is not None
|
||||
assert getattr(response, "vertex_ai_safety_results") == safety_ratings
|
||||
assert response._hidden_params["vertex_ai_safety_results"] == safety_ratings
|
||||
assert response.model_dump()["vertex_ai_safety_results"] == safety_ratings
|
||||
|
||||
|
||||
def test_stream_chunk_builder_propagates_vertex_ai_metadata_from_dict_chunks():
|
||||
"""Dict snapshot chunks (model_dump) should also propagate Vertex AI metadata."""
|
||||
chunk_dict = ModelResponseStream(
|
||||
id="chatcmpl-vertex-2",
|
||||
created=1,
|
||||
model="gemini-2.5-flash",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content="hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
).model_dump()
|
||||
chunk_dict["_hidden_params"] = {
|
||||
"vertex_ai_grounding_metadata": [{"webSearchQueries": ["test query"]}]
|
||||
}
|
||||
|
||||
response = stream_chunk_builder(chunks=[chunk_dict])
|
||||
assert response is not None
|
||||
assert getattr(response, "vertex_ai_grounding_metadata") == [
|
||||
{"webSearchQueries": ["test query"]}
|
||||
]
|
||||
assert response.model_dump()["vertex_ai_grounding_metadata"] == [
|
||||
{"webSearchQueries": ["test query"]}
|
||||
]
|
||||
|
||||
@ -1459,6 +1459,26 @@ def test_vertex_ai_process_candidates_with_grounding_metadata():
|
||||
assert len(result[0]) == 1
|
||||
|
||||
|
||||
def test_set_stream_metadata_mirrors_non_streaming_safety_field_names():
|
||||
safety_ratings = [
|
||||
[{"category": "HARM_CATEGORY_HATE_SPEECH", "probability": "NEGLIGIBLE"}]
|
||||
]
|
||||
|
||||
model_response = ModelResponse()
|
||||
VertexGeminiConfig._set_stream_metadata_on_response(
|
||||
model_response=model_response,
|
||||
grounding_metadata=[],
|
||||
url_context_metadata=[],
|
||||
safety_ratings=safety_ratings,
|
||||
citation_metadata=[],
|
||||
)
|
||||
|
||||
assert getattr(model_response, "vertex_ai_safety_ratings") == safety_ratings
|
||||
assert getattr(model_response, "vertex_ai_safety_results") == safety_ratings
|
||||
assert model_response._hidden_params["vertex_ai_safety_ratings"] == safety_ratings
|
||||
assert model_response._hidden_params["vertex_ai_safety_results"] == safety_ratings
|
||||
|
||||
|
||||
def test_vertex_ai_tool_call_id_format():
|
||||
"""
|
||||
Test that tool call IDs have the correct format and length.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user