Agent Tracing - support context_id based trace id propogation + nested llm calls (#22626)

* style(ui/): distinguish agent calls from llm calls on ui

* feat: initial grouping working

* feat: set stable contextid for a2a calls - allows for easily passing to downstream llm/mcp calls

* feat(a2a_endpoints.py): fix tracing to avoid recreating logging objects for the same call

allows stable trace id usage

* fix(guardrail_endpoints): handle string ui_type values in _build_field_dict

_build_field_dict unconditionally called .value on ui_type, which crashes
for guardrail configs that use plain strings (e.g. BlockCodeExecutionGuardrailConfigModel
uses "multiselect" and "percentage"). Now checks with hasattr before calling .value.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: propagate trace/session id from headers in MCP server calls

Cherry-picked mcp_server/server.py fixes from 6feb9bab: adds
get_chain_id_from_headers to extract x-litellm-trace-id /
x-litellm-session-id from raw headers, and uses it in call_tool
and list_tools to keep spend logs and tracing consistent with A2A.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Krish Dholakia 2026-03-03 18:19:12 -08:00 committed by GitHub
parent ba7a6d9bfd
commit 90eb6729d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 579 additions and 358 deletions

View File

@ -930,7 +930,7 @@ For Responses API with load balancing across deployments with **different API ke
Notes:
- User-key affinity is keyed on `metadata.user_api_key_hash` (the API key hash). The OpenAI `user` request parameter is an end-user identifier and is intentionally not used for deployment affinity.
- Session-ID affinity is keyed on `metadata.session_id`. For proxy requests, this can be passed via the `x-litellm-session-id` HTTP header. For Python SDK requests, you can pass it via `litellm_metadata={"session_id": "value"}` in request args.
- Session-ID affinity is keyed on `metadata.session_id`. For proxy requests, this can be passed via the `x-litellm-session-id` or `x-litellm-trace-id` HTTP header (they are interchangeable for call chaining). For Python SDK requests, you can pass it via `litellm_metadata={"session_id": "value"}` in request args.
- `user_api_key_hash` is already SHA-256, and is used as-is (no double hashing).
- Affinity is scoped by a stable model identifier (the model-map key, e.g. `model_map_information.model_map_key`) so model aliases map to the same stickiness bucket.
- The mapping TTL is controlled by `deployment_affinity_ttl_seconds` (configured on Router init / proxy startup).

View File

@ -24,11 +24,7 @@ from litellm.utils import client
if TYPE_CHECKING:
from a2a.client import A2AClient as A2AClientType
from a2a.types import (
AgentCard,
SendMessageRequest,
SendStreamingMessageRequest,
)
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
# Runtime imports with availability check
A2A_SDK_AVAILABLE = False
@ -124,13 +120,48 @@ def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
litellm_logging_obj.model = model
litellm_logging_obj.custom_llm_provider = custom_llm_provider
litellm_logging_obj.model_call_details["model"] = model
litellm_logging_obj.model_call_details[
"custom_llm_provider"
] = custom_llm_provider
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
custom_llm_provider
)
return agent_name
async def _send_message_via_completion_bridge(
request: "SendMessageRequest",
custom_llm_provider: str,
api_base: Optional[str],
litellm_params: Dict[str, Any],
) -> LiteLLMSendMessageResponse:
"""
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
Requires request; api_base is optional for providers that derive endpoint from model.
"""
verbose_logger.info(
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
return LiteLLMSendMessageResponse.from_dict(response_dict)
@client
async def asend_message(
a2a_client: Optional["A2AClientType"] = None,
@ -193,39 +224,21 @@ async def asend_message(
```
"""
litellm_params = litellm_params or {}
logging_obj = kwargs.get("litellm_logging_obj")
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
custom_llm_provider = litellm_params.get("custom_llm_provider")
# Route through completion bridge if custom_llm_provider is set
if custom_llm_provider:
if request is None:
raise ValueError("request is required for completion bridge")
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
verbose_logger.info(
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
)
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
A2ACompletionBridgeHandler,
)
# Extract params from request
params = (
request.params.model_dump(mode="json")
if hasattr(request.params, "model_dump")
else dict(request.params)
)
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=str(request.id),
params=params,
litellm_params=litellm_params,
return await _send_message_via_completion_bridge(
request=request,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
litellm_params=litellm_params,
)
# Convert to LiteLLMSendMessageResponse
return LiteLLMSendMessageResponse.from_dict(response_dict)
# Standard A2A client flow
if request is None:
raise ValueError("request is required")
@ -236,11 +249,13 @@ async def asend_message(
raise ValueError(
"Either a2a_client or api_base is required for standard A2A flow"
)
trace_id = str(uuid.uuid4())
trace_id = trace_id or str(uuid.uuid4())
extra_headers = {"X-LiteLLM-Trace-Id": trace_id}
if agent_id:
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
a2a_client = await create_a2a_client(base_url=api_base, extra_headers=extra_headers)
a2a_client = await create_a2a_client(
base_url=api_base, extra_headers=extra_headers
)
# Type assertion: a2a_client is guaranteed to be non-None here
assert a2a_client is not None
@ -255,6 +270,10 @@ async def asend_message(
)
card_url = getattr(agent_card, "url", None) if agent_card else None
context_id = trace_id or str(uuid.uuid4())
if request.params.message.context_id is None:
request.params.message.context_id = context_id
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
a2a_response = None
for _ in range(2): # max 2 attempts: original + 1 retry
@ -606,7 +625,9 @@ async def create_a2a_client(
if extra_headers:
httpx_client.headers.update(extra_headers)
verbose_proxy_logger.debug(f"A2A client created with extra_headers={extra_headers}")
verbose_proxy_logger.debug(
f"A2A client created with extra_headers={extra_headers}"
)
# Resolve agent card
resolver = A2ACardResolver(

View File

@ -1,6 +1,5 @@
from typing import Optional
# Pre-define optional kwargs keys as frozenset for O(1) lookups
# These are extracted from kwargs only if present, avoiding unnecessary .get() calls
_OPTIONAL_KWARGS_KEYS = frozenset({
@ -95,6 +94,13 @@ def get_litellm_params(
litellm_request_debug: Optional[bool] = None,
**kwargs,
) -> dict:
# Derive litellm_session_id / litellm_trace_id from metadata when not provided (call chaining)
_meta = metadata or {}
if litellm_session_id is None:
litellm_session_id = _meta.get("session_id") or _meta.get("trace_id")
if litellm_trace_id is None:
litellm_trace_id = _meta.get("trace_id") or _meta.get("session_id")
# Build base dict with explicit parameters (always included)
litellm_params = {
"acompletion": acompletion,

View File

@ -133,8 +133,8 @@ from ..integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger
from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger
from ..integrations.custom_prompt_management import CustomPromptManagement
from ..integrations.datadog.datadog import DataDogLogger
from ..integrations.datadog.datadog_metrics import DatadogMetricsLogger
from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
from ..integrations.datadog.datadog_metrics import DatadogMetricsLogger
from ..integrations.dotprompt import DotpromptManager
from ..integrations.dynamodb import DyanmoDBLogger
from ..integrations.galileo import GalileoObserve
@ -352,9 +352,9 @@ class Logging(LiteLLMLoggingBaseClass):
)
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -746,9 +746,9 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_spec=prompt_spec,
dynamic_callback_params=dynamic_callback_params,
):
self.model_call_details[
"prompt_integration"
] = logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
logger.__class__.__name__
)
return logger
except Exception:
# If check fails, continue to next logger
@ -816,9 +816,9 @@ class Logging(LiteLLMLoggingBaseClass):
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details[
"prompt_integration"
] = anthropic_cache_control_logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
return anthropic_cache_control_logger
#########################################################
@ -830,9 +830,9 @@ class Logging(LiteLLMLoggingBaseClass):
internal_usage_cache=None,
llm_router=None,
)
self.model_call_details[
"prompt_integration"
] = vector_store_custom_logger.__class__.__name__
self.model_call_details["prompt_integration"] = (
vector_store_custom_logger.__class__.__name__
)
# Add to global callbacks so post-call hooks are invoked
if (
vector_store_custom_logger
@ -892,9 +892,9 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
@ -923,10 +923,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata[
"raw_request"
] = "redacted by litellm. \
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@ -937,34 +937,34 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
)
except Exception as e:
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
)
_metadata[
"raw_request"
] = "Unable to Log \
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
str(e)
)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
@ -1265,13 +1265,13 @@ class Logging(LiteLLMLoggingBaseClass):
for callback in callbacks:
try:
if isinstance(callback, CustomLogger):
response: Optional[
MCPPostCallResponseObject
] = await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
response: Optional[MCPPostCallResponseObject] = (
await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
)
######################################################################
# if any of the callbacks modify the response, use the modified response
@ -1466,9 +1466,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
try:
@ -1494,9 +1494,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
return None
@ -1652,10 +1652,8 @@ class Logging(LiteLLMLoggingBaseClass):
result=logging_result
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
logging_result, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(logging_result, start_time, end_time)
)
if (
@ -1734,9 +1732,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
@ -1773,10 +1771,10 @@ class Logging(LiteLLMLoggingBaseClass):
end_time=end_time,
)
elif isinstance(result, dict) or isinstance(result, list):
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
result, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
result, start_time, end_time
)
)
if (
standard_logging_payload := self.model_call_details.get(
@ -1785,9 +1783,9 @@ class Logging(LiteLLMLoggingBaseClass):
) is not None:
emit_standard_logging_payload(standard_logging_payload)
elif standard_logging_object is not None:
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
else:
self.model_call_details["response_cost"] = None
@ -1945,17 +1943,17 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
)
if (
standard_logging_payload := self.model_call_details.get(
@ -2289,10 +2287,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -2316,10 +2314,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
@ -2458,9 +2456,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
try:
if self.model_call_details.get("cache_hit", False) is True:
@ -2471,10 +2469,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
)
verbose_logger.debug(
@ -2487,10 +2485,10 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
)
# print standard logging payload
@ -2517,10 +2515,8 @@ class Logging(LiteLLMLoggingBaseClass):
# _success_handler_helper_fn
if self.model_call_details.get("standard_logging_object") is None:
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
result, start_time, end_time
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(result, start_time, end_time)
)
# print standard logging payload
@ -2764,18 +2760,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
)
return start_time, end_time
@ -3739,9 +3735,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
service_name=arize_config.project_name,
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -3767,13 +3763,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
)
else:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={arize_phoenix_config.project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={arize_phoenix_config.project_name}"
)
# Set Phoenix project name from environment variable
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
@ -3781,19 +3777,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
)
else:
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={phoenix_project_name}"
)
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
for callback in _in_memory_loggers:
if (
@ -3969,9 +3965,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -4204,8 +4200,7 @@ def _maybe_auto_initialize_arize_phoenix(_in_memory_loggers: list) -> None:
litellm.logging_callback_manager.add_litellm_callback(phoenix_logger)
verbose_logger.info(
"Auto-initialized Arize Phoenix logger alongside otel "
"(endpoint=%s)",
"Auto-initialized Arize Phoenix logger alongside otel " "(endpoint=%s)",
arize_phoenix_config.endpoint,
)
except Exception as e:
@ -4768,9 +4763,11 @@ class StandardLoggingPayloadSetup:
).model_dump()
if isinstance(_raw, dict):
if ResponseAPILoggingUtils._is_response_api_usage(_raw):
return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
_raw
).model_dump()
return (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
_raw
).model_dump()
)
return _raw
if isinstance(_raw, Usage):
return _raw.model_dump()
@ -4884,10 +4881,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -5039,14 +5036,22 @@ class StandardLoggingPayloadSetup:
dynamic_litellm_session_id = litellm_params.get("litellm_session_id")
dynamic_litellm_trace_id = litellm_params.get("litellm_trace_id")
# Note: we recommend using `litellm_session_id` for session tracking
# `litellm_trace_id` is an internal litellm param
if dynamic_litellm_session_id:
return str(dynamic_litellm_session_id)
elif dynamic_litellm_trace_id:
return str(dynamic_litellm_trace_id)
else:
return logging_obj.litellm_trace_id
# Fallback: use metadata.session_id or metadata.trace_id for call chaining
metadata = litellm_params.get("metadata") or {}
metadata_session_id = metadata.get("session_id")
metadata_trace_id = metadata.get("trace_id")
if metadata_session_id:
return str(metadata_session_id)
if metadata_trace_id:
return str(metadata_trace_id)
return logging_obj.litellm_trace_id
@staticmethod
def _get_user_agent_tags(proxy_server_request: dict) -> Optional[List[str]]:
@ -5502,9 +5507,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
else:
cleaned_user_api_key_metadata[k] = v
@ -5616,4 +5621,3 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
model_parameters={"stream": True},
hidden_params=hidden_params,
)

View File

@ -5,7 +5,6 @@ LiteLLM MCP Server Routes
import asyncio
import contextlib
import traceback
import uuid
from datetime import datetime
@ -44,7 +43,10 @@ from litellm.proxy._experimental.mcp_server.utils import (
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.litellm_pre_call_utils import (
LiteLLMProxyRequestSetup,
get_chain_id_from_headers,
)
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall
@ -331,6 +333,11 @@ if MCP_AVAILABLE:
try:
# Create a body date for logging
body_data = {"name": name, "arguments": arguments}
# Set trace/session id from raw_headers so spend logs and logging_obj stay consistent (same as A2A)
chain_id = get_chain_id_from_headers(raw_headers)
if chain_id:
body_data["litellm_trace_id"] = chain_id
body_data["litellm_session_id"] = chain_id
request = Request(
scope={
@ -884,6 +891,10 @@ if MCP_AVAILABLE:
# This is intentionally minimal: only async_success_handler / post_call_failure_hook
rules_obj = Rules()
list_tools_call_id = str(uuid.uuid4())
# Derive trace_id from raw_headers when not explicitly passed (same as A2A / MCP call_tool)
effective_litellm_trace_id = litellm_trace_id or get_chain_id_from_headers(
raw_headers
)
spend_logs_metadata: Dict[str, Any] = {
"mcp_operation": "list_tools",
}
@ -896,7 +907,7 @@ if MCP_AVAILABLE:
"model": "MCP: list_tools",
"call_type": CallTypes.list_mcp_tools.value,
"litellm_call_id": list_tools_call_id,
"litellm_trace_id": litellm_trace_id,
"litellm_trace_id": effective_litellm_trace_id,
"metadata": {
"spend_logs_metadata": spend_logs_metadata,
},

View File

@ -69,6 +69,7 @@ async def _handle_stream_message(
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
if not A2A_SDK_AVAILABLE:
async def _error_stream():
yield json.dumps(
{
@ -106,7 +107,12 @@ async def _handle_stream_message(
proxy_server_request=proxy_server_request,
)
if use_proxy_hooks and user_api_key_dict is not None and request_data is not None and proxy_logging_obj is not None:
if (
use_proxy_hooks
and user_api_key_dict is not None
and request_data is not None
and proxy_logging_obj is not None
):
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)
@ -119,20 +125,27 @@ async def _handle_stream_message(
return json.dumps(obj) + "\n"
def _ndjson_error(proxy_exc: Any) -> str:
return json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": getattr(
proxy_exc, "message", f"Streaming error: {proxy_exc!s}"
),
},
}
) + "\n"
return (
json.dumps(
{
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32603,
"message": getattr(
proxy_exc,
"message",
f"Streaming error: {proxy_exc!s}",
),
},
}
)
+ "\n"
)
async for line in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
async for (
line
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
response=a2a_stream,
user_api_key_dict=user_api_key_dict,
request_data=request_data,
@ -151,7 +164,12 @@ async def _handle_stream_message(
yield json.dumps(chunk) + "\n"
except Exception as e:
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
if use_proxy_hooks and proxy_logging_obj is not None and user_api_key_dict is not None and request_data is not None:
if (
use_proxy_hooks
and proxy_logging_obj is not None
and user_api_key_dict is not None
and request_data is not None
):
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
@ -382,6 +400,7 @@ async def invoke_agent_a2a(
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
litellm_logging_obj=logging_obj,
)
response = await proxy_logging_obj.post_call_success_hook(

View File

@ -1624,11 +1624,11 @@ def _build_field_dict(
# Determine the field type from annotation
field_type = _get_field_type_from_annotation(field_annotation)
# Check for custom UI type override (ui_type preferred; "type" leaks into OpenAPI and breaks schema)
field_json_schema_extra = getattr(field, "json_schema_extra", {}) or {}
# Check for custom UI type override
field_json_schema_extra = getattr(field, "json_schema_extra", {})
if field_json_schema_extra and "ui_type" in field_json_schema_extra:
ut = field_json_schema_extra["ui_type"]
field_type = ut if isinstance(ut, str) else getattr(ut, "value", ut)
ui_type = field_json_schema_extra["ui_type"]
field_type = ui_type.value if hasattr(ui_type, "value") else ui_type
elif field_json_schema_extra and "type" in field_json_schema_extra:
field_type = field_json_schema_extra["type"]

View File

@ -89,6 +89,25 @@ def _get_metadata_variable_name(request: Request) -> str:
return "metadata"
def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]:
"""
Extract chain id for call chaining from request headers.
x-litellm-trace-id and x-litellm-session-id are interchangeable; when both
are present, x-litellm-trace-id takes precedence. Header keys are matched
case-insensitively so this works with raw header dicts from any transport.
Used by MCP (and other paths that have raw_headers but no Request) to set
litellm_trace_id/litellm_session_id for spend logs and logging consistency.
"""
if not headers:
return None
normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)}
return normalized.get("x-litellm-trace-id") or normalized.get(
"x-litellm-session-id"
)
def safe_add_api_version_from_query_params(data: dict, request: Request):
try:
if hasattr(request, "query_params"):
@ -177,12 +196,12 @@ def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
key_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
team_dynamic_logging_settings: Optional[
dict
] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
key_dynamic_logging_settings: Optional[dict] = (
KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
)
team_dynamic_logging_settings: Optional[dict] = (
KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
)
#########################################################################################
# Key-based callbacks
#########################################################################################
@ -576,9 +595,13 @@ class LiteLLMProxyRequestSetup:
#########################################################################################
# Finally update the requests metadata with the `metadata_from_headers`
#########################################################################################
agent_id_from_header = headers.get("x-litellm-agent-id")
trace_id_from_header = headers.get("x-litellm-trace-id")
session_id_from_header = headers.get("x-litellm-session-id")
# x-litellm-trace-id and x-litellm-session-id are interchangeable for call chaining
chain_id = headers.get("x-litellm-trace-id") or headers.get(
"x-litellm-session-id"
)
if agent_id_from_header:
metadata_from_headers["agent_id"] = agent_id_from_header
@ -586,16 +609,13 @@ class LiteLLMProxyRequestSetup:
f"Extracted agent_id from header: {agent_id_from_header}"
)
if trace_id_from_header:
metadata_from_headers["trace_id"] = trace_id_from_header
if chain_id:
metadata_from_headers["trace_id"] = chain_id
metadata_from_headers["session_id"] = chain_id
data["litellm_session_id"] = chain_id
data["litellm_trace_id"] = chain_id
verbose_proxy_logger.debug(
f"Extracted trace_id from header: {trace_id_from_header}"
)
if session_id_from_header:
metadata_from_headers["session_id"] = session_id_from_header
verbose_proxy_logger.debug(
f"Extracted session_id from header: {session_id_from_header}"
f"Extracted chain_id from header (trace-id/session-id): {chain_id}"
)
if isinstance(data[_metadata_variable_name], dict):
@ -702,11 +722,11 @@ class LiteLLMProxyRequestSetup:
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
data[_metadata_variable_name][
"tags"
] = LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=key_metadata["tags"],
data[_metadata_variable_name]["tags"] = (
LiteLLMProxyRequestSetup._merge_tags(
request_tags=data[_metadata_variable_name].get("tags"),
tags_to_add=key_metadata["tags"],
)
)
if "disable_global_guardrails" in key_metadata and isinstance(
key_metadata["disable_global_guardrails"], bool
@ -839,14 +859,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
"""
from litellm.proxy.proxy_server import llm_router, premium_user
from litellm.types.proxy.litellm_pre_call_utils import (
RedactedDict,
SecretFields,
)
from litellm.types.proxy.litellm_pre_call_utils import RedactedDict, SecretFields
_raw_headers: Dict[str, str] = RedactedDict(
_safe_get_request_headers(request)
)
_raw_headers: Dict[str, str] = RedactedDict(_safe_get_request_headers(request))
forward_llm_auth = False
if general_settings:
@ -986,9 +1001,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
data[_metadata_variable_name]["litellm_api_version"] = version
if general_settings is not None:
data[_metadata_variable_name][
"global_max_parallel_requests"
] = general_settings.get("global_max_parallel_requests", None)
data[_metadata_variable_name]["global_max_parallel_requests"] = (
general_settings.get("global_max_parallel_requests", None)
)
### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata

View File

@ -11,26 +11,21 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import (
MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB,
)
from litellm.constants import \
MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
from litellm.constants import REDACTED_BY_LITELM_STRING
from litellm.litellm_core_utils.core_helpers import (
get_litellm_metadata_from_kwargs,
reconstruct_model_name,
)
get_litellm_metadata_from_kwargs, reconstruct_model_name)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import (
CostBreakdown,
StandardLoggingGuardrailInformation,
StandardLoggingMCPToolCall,
StandardLoggingModelInformation,
StandardLoggingPayload,
StandardLoggingVectorStoreRequest,
VectorStoreSearchResponse,
)
from litellm.types.utils import (CostBreakdown,
StandardLoggingGuardrailInformation,
StandardLoggingMCPToolCall,
StandardLoggingModelInformation,
StandardLoggingPayload,
StandardLoggingVectorStoreRequest,
VectorStoreSearchResponse)
from litellm.utils import get_end_user_id_for_cost_tracking
@ -116,16 +111,15 @@ def _get_spend_logs_metadata(
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata.get(key)
for key in SpendLogsMetadata.__annotations__.keys()
key: metadata.get(key) for key in SpendLogsMetadata.__annotations__.keys()
}
)
clean_metadata["applied_guardrails"] = applied_guardrails
clean_metadata["batch_models"] = batch_models
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
clean_metadata[
"vector_store_request_metadata"
] = _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
clean_metadata["vector_store_request_metadata"] = (
_get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
)
clean_metadata["guardrail_information"] = guardrail_information
clean_metadata["usage_object"] = usage_object
clean_metadata["model_map_information"] = model_map_information
@ -372,9 +366,11 @@ def get_logging_payload( # noqa: PLR0915
guardrail_information=(
standard_logging_payload.get("guardrail_information", None)
if standard_logging_payload is not None
else metadata.get("standard_logging_guardrail_information", None)
if metadata is not None
else None
else (
metadata.get("standard_logging_guardrail_information", None)
if metadata is not None
else None
)
),
cold_storage_object_key=(
standard_logging_payload["metadata"].get("cold_storage_object_key", None)
@ -501,6 +497,7 @@ def _get_session_id_for_spend_log(
"""
from litellm._uuid import uuid
if (
standard_logging_payload is not None
and standard_logging_payload.get("trace_id") is not None
@ -515,9 +512,7 @@ def _get_session_id_for_spend_log(
return str(uuid.uuid4())
def _get_request_duration_ms(
start_time: datetime, end_time: datetime
) -> Optional[int]:
def _get_request_duration_ms(start_time: datetime, end_time: datetime) -> Optional[int]:
"""Compute request duration in milliseconds from start and end times."""
try:
return int((end_time - start_time).total_seconds() * 1000)
@ -709,20 +704,20 @@ def _convert_to_json_serializable_dict(
if max_depth <= 0:
# Return a placeholder if max depth is exceeded
return "<max_depth_exceeded>"
if visited is None:
visited = set()
# Get the object's memory address to track visited objects
obj_id = id(obj)
if obj_id in visited:
# Circular reference detected, return placeholder
return "<circular_reference>"
# Only track mutable objects (dict, list, objects with __dict__)
if isinstance(obj, (dict, list)) or hasattr(obj, "__dict__"):
visited.add(obj_id)
try:
if isinstance(obj, BaseModel):
# Use Pydantic's model_dump() instead of pickle
@ -741,7 +736,9 @@ def _convert_to_json_serializable_dict(
]
elif hasattr(obj, "__dict__"):
# Handle objects with __dict__ attribute
return _convert_to_json_serializable_dict(obj.__dict__, visited, max_depth - 1)
return _convert_to_json_serializable_dict(
obj.__dict__, visited, max_depth - 1
)
else:
# Primitives (str, int, float, bool, None) pass through
return obj
@ -777,9 +774,7 @@ def _get_proxy_server_request_for_spend_logs_payload(
# Apply message redaction if turn_off_message_logging is enabled
if kwargs is not None:
from litellm.litellm_core_utils.redact_messages import (
perform_redaction,
should_redact_message_logging,
)
perform_redaction, should_redact_message_logging)
# Build model_call_details dict to check redaction settings
model_call_details = {
@ -788,12 +783,12 @@ def _get_proxy_server_request_for_spend_logs_payload(
"standard_callback_dynamic_params"
),
}
# If redaction is enabled, convert to serializable dict before redacting
if should_redact_message_logging(model_call_details=model_call_details):
_request_body = _convert_to_json_serializable_dict(_request_body)
perform_redaction(model_call_details=_request_body, result=None)
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
_request_body_json_str = json.dumps(_request_body, default=str)
return _request_body_json_str
@ -845,10 +840,8 @@ def _get_response_for_spend_logs_payload(
# Apply message redaction if turn_off_message_logging is enabled
if kwargs is not None:
from litellm.litellm_core_utils.redact_messages import (
perform_redaction,
should_redact_message_logging,
)
perform_redaction, should_redact_message_logging)
litellm_params = kwargs.get("litellm_params", {})
model_call_details = {
"litellm_params": litellm_params,
@ -856,11 +849,13 @@ def _get_response_for_spend_logs_payload(
"standard_callback_dynamic_params"
),
}
# If redaction is enabled, convert to serializable dict before redacting
if should_redact_message_logging(model_call_details=model_call_details):
response_obj = _convert_to_json_serializable_dict(response_obj)
response_obj = perform_redaction(model_call_details={}, result=response_obj)
response_obj = perform_redaction(
model_call_details={}, result=response_obj
)
sanitized_wrapper = _sanitize_request_body_for_spend_logs_payload(
{"response": response_obj}
@ -882,7 +877,7 @@ def _should_store_prompts_and_responses_in_spend_logs() -> bool:
# Check general_settings (from DB or proxy_config.yaml)
store_prompts_value = general_settings.get("store_prompts_in_spend_logs")
# Normalize case: handle True/true/TRUE, False/false/FALSE, None/null
if store_prompts_value is True:
return True
@ -890,7 +885,7 @@ def _should_store_prompts_and_responses_in_spend_logs() -> bool:
# Case-insensitive string comparison
if store_prompts_value.lower() == "true":
return True
# Also check environment variable
return get_secret_bool("STORE_PROMPTS_IN_SPEND_LOGS") is True

View File

@ -1454,10 +1454,12 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
# Type assertion: logging_obj is guaranteed to be non-None after function_setup
assert logging_obj is not None, "logging_obj should not be None after function_setup"
assert (
logging_obj is not None
), "logging_obj should not be None after function_setup"
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
kwargs["litellm_logging_obj"] = logging_obj
@ -1753,7 +1755,9 @@ def client(original_function): # noqa: PLR0915
print_args_passed_to_litellm(original_function, args, kwargs)
start_time = datetime.datetime.now()
result = None
_update_response_metadata = getattr(sys.modules[__name__], "update_response_metadata")
_update_response_metadata = getattr(
sys.modules[__name__], "update_response_metadata"
)
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
"litellm_logging_obj", None
)
@ -1776,9 +1780,11 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
# Type assertion: logging_obj is guaranteed to be non-None after function_setup
assert logging_obj is not None, "logging_obj should not be None after function_setup"
assert (
logging_obj is not None
), "logging_obj should not be None after function_setup"
modified_kwargs = await async_pre_call_deployment_hook(kwargs, call_type)
if modified_kwargs is not None:
@ -1861,6 +1867,7 @@ def client(original_function): # noqa: PLR0915
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
if _is_streaming_request(
kwargs=kwargs,
call_type=call_type,
@ -2082,12 +2089,14 @@ def _is_async_request(
return False
_STREAMING_CALL_TYPES = frozenset({
CallTypes.generate_content_stream,
CallTypes.agenerate_content_stream,
CallTypes.generate_content_stream.value,
CallTypes.agenerate_content_stream.value,
})
_STREAMING_CALL_TYPES = frozenset(
{
CallTypes.generate_content_stream,
CallTypes.agenerate_content_stream,
CallTypes.generate_content_stream.value,
CallTypes.agenerate_content_stream.value,
}
)
def _is_streaming_request(
@ -2181,7 +2190,7 @@ def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
# Normalize: HuggingFace Tokenizer.encode() returns an Encoding object;
# extract .ids so the return type is always List[int].
if hasattr(enc, "ids"):
return enc.ids
return enc.ids # type: ignore
return enc
@ -5836,7 +5845,7 @@ def get_model_info(
_model_info[key] = value # type: ignore
# if verbose_logger.isEnabledFor(logging.DEBUG):
# verbose_logger.debug(f"model_info: {_model_info}")
# verbose_logger.debug(f"model_info: {_model_info}")
returned_model_info = ModelInfo(
**_model_info, supported_openai_params=supported_openai_params
@ -6179,8 +6188,10 @@ def validate_environment( # noqa: PLR0915
"AWS_ROLE_ARN" in os.environ
or "AWS_PROFILE" in os.environ
or "AWS_WEB_IDENTITY_TOKEN_FILE" in os.environ
or "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" in os.environ # ECS task role
or "AWS_CONTAINER_CREDENTIALS_FULL_URI" in os.environ # ECS/Fargate full URI credential delivery
or "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
in os.environ # ECS task role
or "AWS_CONTAINER_CREDENTIALS_FULL_URI"
in os.environ # ECS/Fargate full URI credential delivery
):
keys_in_environment = True
else:
@ -7386,7 +7397,9 @@ class ModelResponseIterator:
if convert_to_delta is True:
_stream_response = ModelResponseStream()
_stream_response.choices[0].delta.content = model_response.choices[0].message.content # type: ignore
self.model_response: Union[ModelResponse, ModelResponseStream] = _stream_response
self.model_response: Union[ModelResponse, ModelResponseStream] = (
_stream_response
)
else:
self.model_response = model_response
self.is_done = False
@ -7457,13 +7470,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
Used for anthropic/gemini context caching.
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
Can be disabled globally by setting litellm.disable_anthropic_gemini_context_caching_transform = True
"""
# Check if context caching is disabled globally
if litellm.disable_anthropic_gemini_context_caching_transform is True:
return False
if "content" not in message:
return False
@ -7980,6 +7993,7 @@ class ProviderConfigManager:
def _get_azure_ai_config(model: str) -> BaseConfig:
"""Get Azure AI config based on model type."""
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
return AzureFoundryModelInfo.get_azure_ai_config_for_model(model)
@staticmethod

View File

@ -1112,6 +1112,47 @@ async def test_get_guardrail_info_endpoint_db_guardrail(mocker):
assert result.guardrail_definition_location == "db"
class TestBuildFieldDict:
"""Test _build_field_dict handles both enum and string ui_type values."""
def test_build_field_dict_with_string_ui_type(self):
"""Test that _build_field_dict works when ui_type is a plain string (e.g. BlockCodeExecutionGuardrailConfigModel)."""
from unittest.mock import MagicMock
from litellm.proxy.guardrails.guardrail_endpoints import _build_field_dict
field = MagicMock()
field.json_schema_extra = {"ui_type": "multiselect", "options": ["python", "javascript"]}
result = _build_field_dict(
field=field,
field_annotation=str,
description="Test field",
required=False,
)
assert result["type"] == "multiselect"
assert result["description"] == "Test field"
def test_build_field_dict_with_enum_ui_type(self):
"""Test that _build_field_dict works when ui_type is a GuardrailParamUITypes enum."""
from unittest.mock import MagicMock
from litellm.proxy.guardrails.guardrail_endpoints import _build_field_dict
from litellm.types.guardrails import GuardrailParamUITypes
field = MagicMock()
field.json_schema_extra = {"ui_type": GuardrailParamUITypes.BOOL}
result = _build_field_dict(
field=field,
field_annotation=bool,
description="Test bool field",
required=True,
)
assert result["type"] == "bool"
assert result["required"] is True
# --- Team guardrail registration (register / submissions) ---
MOCK_REGISTER_REQUEST = RegisterGuardrailRequest(
@ -1571,4 +1612,4 @@ async def test_list_submissions_summary_counts_unaffected_by_filters(mocker):
assert len(result.submissions) == 1 # filtered
assert result.summary.total == 2 # unfiltered
assert result.summary.pending_review == 1
assert result.summary.active == 1
assert result.summary.active == 1

View File

@ -11,11 +11,16 @@ from fastapi import Request
import litellm
from litellm.proxy._types import TeamCallbackMetadata, UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import (
KeyAndTeamLoggingSettings, LiteLLMProxyRequestSetup,
_get_dynamic_logging_metadata, _get_enforced_params,
_get_metadata_variable_name, _update_model_if_key_alias_exists,
add_guardrails_from_policy_engine, add_litellm_data_to_request,
check_if_token_is_service_account)
KeyAndTeamLoggingSettings,
LiteLLMProxyRequestSetup,
_get_dynamic_logging_metadata,
_get_enforced_params,
_get_metadata_variable_name,
_update_model_if_key_alias_exists,
add_guardrails_from_policy_engine,
add_litellm_data_to_request,
check_if_token_is_service_account,
)
sys.path.insert(
0, os.path.abspath("../../..")
@ -154,8 +159,7 @@ def test_get_enforced_params(
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_parses_string_metadata():
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup
request_mock = MagicMock(spec=Request)
@ -201,8 +205,7 @@ async def test_add_litellm_data_to_request_parses_string_metadata():
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_user_spend_and_budget():
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/v1/completions"
@ -240,8 +243,7 @@ async def test_add_litellm_data_to_request_user_spend_and_budget():
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_audio_transcription_multipart():
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup request mock for /v1/audio/transcriptions
request_mock = MagicMock(spec=Request)
@ -306,8 +308,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks():
"""
Test that litellm_disabled_callbacks from key metadata is properly added to the request data.
"""
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
@ -360,8 +361,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_empty():
"""
Test that litellm_disabled_callbacks is not added when it's empty.
"""
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
@ -413,8 +413,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_not_present():
"""
Test that litellm_disabled_callbacks is not added when it's not present in metadata.
"""
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
@ -466,8 +465,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_invalid_type():
"""
Test that litellm_disabled_callbacks is not added when it's not a list.
"""
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
@ -519,8 +517,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_with_logging_setti
"""
Test that litellm_disabled_callbacks works correctly alongside logging settings.
"""
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request
request_mock = MagicMock(spec=Request)
@ -1030,8 +1027,7 @@ from unittest.mock import AsyncMock
from fastapi.responses import Response
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.common_request_processing import \
ProxyBaseLLMRequestProcessing
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.utils import ProxyLogging
from litellm.types.utils import StandardLoggingPayload
@ -1149,6 +1145,47 @@ async def test_add_litellm_metadata_from_request_headers():
litellm.callbacks = original_callbacks
def test_add_litellm_metadata_from_request_headers_x_litellm_trace_id_sets_chain_id():
"""x-litellm-trace-id sets both metadata and top-level litellm_session_id/litellm_trace_id for call chaining."""
headers = {"x-litellm-trace-id": "foo"}
data = {"metadata": {}}
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
headers=headers, data=data, _metadata_variable_name="metadata"
)
assert data["metadata"]["trace_id"] == "foo"
assert data["metadata"]["session_id"] == "foo"
assert data["litellm_session_id"] == "foo"
assert data["litellm_trace_id"] == "foo"
def test_add_litellm_metadata_from_request_headers_x_litellm_session_id_sets_chain_id():
"""x-litellm-session-id sets both metadata and top-level litellm_session_id/litellm_trace_id for call chaining."""
headers = {"x-litellm-session-id": "bar"}
data = {"metadata": {}}
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
headers=headers, data=data, _metadata_variable_name="metadata"
)
assert data["metadata"]["trace_id"] == "bar"
assert data["metadata"]["session_id"] == "bar"
assert data["litellm_session_id"] == "bar"
assert data["litellm_trace_id"] == "bar"
def test_add_litellm_metadata_from_request_headers_both_headers_trace_id_precedence():
"""When both x-litellm-trace-id and x-litellm-session-id are present, trace-id takes precedence for chain_id."""
headers = {
"x-litellm-trace-id": "trace-value",
"x-litellm-session-id": "session-value",
}
data = {"metadata": {}}
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
headers=headers, data=data, _metadata_variable_name="metadata"
)
assert data["metadata"]["trace_id"] == "trace-value"
assert data["metadata"]["session_id"] == "trace-value"
assert data["litellm_session_id"] == "trace-value"
assert data["litellm_trace_id"] == "trace-value"
def test_get_internal_user_header_from_mapping_returns_expected_header():
mappings = [
@ -1407,8 +1444,7 @@ async def test_embedding_header_forwarding_with_model_group():
importlib.reload(pre_call_utils_module)
# Re-import the function after reload to get the fresh version
from litellm.proxy.litellm_pre_call_utils import \
add_litellm_data_to_request
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
# Setup mock request for embeddings
request_mock = MagicMock(spec=Request)
@ -1542,11 +1578,13 @@ async def test_add_guardrails_from_policy_engine():
Test that add_guardrails_from_policy_engine adds guardrails from matching policies
and tracks applied policies in metadata.
"""
from litellm.proxy.policy_engine.attachment_registry import \
get_attachment_registry
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import (Policy, PolicyAttachment,
PolicyGuardrails)
from litellm.types.proxy.policy_engine import (
Policy,
PolicyAttachment,
PolicyGuardrails,
)
# Setup test data
data = {
@ -1659,8 +1697,7 @@ async def test_add_guardrails_from_policy_engine_policy_version_by_id():
Test that add_guardrails_from_policy_engine executes a specific policy version
when policy_<uuid> is passed in the request body.
"""
from litellm.proxy.policy_engine.attachment_registry import \
get_attachment_registry
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import Policy, PolicyGuardrails
@ -1729,6 +1766,7 @@ async def test_bearer_token_not_in_debug_logs():
"""
import logging
from io import StringIO
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import ProxyConfig

View File

@ -6,9 +6,9 @@ import {
LeftOutlined,
RightOutlined,
} from "@ant-design/icons";
import { Sparkles, Wrench } from "lucide-react";
import { Bot, Sparkles, Wrench } from "lucide-react";
import { LogEntry } from "../columns";
import { MCP_CALL_TYPES } from "../constants";
import { AGENT_CALL_TYPES, MCP_CALL_TYPES } from "../constants";
import { getEventDisplayName } from "../utils";
import { DrawerHeader } from "./DrawerHeader";
import { useKeyboardNavigation } from "./useKeyboardNavigation";
@ -46,6 +46,7 @@ interface TraceEventRowProps {
function TraceEventRow({ row, isSelected, onClick }: TraceEventRowProps) {
const isMcp = MCP_CALL_TYPES.includes(row.call_type);
const isAgent = AGENT_CALL_TYPES.includes(row.call_type);
const durationValue =
row.request_duration_ms != null
? (row.request_duration_ms / 1000).toFixed(3)
@ -64,6 +65,8 @@ function TraceEventRow({ row, isSelected, onClick }: TraceEventRowProps) {
<div className="flex items-center gap-1">
{isMcp ? (
<Wrench size={12} className="text-slate-500 flex-shrink-0" />
) : isAgent ? (
<Bot size={12} className="text-slate-500 flex-shrink-0" />
) : (
<Sparkles size={12} className="text-slate-500 flex-shrink-0" />
)}
@ -219,7 +222,10 @@ export function LogDetailsDrawer({
: null;
const sessionDurationSeconds =
sessionStart && sessionEnd ? ((sessionEnd.getTime() - sessionStart.getTime()) / 1000).toFixed(2) : "0.00";
const llmCount = sessionLogs.filter((row) => !MCP_CALL_TYPES.includes(row.call_type)).length;
const llmCount = sessionLogs.filter(
(row) => !MCP_CALL_TYPES.includes(row.call_type) && !AGENT_CALL_TYPES.includes(row.call_type),
).length;
const agentCount = sessionLogs.filter((row) => AGENT_CALL_TYPES.includes(row.call_type)).length;
const mcpCount = sessionLogs.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length;
const logsForList = isSessionMode ? sessionLogs : currentLog ? [currentLog] : [];
const leftPanelId = isSessionMode ? sessionId || "" : currentLog?.request_id || "";
@ -302,14 +308,25 @@ export function LogDetailsDrawer({
</div>
<div className="mt-1 text-[11px] text-slate-500 font-mono">
{logsForList.length} req
<span className="mx-1.5">·</span>
{isSessionMode
? `${llmCount} LLM`
: `${logsForList.filter((row) => !MCP_CALL_TYPES.includes(row.call_type)).length} LLM`}
<span className="mx-1.5">·</span>
{isSessionMode
? `${mcpCount} MCP`
: `${logsForList.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length} MCP`}
{[
isSessionMode
? llmCount
: logsForList.filter(
(row) =>
!MCP_CALL_TYPES.includes(row.call_type) && !AGENT_CALL_TYPES.includes(row.call_type),
).length,
isSessionMode ? agentCount : logsForList.filter((row) => AGENT_CALL_TYPES.includes(row.call_type)).length,
isSessionMode ? mcpCount : logsForList.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length,
].map((count, i) => {
const label = [" LLM", " Agent", " MCP"][i];
return count > 0 ? (
<span key={label}>
<span className="mx-1.5">·</span>
{count}
{label}
</span>
) : null;
})}
<span className="mx-1.5">·</span>
{isSessionMode
? getSpendString(totalSessionCost)

View File

@ -1,5 +1,5 @@
/**
* Compact type-indicator badges for LLM and MCP log entries.
* Compact type-indicator badges for LLM, Agent, and MCP log entries.
* Used in the request logs table and session type column.
*/
@ -15,6 +15,18 @@ export const WrenchIcon = ({ size = 10 }: { size?: number }) => (
</svg>
);
/** Agent/bot icon for A2A and agent call types (Lucide Bot-style). */
export const AgentIcon = ({ size = 12 }: { size?: number }) => (
<svg width={size} height={size} viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" className="flex-shrink-0">
<path d="M12 8V4H8" />
<rect width="16" height="12" x="4" y="8" rx="2" />
<path d="M2 14h2" />
<path d="M20 14h2" />
<path d="M15 13v2" />
<path d="M9 13v2" />
</svg>
);
export const LlmBadge = ({ count }: { count?: number }) => (
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-blue-50 text-blue-700 border border-blue-200 rounded-full text-[11px] font-medium whitespace-nowrap">
<SparkleIcon />
@ -28,3 +40,10 @@ export const McpBadge = ({ count }: { count?: number }) => (
{count != null ? count : "MCP"}
</span>
);
export const AgentBadge = ({ count }: { count?: number }) => (
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-violet-50 text-violet-700 border border-violet-200 rounded-full text-[11px] font-medium whitespace-nowrap">
<AgentIcon />
{count != null ? count : "Agent"}
</span>
);

View File

@ -6,8 +6,8 @@ import React, { useState } from "react";
import { getProviderLogoAndName } from "../provider_info_helpers";
import { TableHeaderSortDropdown } from "../common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
import { TimeCell } from "./time_cell";
import { MCP_CALL_TYPES } from "./constants";
import { LlmBadge, McpBadge, SparkleIcon, WrenchIcon } from "./TypeBadges";
import { AGENT_CALL_TYPES, MCP_CALL_TYPES } from "./constants";
import { AgentBadge, AgentIcon, LlmBadge, McpBadge, SparkleIcon, WrenchIcon } from "./TypeBadges";
/** API sort field mapping for /spend/logs/ui endpoint */
export const LOGS_SORT_FIELD_MAP = {
@ -69,6 +69,7 @@ export type LogEntry = {
mcp_tool_call_spend?: number;
session_llm_count?: number;
session_mcp_count?: number;
session_agent_count?: number;
onKeyHashClick?: (keyHash: string) => void;
onSessionClick?: (sessionId: string) => void;
};
@ -124,17 +125,26 @@ export const createColumns = (sortProps?: LogsSortProps): ColumnDef<LogEntry>[]
const row = info.row.original;
const sessionCount = row.session_total_count || 1;
const isMcp = MCP_CALL_TYPES.includes(row.call_type);
const sessionLlmCount = row.session_llm_count ?? (isMcp ? 0 : sessionCount);
const isAgent = AGENT_CALL_TYPES.includes(row.call_type);
const sessionLlmCount = row.session_llm_count ?? (isMcp || isAgent ? 0 : sessionCount);
const sessionAgentCount = row.session_agent_count ?? (isAgent ? sessionCount : 0);
const sessionMcpCount = row.session_mcp_count ?? (isMcp ? sessionCount : 0);
if (isMcp) return <McpBadge />;
if (isAgent && sessionCount <= 1) return <AgentBadge />;
if (sessionCount <= 1) return <LlmBadge />;
// Multi-call session — show total count, plus MCP indicator when mixed.
// Multi-call session — show total count, plus Agent/MCP indicators when mixed.
const sessionTypeBadge = (
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-blue-50 text-blue-700 border border-blue-200 rounded-full text-[11px] font-medium whitespace-nowrap">
<SparkleIcon />
<span>{sessionCount}</span>
{sessionAgentCount > 0 && (
<>
<span className="text-blue-300">·</span>
<AgentIcon size={10} />
</>
)}
{sessionMcpCount > 0 && (
<>
<span className="text-blue-300">·</span>
@ -144,8 +154,13 @@ export const createColumns = (sortProps?: LogsSortProps): ColumnDef<LogEntry>[]
</span>
);
const tooltipParts = [
sessionLlmCount > 0 && `${sessionLlmCount} LLM`,
sessionAgentCount > 0 && `${sessionAgentCount} Agent`,
sessionMcpCount > 0 && `${sessionMcpCount} MCP`,
].filter(Boolean);
return (
<Tooltip title={`${sessionLlmCount} LLM • ${sessionMcpCount} MCP`}>
<Tooltip title={tooltipParts.join(" • ")}>
{sessionTypeBadge}
</Tooltip>
);

View File

@ -15,6 +15,9 @@ export const ERROR_CODE_OPTIONS: { label: string; value: string }[] = [
/** Call types that represent MCP tool invocations (shared across columns, index, drawer). */
export const MCP_CALL_TYPES = ["call_mcp_tool", "list_mcp_tools"];
/** Call types that represent agent/A2A requests (e.g. asend_message). */
export const AGENT_CALL_TYPES = ["asend_message"];
export const QUICK_SELECT_OPTIONS: { label: string; value: number; unit: string }[] = [
{ label: "Last 15 Minutes", value: 15, unit: "minutes" },
{ label: "Last Hour", value: 1, unit: "hours" },

View File

@ -20,7 +20,7 @@ import KeyInfoView from "../templates/key_info_view";
import AuditLogs from "./audit_logs";
import { createColumns, LogEntry, type LogsSortField } from "./columns";
import { ConfigInfoMessage } from "./ConfigInfoMessage";
import { ERROR_CODE_OPTIONS, MCP_CALL_TYPES, QUICK_SELECT_OPTIONS } from "./constants";
import { AGENT_CALL_TYPES, ERROR_CODE_OPTIONS, MCP_CALL_TYPES, QUICK_SELECT_OPTIONS } from "./constants";
import { CostBreakdownViewer } from "./CostBreakdownViewer";
import { ErrorViewer } from "./ErrorViewer";
import { useLogFilterLogic } from "./log_filter_logic";
@ -309,13 +309,15 @@ export default function SpendLogsTable({
return matchesSearch;
});
const sessionCompositionById = searchedLogs.reduce<Record<string, { llm: number; mcp: number }>>((acc, log) => {
const sessionCompositionById = searchedLogs.reduce<Record<string, { llm: number; agent: number; mcp: number }>>((acc, log) => {
if (!log.session_id) return acc;
if (!acc[log.session_id]) {
acc[log.session_id] = { llm: 0, mcp: 0 };
acc[log.session_id] = { llm: 0, agent: 0, mcp: 0 };
}
if (MCP_CALL_TYPES.includes(log.call_type)) {
acc[log.session_id].mcp += 1;
} else if (AGENT_CALL_TYPES.includes(log.call_type)) {
acc[log.session_id].agent += 1;
} else {
acc[log.session_id].llm += 1;
}
@ -343,6 +345,7 @@ export default function SpendLogsTable({
request_duration_ms: log.request_duration_ms,
session_llm_count: sessionComposition?.llm ?? undefined,
session_mcp_count: sessionComposition?.mcp ?? undefined,
session_agent_count: sessionComposition?.agent ?? undefined,
onKeyHashClick: (keyHash: string) => setSelectedKeyIdInfoView(keyHash),
onSessionClick: (sessionId: string) => {
if (sessionId) {

View File

@ -14,7 +14,7 @@
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "react-jsx",
"jsx": "preserve",
"incremental": true,
"plugins": [
{