Agent Tracing - support context_id based trace id propogation + nested llm calls (#22626)
* style(ui/): distinguish agent calls from llm calls on ui * feat: initial grouping working * feat: set stable contextid for a2a calls - allows for easily passing to downstream llm/mcp calls * feat(a2a_endpoints.py): fix tracing to avoid recreating logging objects for the same call allows stable trace id usage * fix(guardrail_endpoints): handle string ui_type values in _build_field_dict _build_field_dict unconditionally called .value on ui_type, which crashes for guardrail configs that use plain strings (e.g. BlockCodeExecutionGuardrailConfigModel uses "multiselect" and "percentage"). Now checks with hasattr before calling .value. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: propagate trace/session id from headers in MCP server calls Cherry-picked mcp_server/server.py fixes from 6feb9bab: adds get_chain_id_from_headers to extract x-litellm-trace-id / x-litellm-session-id from raw headers, and uses it in call_tool and list_tools to keep spend logs and tracing consistent with A2A. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ba7a6d9bfd
commit
90eb6729d5
@ -930,7 +930,7 @@ For Responses API with load balancing across deployments with **different API ke
|
||||
|
||||
Notes:
|
||||
- User-key affinity is keyed on `metadata.user_api_key_hash` (the API key hash). The OpenAI `user` request parameter is an end-user identifier and is intentionally not used for deployment affinity.
|
||||
- Session-ID affinity is keyed on `metadata.session_id`. For proxy requests, this can be passed via the `x-litellm-session-id` HTTP header. For Python SDK requests, you can pass it via `litellm_metadata={"session_id": "value"}` in request args.
|
||||
- Session-ID affinity is keyed on `metadata.session_id`. For proxy requests, this can be passed via the `x-litellm-session-id` or `x-litellm-trace-id` HTTP header (they are interchangeable for call chaining). For Python SDK requests, you can pass it via `litellm_metadata={"session_id": "value"}` in request args.
|
||||
- `user_api_key_hash` is already SHA-256, and is used as-is (no double hashing).
|
||||
- Affinity is scoped by a stable model identifier (the model-map key, e.g. `model_map_information.model_map_key`) so model aliases map to the same stickiness bucket.
|
||||
- The mapping TTL is controlled by `deployment_affinity_ttl_seconds` (configured on Router init / proxy startup).
|
||||
|
||||
@ -24,11 +24,7 @@ from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
SendMessageRequest,
|
||||
SendStreamingMessageRequest,
|
||||
)
|
||||
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
# Runtime imports with availability check
|
||||
A2A_SDK_AVAILABLE = False
|
||||
@ -124,13 +120,48 @@ def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["custom_llm_provider"] = (
|
||||
custom_llm_provider
|
||||
)
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
async def _send_message_via_completion_bridge(
|
||||
request: "SendMessageRequest",
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
|
||||
|
||||
Requires request; api_base is optional for providers that derive endpoint from model.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
@ -193,39 +224,21 @@ async def asend_message(
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
# Extract params from request
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
return await _send_message_via_completion_bridge(
|
||||
request=request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Convert to LiteLLMSendMessageResponse
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
@ -236,11 +249,13 @@ async def asend_message(
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
trace_id = str(uuid.uuid4())
|
||||
trace_id = trace_id or str(uuid.uuid4())
|
||||
extra_headers = {"X-LiteLLM-Trace-Id": trace_id}
|
||||
if agent_id:
|
||||
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
a2a_client = await create_a2a_client(base_url=api_base, extra_headers=extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
@ -255,6 +270,10 @@ async def asend_message(
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
|
||||
context_id = trace_id or str(uuid.uuid4())
|
||||
if request.params.message.context_id is None:
|
||||
request.params.message.context_id = context_id
|
||||
|
||||
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
|
||||
a2a_response = None
|
||||
for _ in range(2): # max 2 attempts: original + 1 retry
|
||||
@ -606,7 +625,9 @@ async def create_a2a_client(
|
||||
|
||||
if extra_headers:
|
||||
httpx_client.headers.update(extra_headers)
|
||||
verbose_proxy_logger.debug(f"A2A client created with extra_headers={extra_headers}")
|
||||
verbose_proxy_logger.debug(
|
||||
f"A2A client created with extra_headers={extra_headers}"
|
||||
)
|
||||
|
||||
# Resolve agent card
|
||||
resolver = A2ACardResolver(
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Pre-define optional kwargs keys as frozenset for O(1) lookups
|
||||
# These are extracted from kwargs only if present, avoiding unnecessary .get() calls
|
||||
_OPTIONAL_KWARGS_KEYS = frozenset({
|
||||
@ -95,6 +94,13 @@ def get_litellm_params(
|
||||
litellm_request_debug: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
# Derive litellm_session_id / litellm_trace_id from metadata when not provided (call chaining)
|
||||
_meta = metadata or {}
|
||||
if litellm_session_id is None:
|
||||
litellm_session_id = _meta.get("session_id") or _meta.get("trace_id")
|
||||
if litellm_trace_id is None:
|
||||
litellm_trace_id = _meta.get("trace_id") or _meta.get("session_id")
|
||||
|
||||
# Build base dict with explicit parameters (always included)
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
||||
@ -133,8 +133,8 @@ from ..integrations.azure_sentinel.azure_sentinel import AzureSentinelLogger
|
||||
from ..integrations.azure_storage.azure_storage import AzureBlobStorageLogger
|
||||
from ..integrations.custom_prompt_management import CustomPromptManagement
|
||||
from ..integrations.datadog.datadog import DataDogLogger
|
||||
from ..integrations.datadog.datadog_metrics import DatadogMetricsLogger
|
||||
from ..integrations.datadog.datadog_llm_obs import DataDogLLMObsLogger
|
||||
from ..integrations.datadog.datadog_metrics import DatadogMetricsLogger
|
||||
from ..integrations.dotprompt import DotpromptManager
|
||||
from ..integrations.dynamodb import DyanmoDBLogger
|
||||
from ..integrations.galileo import GalileoObserve
|
||||
@ -352,9 +352,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
@ -746,9 +746,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
prompt_spec=prompt_spec,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
):
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = logger.__class__.__name__
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
logger.__class__.__name__
|
||||
)
|
||||
return logger
|
||||
except Exception:
|
||||
# If check fails, continue to next logger
|
||||
@ -816,9 +816,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = anthropic_cache_control_logger.__class__.__name__
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
anthropic_cache_control_logger.__class__.__name__
|
||||
)
|
||||
return anthropic_cache_control_logger
|
||||
|
||||
#########################################################
|
||||
@ -830,9 +830,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = vector_store_custom_logger.__class__.__name__
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
vector_store_custom_logger.__class__.__name__
|
||||
)
|
||||
# Add to global callbacks so post-call hooks are invoked
|
||||
if (
|
||||
vector_store_custom_logger
|
||||
@ -892,9 +892,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
# Log the exact input to the LLM API
|
||||
@ -923,10 +923,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
@ -937,34 +937,34 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||
try:
|
||||
@ -1265,13 +1265,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
response: Optional[
|
||||
MCPPostCallResponseObject
|
||||
] = await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response: Optional[MCPPostCallResponseObject] = (
|
||||
await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
######################################################################
|
||||
# if any of the callbacks modify the response, use the modified response
|
||||
@ -1466,9 +1466,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -1494,9 +1494,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -1652,10 +1652,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
result=logging_result
|
||||
)
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
logging_result, start_time, end_time
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(logging_result, start_time, end_time)
|
||||
)
|
||||
|
||||
if (
|
||||
@ -1734,9 +1732,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
@ -1773,10 +1771,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time=end_time,
|
||||
)
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
)
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -1785,9 +1783,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
) is not None:
|
||||
emit_standard_logging_payload(standard_logging_payload)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
else:
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
@ -1945,17 +1943,17 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -2289,10 +2287,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
@ -2316,10 +2314,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
@ -2458,9 +2456,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
@ -2471,10 +2469,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
@ -2487,10 +2485,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
)
|
||||
|
||||
# print standard logging payload
|
||||
@ -2517,10 +2515,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
# _success_handler_helper_fn
|
||||
if self.model_call_details.get("standard_logging_object") is None:
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(result, start_time, end_time)
|
||||
)
|
||||
|
||||
# print standard logging payload
|
||||
@ -2764,18 +2760,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
@ -3739,9 +3735,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
service_name=arize_config.project_name,
|
||||
)
|
||||
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
@ -3767,13 +3763,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
else:
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
|
||||
# Set Phoenix project name from environment variable
|
||||
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
|
||||
@ -3781,19 +3777,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
else:
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={phoenix_project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
@ -3969,9 +3965,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
@ -4204,8 +4200,7 @@ def _maybe_auto_initialize_arize_phoenix(_in_memory_loggers: list) -> None:
|
||||
litellm.logging_callback_manager.add_litellm_callback(phoenix_logger)
|
||||
|
||||
verbose_logger.info(
|
||||
"Auto-initialized Arize Phoenix logger alongside otel "
|
||||
"(endpoint=%s)",
|
||||
"Auto-initialized Arize Phoenix logger alongside otel " "(endpoint=%s)",
|
||||
arize_phoenix_config.endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -4768,9 +4763,11 @@ class StandardLoggingPayloadSetup:
|
||||
).model_dump()
|
||||
if isinstance(_raw, dict):
|
||||
if ResponseAPILoggingUtils._is_response_api_usage(_raw):
|
||||
return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
_raw
|
||||
).model_dump()
|
||||
return (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
_raw
|
||||
).model_dump()
|
||||
)
|
||||
return _raw
|
||||
if isinstance(_raw, Usage):
|
||||
return _raw.model_dump()
|
||||
@ -4884,10 +4881,10 @@ class StandardLoggingPayloadSetup:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
@ -5039,14 +5036,22 @@ class StandardLoggingPayloadSetup:
|
||||
dynamic_litellm_session_id = litellm_params.get("litellm_session_id")
|
||||
dynamic_litellm_trace_id = litellm_params.get("litellm_trace_id")
|
||||
|
||||
|
||||
# Note: we recommend using `litellm_session_id` for session tracking
|
||||
# `litellm_trace_id` is an internal litellm param
|
||||
if dynamic_litellm_session_id:
|
||||
return str(dynamic_litellm_session_id)
|
||||
elif dynamic_litellm_trace_id:
|
||||
return str(dynamic_litellm_trace_id)
|
||||
else:
|
||||
return logging_obj.litellm_trace_id
|
||||
# Fallback: use metadata.session_id or metadata.trace_id for call chaining
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
metadata_session_id = metadata.get("session_id")
|
||||
metadata_trace_id = metadata.get("trace_id")
|
||||
if metadata_session_id:
|
||||
return str(metadata_session_id)
|
||||
if metadata_trace_id:
|
||||
return str(metadata_trace_id)
|
||||
return logging_obj.litellm_trace_id
|
||||
|
||||
@staticmethod
|
||||
def _get_user_agent_tags(proxy_server_request: dict) -> Optional[List[str]]:
|
||||
@ -5502,9 +5507,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
@ -5616,4 +5621,3 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||
model_parameters={"stream": True},
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ LiteLLM MCP Server Routes
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -44,7 +43,10 @@ from litellm.proxy._experimental.mcp_server.utils import (
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
LiteLLMProxyRequestSetup,
|
||||
get_chain_id_from_headers,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer
|
||||
from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall
|
||||
@ -331,6 +333,11 @@ if MCP_AVAILABLE:
|
||||
try:
|
||||
# Create a body date for logging
|
||||
body_data = {"name": name, "arguments": arguments}
|
||||
# Set trace/session id from raw_headers so spend logs and logging_obj stay consistent (same as A2A)
|
||||
chain_id = get_chain_id_from_headers(raw_headers)
|
||||
if chain_id:
|
||||
body_data["litellm_trace_id"] = chain_id
|
||||
body_data["litellm_session_id"] = chain_id
|
||||
|
||||
request = Request(
|
||||
scope={
|
||||
@ -884,6 +891,10 @@ if MCP_AVAILABLE:
|
||||
# This is intentionally minimal: only async_success_handler / post_call_failure_hook
|
||||
rules_obj = Rules()
|
||||
list_tools_call_id = str(uuid.uuid4())
|
||||
# Derive trace_id from raw_headers when not explicitly passed (same as A2A / MCP call_tool)
|
||||
effective_litellm_trace_id = litellm_trace_id or get_chain_id_from_headers(
|
||||
raw_headers
|
||||
)
|
||||
spend_logs_metadata: Dict[str, Any] = {
|
||||
"mcp_operation": "list_tools",
|
||||
}
|
||||
@ -896,7 +907,7 @@ if MCP_AVAILABLE:
|
||||
"model": "MCP: list_tools",
|
||||
"call_type": CallTypes.list_mcp_tools.value,
|
||||
"litellm_call_id": list_tools_call_id,
|
||||
"litellm_trace_id": litellm_trace_id,
|
||||
"litellm_trace_id": effective_litellm_trace_id,
|
||||
"metadata": {
|
||||
"spend_logs_metadata": spend_logs_metadata,
|
||||
},
|
||||
|
||||
@ -69,6 +69,7 @@ async def _handle_stream_message(
|
||||
from litellm.a2a_protocol.main import A2A_SDK_AVAILABLE
|
||||
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
|
||||
async def _error_stream():
|
||||
yield json.dumps(
|
||||
{
|
||||
@ -106,7 +107,12 @@ async def _handle_stream_message(
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
if use_proxy_hooks and user_api_key_dict is not None and request_data is not None and proxy_logging_obj is not None:
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
and proxy_logging_obj is not None
|
||||
):
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
)
|
||||
@ -119,20 +125,27 @@ async def _handle_stream_message(
|
||||
return json.dumps(obj) + "\n"
|
||||
|
||||
def _ndjson_error(proxy_exc: Any) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": getattr(
|
||||
proxy_exc, "message", f"Streaming error: {proxy_exc!s}"
|
||||
),
|
||||
},
|
||||
}
|
||||
) + "\n"
|
||||
return (
|
||||
json.dumps(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": getattr(
|
||||
proxy_exc,
|
||||
"message",
|
||||
f"Streaming error: {proxy_exc!s}",
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
async for line in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
async for (
|
||||
line
|
||||
) in ProxyBaseLLMRequestProcessing.async_streaming_data_generator(
|
||||
response=a2a_stream,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
@ -151,7 +164,12 @@ async def _handle_stream_message(
|
||||
yield json.dumps(chunk) + "\n"
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error streaming A2A response: {e}")
|
||||
if use_proxy_hooks and proxy_logging_obj is not None and user_api_key_dict is not None and request_data is not None:
|
||||
if (
|
||||
use_proxy_hooks
|
||||
and proxy_logging_obj is not None
|
||||
and user_api_key_dict is not None
|
||||
and request_data is not None
|
||||
):
|
||||
transformed_exception = await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
@ -382,6 +400,7 @@ async def invoke_agent_a2a(
|
||||
agent_id=agent.agent_id,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
|
||||
@ -1624,11 +1624,11 @@ def _build_field_dict(
|
||||
# Determine the field type from annotation
|
||||
field_type = _get_field_type_from_annotation(field_annotation)
|
||||
|
||||
# Check for custom UI type override (ui_type preferred; "type" leaks into OpenAPI and breaks schema)
|
||||
field_json_schema_extra = getattr(field, "json_schema_extra", {}) or {}
|
||||
# Check for custom UI type override
|
||||
field_json_schema_extra = getattr(field, "json_schema_extra", {})
|
||||
if field_json_schema_extra and "ui_type" in field_json_schema_extra:
|
||||
ut = field_json_schema_extra["ui_type"]
|
||||
field_type = ut if isinstance(ut, str) else getattr(ut, "value", ut)
|
||||
ui_type = field_json_schema_extra["ui_type"]
|
||||
field_type = ui_type.value if hasattr(ui_type, "value") else ui_type
|
||||
elif field_json_schema_extra and "type" in field_json_schema_extra:
|
||||
field_type = field_json_schema_extra["type"]
|
||||
|
||||
|
||||
@ -89,6 +89,25 @@ def _get_metadata_variable_name(request: Request) -> str:
|
||||
return "metadata"
|
||||
|
||||
|
||||
def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Extract chain id for call chaining from request headers.
|
||||
|
||||
x-litellm-trace-id and x-litellm-session-id are interchangeable; when both
|
||||
are present, x-litellm-trace-id takes precedence. Header keys are matched
|
||||
case-insensitively so this works with raw header dicts from any transport.
|
||||
|
||||
Used by MCP (and other paths that have raw_headers but no Request) to set
|
||||
litellm_trace_id/litellm_session_id for spend logs and logging consistency.
|
||||
"""
|
||||
if not headers:
|
||||
return None
|
||||
normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)}
|
||||
return normalized.get("x-litellm-trace-id") or normalized.get(
|
||||
"x-litellm-session-id"
|
||||
)
|
||||
|
||||
|
||||
def safe_add_api_version_from_query_params(data: dict, request: Request):
|
||||
try:
|
||||
if hasattr(request, "query_params"):
|
||||
@ -177,12 +196,12 @@ def _get_dynamic_logging_metadata(
|
||||
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
|
||||
) -> Optional[TeamCallbackMetadata]:
|
||||
callback_settings_obj: Optional[TeamCallbackMetadata] = None
|
||||
key_dynamic_logging_settings: Optional[
|
||||
dict
|
||||
] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
|
||||
team_dynamic_logging_settings: Optional[
|
||||
dict
|
||||
] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
|
||||
key_dynamic_logging_settings: Optional[dict] = (
|
||||
KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
|
||||
)
|
||||
team_dynamic_logging_settings: Optional[dict] = (
|
||||
KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
|
||||
)
|
||||
#########################################################################################
|
||||
# Key-based callbacks
|
||||
#########################################################################################
|
||||
@ -576,9 +595,13 @@ class LiteLLMProxyRequestSetup:
|
||||
#########################################################################################
|
||||
# Finally update the requests metadata with the `metadata_from_headers`
|
||||
#########################################################################################
|
||||
|
||||
agent_id_from_header = headers.get("x-litellm-agent-id")
|
||||
trace_id_from_header = headers.get("x-litellm-trace-id")
|
||||
session_id_from_header = headers.get("x-litellm-session-id")
|
||||
# x-litellm-trace-id and x-litellm-session-id are interchangeable for call chaining
|
||||
chain_id = headers.get("x-litellm-trace-id") or headers.get(
|
||||
"x-litellm-session-id"
|
||||
)
|
||||
|
||||
|
||||
if agent_id_from_header:
|
||||
metadata_from_headers["agent_id"] = agent_id_from_header
|
||||
@ -586,16 +609,13 @@ class LiteLLMProxyRequestSetup:
|
||||
f"Extracted agent_id from header: {agent_id_from_header}"
|
||||
)
|
||||
|
||||
if trace_id_from_header:
|
||||
metadata_from_headers["trace_id"] = trace_id_from_header
|
||||
if chain_id:
|
||||
metadata_from_headers["trace_id"] = chain_id
|
||||
metadata_from_headers["session_id"] = chain_id
|
||||
data["litellm_session_id"] = chain_id
|
||||
data["litellm_trace_id"] = chain_id
|
||||
verbose_proxy_logger.debug(
|
||||
f"Extracted trace_id from header: {trace_id_from_header}"
|
||||
)
|
||||
|
||||
if session_id_from_header:
|
||||
metadata_from_headers["session_id"] = session_id_from_header
|
||||
verbose_proxy_logger.debug(
|
||||
f"Extracted session_id from header: {session_id_from_header}"
|
||||
f"Extracted chain_id from header (trace-id/session-id): {chain_id}"
|
||||
)
|
||||
|
||||
if isinstance(data[_metadata_variable_name], dict):
|
||||
@ -702,11 +722,11 @@ class LiteLLMProxyRequestSetup:
|
||||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name][
|
||||
"tags"
|
||||
] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
data[_metadata_variable_name]["tags"] = (
|
||||
LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
)
|
||||
if "disable_global_guardrails" in key_metadata and isinstance(
|
||||
key_metadata["disable_global_guardrails"], bool
|
||||
@ -839,14 +859,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router, premium_user
|
||||
from litellm.types.proxy.litellm_pre_call_utils import (
|
||||
RedactedDict,
|
||||
SecretFields,
|
||||
)
|
||||
from litellm.types.proxy.litellm_pre_call_utils import RedactedDict, SecretFields
|
||||
|
||||
_raw_headers: Dict[str, str] = RedactedDict(
|
||||
_safe_get_request_headers(request)
|
||||
)
|
||||
_raw_headers: Dict[str, str] = RedactedDict(_safe_get_request_headers(request))
|
||||
|
||||
forward_llm_auth = False
|
||||
if general_settings:
|
||||
@ -986,9 +1001,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data[_metadata_variable_name]["litellm_api_version"] = version
|
||||
|
||||
if general_settings is not None:
|
||||
data[_metadata_variable_name][
|
||||
"global_max_parallel_requests"
|
||||
] = general_settings.get("global_max_parallel_requests", None)
|
||||
data[_metadata_variable_name]["global_max_parallel_requests"] = (
|
||||
general_settings.get("global_max_parallel_requests", None)
|
||||
)
|
||||
|
||||
### KEY-LEVEL Controls
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
|
||||
@ -11,26 +11,21 @@ from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import (
|
||||
MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB,
|
||||
)
|
||||
from litellm.constants import \
|
||||
MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
|
||||
from litellm.constants import REDACTED_BY_LITELM_STRING
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
get_litellm_metadata_from_kwargs,
|
||||
reconstruct_model_name,
|
||||
)
|
||||
get_litellm_metadata_from_kwargs, reconstruct_model_name)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import (
|
||||
CostBreakdown,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingVectorStoreRequest,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
from litellm.types.utils import (CostBreakdown,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingVectorStoreRequest,
|
||||
VectorStoreSearchResponse)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
@ -116,16 +111,15 @@ def _get_spend_logs_metadata(
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = SpendLogsMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata.get(key)
|
||||
for key in SpendLogsMetadata.__annotations__.keys()
|
||||
key: metadata.get(key) for key in SpendLogsMetadata.__annotations__.keys()
|
||||
}
|
||||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
clean_metadata["batch_models"] = batch_models
|
||||
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
||||
clean_metadata[
|
||||
"vector_store_request_metadata"
|
||||
] = _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
|
||||
clean_metadata["vector_store_request_metadata"] = (
|
||||
_get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
|
||||
)
|
||||
clean_metadata["guardrail_information"] = guardrail_information
|
||||
clean_metadata["usage_object"] = usage_object
|
||||
clean_metadata["model_map_information"] = model_map_information
|
||||
@ -372,9 +366,11 @@ def get_logging_payload( # noqa: PLR0915
|
||||
guardrail_information=(
|
||||
standard_logging_payload.get("guardrail_information", None)
|
||||
if standard_logging_payload is not None
|
||||
else metadata.get("standard_logging_guardrail_information", None)
|
||||
if metadata is not None
|
||||
else None
|
||||
else (
|
||||
metadata.get("standard_logging_guardrail_information", None)
|
||||
if metadata is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
cold_storage_object_key=(
|
||||
standard_logging_payload["metadata"].get("cold_storage_object_key", None)
|
||||
@ -501,6 +497,7 @@ def _get_session_id_for_spend_log(
|
||||
"""
|
||||
from litellm._uuid import uuid
|
||||
|
||||
|
||||
if (
|
||||
standard_logging_payload is not None
|
||||
and standard_logging_payload.get("trace_id") is not None
|
||||
@ -515,9 +512,7 @@ def _get_session_id_for_spend_log(
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _get_request_duration_ms(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> Optional[int]:
|
||||
def _get_request_duration_ms(start_time: datetime, end_time: datetime) -> Optional[int]:
|
||||
"""Compute request duration in milliseconds from start and end times."""
|
||||
try:
|
||||
return int((end_time - start_time).total_seconds() * 1000)
|
||||
@ -709,20 +704,20 @@ def _convert_to_json_serializable_dict(
|
||||
if max_depth <= 0:
|
||||
# Return a placeholder if max depth is exceeded
|
||||
return "<max_depth_exceeded>"
|
||||
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
|
||||
# Get the object's memory address to track visited objects
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
# Circular reference detected, return placeholder
|
||||
return "<circular_reference>"
|
||||
|
||||
|
||||
# Only track mutable objects (dict, list, objects with __dict__)
|
||||
if isinstance(obj, (dict, list)) or hasattr(obj, "__dict__"):
|
||||
visited.add(obj_id)
|
||||
|
||||
|
||||
try:
|
||||
if isinstance(obj, BaseModel):
|
||||
# Use Pydantic's model_dump() instead of pickle
|
||||
@ -741,7 +736,9 @@ def _convert_to_json_serializable_dict(
|
||||
]
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return _convert_to_json_serializable_dict(obj.__dict__, visited, max_depth - 1)
|
||||
return _convert_to_json_serializable_dict(
|
||||
obj.__dict__, visited, max_depth - 1
|
||||
)
|
||||
else:
|
||||
# Primitives (str, int, float, bool, None) pass through
|
||||
return obj
|
||||
@ -777,9 +774,7 @@ def _get_proxy_server_request_for_spend_logs_payload(
|
||||
# Apply message redaction if turn_off_message_logging is enabled
|
||||
if kwargs is not None:
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
perform_redaction,
|
||||
should_redact_message_logging,
|
||||
)
|
||||
perform_redaction, should_redact_message_logging)
|
||||
|
||||
# Build model_call_details dict to check redaction settings
|
||||
model_call_details = {
|
||||
@ -788,12 +783,12 @@ def _get_proxy_server_request_for_spend_logs_payload(
|
||||
"standard_callback_dynamic_params"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# If redaction is enabled, convert to serializable dict before redacting
|
||||
if should_redact_message_logging(model_call_details=model_call_details):
|
||||
_request_body = _convert_to_json_serializable_dict(_request_body)
|
||||
perform_redaction(model_call_details=_request_body, result=None)
|
||||
|
||||
|
||||
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
return _request_body_json_str
|
||||
@ -845,10 +840,8 @@ def _get_response_for_spend_logs_payload(
|
||||
# Apply message redaction if turn_off_message_logging is enabled
|
||||
if kwargs is not None:
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
perform_redaction,
|
||||
should_redact_message_logging,
|
||||
)
|
||||
|
||||
perform_redaction, should_redact_message_logging)
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
model_call_details = {
|
||||
"litellm_params": litellm_params,
|
||||
@ -856,11 +849,13 @@ def _get_response_for_spend_logs_payload(
|
||||
"standard_callback_dynamic_params"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# If redaction is enabled, convert to serializable dict before redacting
|
||||
if should_redact_message_logging(model_call_details=model_call_details):
|
||||
response_obj = _convert_to_json_serializable_dict(response_obj)
|
||||
response_obj = perform_redaction(model_call_details={}, result=response_obj)
|
||||
response_obj = perform_redaction(
|
||||
model_call_details={}, result=response_obj
|
||||
)
|
||||
|
||||
sanitized_wrapper = _sanitize_request_body_for_spend_logs_payload(
|
||||
{"response": response_obj}
|
||||
@ -882,7 +877,7 @@ def _should_store_prompts_and_responses_in_spend_logs() -> bool:
|
||||
|
||||
# Check general_settings (from DB or proxy_config.yaml)
|
||||
store_prompts_value = general_settings.get("store_prompts_in_spend_logs")
|
||||
|
||||
|
||||
# Normalize case: handle True/true/TRUE, False/false/FALSE, None/null
|
||||
if store_prompts_value is True:
|
||||
return True
|
||||
@ -890,7 +885,7 @@ def _should_store_prompts_and_responses_in_spend_logs() -> bool:
|
||||
# Case-insensitive string comparison
|
||||
if store_prompts_value.lower() == "true":
|
||||
return True
|
||||
|
||||
|
||||
# Also check environment variable
|
||||
return get_secret_bool("STORE_PROMPTS_IN_SPEND_LOGS") is True
|
||||
|
||||
|
||||
@ -1454,10 +1454,12 @@ def client(original_function): # noqa: PLR0915
|
||||
logging_obj, kwargs = function_setup(
|
||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
# Type assertion: logging_obj is guaranteed to be non-None after function_setup
|
||||
assert logging_obj is not None, "logging_obj should not be None after function_setup"
|
||||
|
||||
assert (
|
||||
logging_obj is not None
|
||||
), "logging_obj should not be None after function_setup"
|
||||
|
||||
## LOAD CREDENTIALS
|
||||
load_credentials_from_list(kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
@ -1753,7 +1755,9 @@ def client(original_function): # noqa: PLR0915
|
||||
print_args_passed_to_litellm(original_function, args, kwargs)
|
||||
start_time = datetime.datetime.now()
|
||||
result = None
|
||||
_update_response_metadata = getattr(sys.modules[__name__], "update_response_metadata")
|
||||
_update_response_metadata = getattr(
|
||||
sys.modules[__name__], "update_response_metadata"
|
||||
)
|
||||
logging_obj: Optional[LiteLLMLoggingObject] = kwargs.get(
|
||||
"litellm_logging_obj", None
|
||||
)
|
||||
@ -1776,9 +1780,11 @@ def client(original_function): # noqa: PLR0915
|
||||
logging_obj, kwargs = function_setup(
|
||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
# Type assertion: logging_obj is guaranteed to be non-None after function_setup
|
||||
assert logging_obj is not None, "logging_obj should not be None after function_setup"
|
||||
assert (
|
||||
logging_obj is not None
|
||||
), "logging_obj should not be None after function_setup"
|
||||
|
||||
modified_kwargs = await async_pre_call_deployment_hook(kwargs, call_type)
|
||||
if modified_kwargs is not None:
|
||||
@ -1861,6 +1867,7 @@ def client(original_function): # noqa: PLR0915
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
if _is_streaming_request(
|
||||
kwargs=kwargs,
|
||||
call_type=call_type,
|
||||
@ -2082,12 +2089,14 @@ def _is_async_request(
|
||||
return False
|
||||
|
||||
|
||||
_STREAMING_CALL_TYPES = frozenset({
|
||||
CallTypes.generate_content_stream,
|
||||
CallTypes.agenerate_content_stream,
|
||||
CallTypes.generate_content_stream.value,
|
||||
CallTypes.agenerate_content_stream.value,
|
||||
})
|
||||
_STREAMING_CALL_TYPES = frozenset(
|
||||
{
|
||||
CallTypes.generate_content_stream,
|
||||
CallTypes.agenerate_content_stream,
|
||||
CallTypes.generate_content_stream.value,
|
||||
CallTypes.agenerate_content_stream.value,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_streaming_request(
|
||||
@ -2181,7 +2190,7 @@ def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
|
||||
# Normalize: HuggingFace Tokenizer.encode() returns an Encoding object;
|
||||
# extract .ids so the return type is always List[int].
|
||||
if hasattr(enc, "ids"):
|
||||
return enc.ids
|
||||
return enc.ids # type: ignore
|
||||
return enc
|
||||
|
||||
|
||||
@ -5836,7 +5845,7 @@ def get_model_info(
|
||||
_model_info[key] = value # type: ignore
|
||||
|
||||
# if verbose_logger.isEnabledFor(logging.DEBUG):
|
||||
# verbose_logger.debug(f"model_info: {_model_info}")
|
||||
# verbose_logger.debug(f"model_info: {_model_info}")
|
||||
|
||||
returned_model_info = ModelInfo(
|
||||
**_model_info, supported_openai_params=supported_openai_params
|
||||
@ -6179,8 +6188,10 @@ def validate_environment( # noqa: PLR0915
|
||||
"AWS_ROLE_ARN" in os.environ
|
||||
or "AWS_PROFILE" in os.environ
|
||||
or "AWS_WEB_IDENTITY_TOKEN_FILE" in os.environ
|
||||
or "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" in os.environ # ECS task role
|
||||
or "AWS_CONTAINER_CREDENTIALS_FULL_URI" in os.environ # ECS/Fargate full URI credential delivery
|
||||
or "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
|
||||
in os.environ # ECS task role
|
||||
or "AWS_CONTAINER_CREDENTIALS_FULL_URI"
|
||||
in os.environ # ECS/Fargate full URI credential delivery
|
||||
):
|
||||
keys_in_environment = True
|
||||
else:
|
||||
@ -7386,7 +7397,9 @@ class ModelResponseIterator:
|
||||
if convert_to_delta is True:
|
||||
_stream_response = ModelResponseStream()
|
||||
_stream_response.choices[0].delta.content = model_response.choices[0].message.content # type: ignore
|
||||
self.model_response: Union[ModelResponse, ModelResponseStream] = _stream_response
|
||||
self.model_response: Union[ModelResponse, ModelResponseStream] = (
|
||||
_stream_response
|
||||
)
|
||||
else:
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
@ -7457,13 +7470,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
|
||||
Used for anthropic/gemini context caching.
|
||||
|
||||
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
|
||||
|
||||
|
||||
Can be disabled globally by setting litellm.disable_anthropic_gemini_context_caching_transform = True
|
||||
"""
|
||||
# Check if context caching is disabled globally
|
||||
if litellm.disable_anthropic_gemini_context_caching_transform is True:
|
||||
return False
|
||||
|
||||
|
||||
if "content" not in message:
|
||||
return False
|
||||
|
||||
@ -7980,6 +7993,7 @@ class ProviderConfigManager:
|
||||
def _get_azure_ai_config(model: str) -> BaseConfig:
|
||||
"""Get Azure AI config based on model type."""
|
||||
from litellm.llms.azure_ai.common_utils import AzureFoundryModelInfo
|
||||
|
||||
return AzureFoundryModelInfo.get_azure_ai_config_for_model(model)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1112,6 +1112,47 @@ async def test_get_guardrail_info_endpoint_db_guardrail(mocker):
|
||||
assert result.guardrail_definition_location == "db"
|
||||
|
||||
|
||||
class TestBuildFieldDict:
|
||||
"""Test _build_field_dict handles both enum and string ui_type values."""
|
||||
|
||||
def test_build_field_dict_with_string_ui_type(self):
|
||||
"""Test that _build_field_dict works when ui_type is a plain string (e.g. BlockCodeExecutionGuardrailConfigModel)."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.proxy.guardrails.guardrail_endpoints import _build_field_dict
|
||||
|
||||
field = MagicMock()
|
||||
field.json_schema_extra = {"ui_type": "multiselect", "options": ["python", "javascript"]}
|
||||
|
||||
result = _build_field_dict(
|
||||
field=field,
|
||||
field_annotation=str,
|
||||
description="Test field",
|
||||
required=False,
|
||||
)
|
||||
|
||||
assert result["type"] == "multiselect"
|
||||
assert result["description"] == "Test field"
|
||||
|
||||
def test_build_field_dict_with_enum_ui_type(self):
|
||||
"""Test that _build_field_dict works when ui_type is a GuardrailParamUITypes enum."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.proxy.guardrails.guardrail_endpoints import _build_field_dict
|
||||
from litellm.types.guardrails import GuardrailParamUITypes
|
||||
|
||||
field = MagicMock()
|
||||
field.json_schema_extra = {"ui_type": GuardrailParamUITypes.BOOL}
|
||||
|
||||
result = _build_field_dict(
|
||||
field=field,
|
||||
field_annotation=bool,
|
||||
description="Test bool field",
|
||||
required=True,
|
||||
)
|
||||
|
||||
assert result["type"] == "bool"
|
||||
assert result["required"] is True
|
||||
# --- Team guardrail registration (register / submissions) ---
|
||||
|
||||
MOCK_REGISTER_REQUEST = RegisterGuardrailRequest(
|
||||
@ -1571,4 +1612,4 @@ async def test_list_submissions_summary_counts_unaffected_by_filters(mocker):
|
||||
assert len(result.submissions) == 1 # filtered
|
||||
assert result.summary.total == 2 # unfiltered
|
||||
assert result.summary.pending_review == 1
|
||||
assert result.summary.active == 1
|
||||
assert result.summary.active == 1
|
||||
|
||||
@ -11,11 +11,16 @@ from fastapi import Request
|
||||
import litellm
|
||||
from litellm.proxy._types import TeamCallbackMetadata, UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
KeyAndTeamLoggingSettings, LiteLLMProxyRequestSetup,
|
||||
_get_dynamic_logging_metadata, _get_enforced_params,
|
||||
_get_metadata_variable_name, _update_model_if_key_alias_exists,
|
||||
add_guardrails_from_policy_engine, add_litellm_data_to_request,
|
||||
check_if_token_is_service_account)
|
||||
KeyAndTeamLoggingSettings,
|
||||
LiteLLMProxyRequestSetup,
|
||||
_get_dynamic_logging_metadata,
|
||||
_get_enforced_params,
|
||||
_get_metadata_variable_name,
|
||||
_update_model_if_key_alias_exists,
|
||||
add_guardrails_from_policy_engine,
|
||||
add_litellm_data_to_request,
|
||||
check_if_token_is_service_account,
|
||||
)
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
@ -154,8 +159,7 @@ def test_get_enforced_params(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_parses_string_metadata():
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -201,8 +205,7 @@ async def test_add_litellm_data_to_request_parses_string_metadata():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_user_spend_and_budget():
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/v1/completions"
|
||||
@ -240,8 +243,7 @@ async def test_add_litellm_data_to_request_user_spend_and_budget():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_audio_transcription_multipart():
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup request mock for /v1/audio/transcriptions
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -306,8 +308,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks():
|
||||
"""
|
||||
Test that litellm_disabled_callbacks from key metadata is properly added to the request data.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -360,8 +361,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_empty():
|
||||
"""
|
||||
Test that litellm_disabled_callbacks is not added when it's empty.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -413,8 +413,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_not_present():
|
||||
"""
|
||||
Test that litellm_disabled_callbacks is not added when it's not present in metadata.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -466,8 +465,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_invalid_type():
|
||||
"""
|
||||
Test that litellm_disabled_callbacks is not added when it's not a list.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -519,8 +517,7 @@ async def test_add_litellm_data_to_request_disabled_callbacks_with_logging_setti
|
||||
"""
|
||||
Test that litellm_disabled_callbacks works correctly alongside logging settings.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -1030,8 +1027,7 @@ from unittest.mock import AsyncMock
|
||||
from fastapi.responses import Response
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.common_request_processing import \
|
||||
ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
@ -1149,6 +1145,47 @@ async def test_add_litellm_metadata_from_request_headers():
|
||||
litellm.callbacks = original_callbacks
|
||||
|
||||
|
||||
def test_add_litellm_metadata_from_request_headers_x_litellm_trace_id_sets_chain_id():
|
||||
"""x-litellm-trace-id sets both metadata and top-level litellm_session_id/litellm_trace_id for call chaining."""
|
||||
headers = {"x-litellm-trace-id": "foo"}
|
||||
data = {"metadata": {}}
|
||||
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
|
||||
headers=headers, data=data, _metadata_variable_name="metadata"
|
||||
)
|
||||
assert data["metadata"]["trace_id"] == "foo"
|
||||
assert data["metadata"]["session_id"] == "foo"
|
||||
assert data["litellm_session_id"] == "foo"
|
||||
assert data["litellm_trace_id"] == "foo"
|
||||
|
||||
|
||||
def test_add_litellm_metadata_from_request_headers_x_litellm_session_id_sets_chain_id():
|
||||
"""x-litellm-session-id sets both metadata and top-level litellm_session_id/litellm_trace_id for call chaining."""
|
||||
headers = {"x-litellm-session-id": "bar"}
|
||||
data = {"metadata": {}}
|
||||
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
|
||||
headers=headers, data=data, _metadata_variable_name="metadata"
|
||||
)
|
||||
assert data["metadata"]["trace_id"] == "bar"
|
||||
assert data["metadata"]["session_id"] == "bar"
|
||||
assert data["litellm_session_id"] == "bar"
|
||||
assert data["litellm_trace_id"] == "bar"
|
||||
|
||||
|
||||
def test_add_litellm_metadata_from_request_headers_both_headers_trace_id_precedence():
|
||||
"""When both x-litellm-trace-id and x-litellm-session-id are present, trace-id takes precedence for chain_id."""
|
||||
headers = {
|
||||
"x-litellm-trace-id": "trace-value",
|
||||
"x-litellm-session-id": "session-value",
|
||||
}
|
||||
data = {"metadata": {}}
|
||||
LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers(
|
||||
headers=headers, data=data, _metadata_variable_name="metadata"
|
||||
)
|
||||
assert data["metadata"]["trace_id"] == "trace-value"
|
||||
assert data["metadata"]["session_id"] == "trace-value"
|
||||
assert data["litellm_session_id"] == "trace-value"
|
||||
assert data["litellm_trace_id"] == "trace-value"
|
||||
|
||||
|
||||
def test_get_internal_user_header_from_mapping_returns_expected_header():
|
||||
mappings = [
|
||||
@ -1407,8 +1444,7 @@ async def test_embedding_header_forwarding_with_model_group():
|
||||
importlib.reload(pre_call_utils_module)
|
||||
|
||||
# Re-import the function after reload to get the fresh version
|
||||
from litellm.proxy.litellm_pre_call_utils import \
|
||||
add_litellm_data_to_request
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request for embeddings
|
||||
request_mock = MagicMock(spec=Request)
|
||||
@ -1542,11 +1578,13 @@ async def test_add_guardrails_from_policy_engine():
|
||||
Test that add_guardrails_from_policy_engine adds guardrails from matching policies
|
||||
and tracks applied policies in metadata.
|
||||
"""
|
||||
from litellm.proxy.policy_engine.attachment_registry import \
|
||||
get_attachment_registry
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import (Policy, PolicyAttachment,
|
||||
PolicyGuardrails)
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
Policy,
|
||||
PolicyAttachment,
|
||||
PolicyGuardrails,
|
||||
)
|
||||
|
||||
# Setup test data
|
||||
data = {
|
||||
@ -1659,8 +1697,7 @@ async def test_add_guardrails_from_policy_engine_policy_version_by_id():
|
||||
Test that add_guardrails_from_policy_engine executes a specific policy version
|
||||
when policy_<uuid> is passed in the request body.
|
||||
"""
|
||||
from litellm.proxy.policy_engine.attachment_registry import \
|
||||
get_attachment_registry
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import Policy, PolicyGuardrails
|
||||
|
||||
@ -1729,6 +1766,7 @@ async def test_bearer_token_not_in_debug_logs():
|
||||
"""
|
||||
import logging
|
||||
from io import StringIO
|
||||
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ import {
|
||||
LeftOutlined,
|
||||
RightOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import { Sparkles, Wrench } from "lucide-react";
|
||||
import { Bot, Sparkles, Wrench } from "lucide-react";
|
||||
import { LogEntry } from "../columns";
|
||||
import { MCP_CALL_TYPES } from "../constants";
|
||||
import { AGENT_CALL_TYPES, MCP_CALL_TYPES } from "../constants";
|
||||
import { getEventDisplayName } from "../utils";
|
||||
import { DrawerHeader } from "./DrawerHeader";
|
||||
import { useKeyboardNavigation } from "./useKeyboardNavigation";
|
||||
@ -46,6 +46,7 @@ interface TraceEventRowProps {
|
||||
|
||||
function TraceEventRow({ row, isSelected, onClick }: TraceEventRowProps) {
|
||||
const isMcp = MCP_CALL_TYPES.includes(row.call_type);
|
||||
const isAgent = AGENT_CALL_TYPES.includes(row.call_type);
|
||||
const durationValue =
|
||||
row.request_duration_ms != null
|
||||
? (row.request_duration_ms / 1000).toFixed(3)
|
||||
@ -64,6 +65,8 @@ function TraceEventRow({ row, isSelected, onClick }: TraceEventRowProps) {
|
||||
<div className="flex items-center gap-1">
|
||||
{isMcp ? (
|
||||
<Wrench size={12} className="text-slate-500 flex-shrink-0" />
|
||||
) : isAgent ? (
|
||||
<Bot size={12} className="text-slate-500 flex-shrink-0" />
|
||||
) : (
|
||||
<Sparkles size={12} className="text-slate-500 flex-shrink-0" />
|
||||
)}
|
||||
@ -219,7 +222,10 @@ export function LogDetailsDrawer({
|
||||
: null;
|
||||
const sessionDurationSeconds =
|
||||
sessionStart && sessionEnd ? ((sessionEnd.getTime() - sessionStart.getTime()) / 1000).toFixed(2) : "0.00";
|
||||
const llmCount = sessionLogs.filter((row) => !MCP_CALL_TYPES.includes(row.call_type)).length;
|
||||
const llmCount = sessionLogs.filter(
|
||||
(row) => !MCP_CALL_TYPES.includes(row.call_type) && !AGENT_CALL_TYPES.includes(row.call_type),
|
||||
).length;
|
||||
const agentCount = sessionLogs.filter((row) => AGENT_CALL_TYPES.includes(row.call_type)).length;
|
||||
const mcpCount = sessionLogs.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length;
|
||||
const logsForList = isSessionMode ? sessionLogs : currentLog ? [currentLog] : [];
|
||||
const leftPanelId = isSessionMode ? sessionId || "" : currentLog?.request_id || "";
|
||||
@ -302,14 +308,25 @@ export function LogDetailsDrawer({
|
||||
</div>
|
||||
<div className="mt-1 text-[11px] text-slate-500 font-mono">
|
||||
{logsForList.length} req
|
||||
<span className="mx-1.5">·</span>
|
||||
{isSessionMode
|
||||
? `${llmCount} LLM`
|
||||
: `${logsForList.filter((row) => !MCP_CALL_TYPES.includes(row.call_type)).length} LLM`}
|
||||
<span className="mx-1.5">·</span>
|
||||
{isSessionMode
|
||||
? `${mcpCount} MCP`
|
||||
: `${logsForList.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length} MCP`}
|
||||
{[
|
||||
isSessionMode
|
||||
? llmCount
|
||||
: logsForList.filter(
|
||||
(row) =>
|
||||
!MCP_CALL_TYPES.includes(row.call_type) && !AGENT_CALL_TYPES.includes(row.call_type),
|
||||
).length,
|
||||
isSessionMode ? agentCount : logsForList.filter((row) => AGENT_CALL_TYPES.includes(row.call_type)).length,
|
||||
isSessionMode ? mcpCount : logsForList.filter((row) => MCP_CALL_TYPES.includes(row.call_type)).length,
|
||||
].map((count, i) => {
|
||||
const label = [" LLM", " Agent", " MCP"][i];
|
||||
return count > 0 ? (
|
||||
<span key={label}>
|
||||
<span className="mx-1.5">·</span>
|
||||
{count}
|
||||
{label}
|
||||
</span>
|
||||
) : null;
|
||||
})}
|
||||
<span className="mx-1.5">·</span>
|
||||
{isSessionMode
|
||||
? getSpendString(totalSessionCost)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Compact type-indicator badges for LLM and MCP log entries.
|
||||
* Compact type-indicator badges for LLM, Agent, and MCP log entries.
|
||||
* Used in the request logs table and session type column.
|
||||
*/
|
||||
|
||||
@ -15,6 +15,18 @@ export const WrenchIcon = ({ size = 10 }: { size?: number }) => (
|
||||
</svg>
|
||||
);
|
||||
|
||||
/** Agent/bot icon for A2A and agent call types (Lucide Bot-style). */
|
||||
export const AgentIcon = ({ size = 12 }: { size?: number }) => (
|
||||
<svg width={size} height={size} viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" className="flex-shrink-0">
|
||||
<path d="M12 8V4H8" />
|
||||
<rect width="16" height="12" x="4" y="8" rx="2" />
|
||||
<path d="M2 14h2" />
|
||||
<path d="M20 14h2" />
|
||||
<path d="M15 13v2" />
|
||||
<path d="M9 13v2" />
|
||||
</svg>
|
||||
);
|
||||
|
||||
export const LlmBadge = ({ count }: { count?: number }) => (
|
||||
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-blue-50 text-blue-700 border border-blue-200 rounded-full text-[11px] font-medium whitespace-nowrap">
|
||||
<SparkleIcon />
|
||||
@ -28,3 +40,10 @@ export const McpBadge = ({ count }: { count?: number }) => (
|
||||
{count != null ? count : "MCP"}
|
||||
</span>
|
||||
);
|
||||
|
||||
export const AgentBadge = ({ count }: { count?: number }) => (
|
||||
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-violet-50 text-violet-700 border border-violet-200 rounded-full text-[11px] font-medium whitespace-nowrap">
|
||||
<AgentIcon />
|
||||
{count != null ? count : "Agent"}
|
||||
</span>
|
||||
);
|
||||
|
||||
@ -6,8 +6,8 @@ import React, { useState } from "react";
|
||||
import { getProviderLogoAndName } from "../provider_info_helpers";
|
||||
import { TableHeaderSortDropdown } from "../common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
|
||||
import { TimeCell } from "./time_cell";
|
||||
import { MCP_CALL_TYPES } from "./constants";
|
||||
import { LlmBadge, McpBadge, SparkleIcon, WrenchIcon } from "./TypeBadges";
|
||||
import { AGENT_CALL_TYPES, MCP_CALL_TYPES } from "./constants";
|
||||
import { AgentBadge, AgentIcon, LlmBadge, McpBadge, SparkleIcon, WrenchIcon } from "./TypeBadges";
|
||||
|
||||
/** API sort field mapping for /spend/logs/ui endpoint */
|
||||
export const LOGS_SORT_FIELD_MAP = {
|
||||
@ -69,6 +69,7 @@ export type LogEntry = {
|
||||
mcp_tool_call_spend?: number;
|
||||
session_llm_count?: number;
|
||||
session_mcp_count?: number;
|
||||
session_agent_count?: number;
|
||||
onKeyHashClick?: (keyHash: string) => void;
|
||||
onSessionClick?: (sessionId: string) => void;
|
||||
};
|
||||
@ -124,17 +125,26 @@ export const createColumns = (sortProps?: LogsSortProps): ColumnDef<LogEntry>[]
|
||||
const row = info.row.original;
|
||||
const sessionCount = row.session_total_count || 1;
|
||||
const isMcp = MCP_CALL_TYPES.includes(row.call_type);
|
||||
const sessionLlmCount = row.session_llm_count ?? (isMcp ? 0 : sessionCount);
|
||||
const isAgent = AGENT_CALL_TYPES.includes(row.call_type);
|
||||
const sessionLlmCount = row.session_llm_count ?? (isMcp || isAgent ? 0 : sessionCount);
|
||||
const sessionAgentCount = row.session_agent_count ?? (isAgent ? sessionCount : 0);
|
||||
const sessionMcpCount = row.session_mcp_count ?? (isMcp ? sessionCount : 0);
|
||||
|
||||
if (isMcp) return <McpBadge />;
|
||||
if (isAgent && sessionCount <= 1) return <AgentBadge />;
|
||||
if (sessionCount <= 1) return <LlmBadge />;
|
||||
|
||||
// Multi-call session — show total count, plus MCP indicator when mixed.
|
||||
// Multi-call session — show total count, plus Agent/MCP indicators when mixed.
|
||||
const sessionTypeBadge = (
|
||||
<span className="inline-flex items-center gap-1 px-2 py-0.5 bg-blue-50 text-blue-700 border border-blue-200 rounded-full text-[11px] font-medium whitespace-nowrap">
|
||||
<SparkleIcon />
|
||||
<span>{sessionCount}</span>
|
||||
{sessionAgentCount > 0 && (
|
||||
<>
|
||||
<span className="text-blue-300">·</span>
|
||||
<AgentIcon size={10} />
|
||||
</>
|
||||
)}
|
||||
{sessionMcpCount > 0 && (
|
||||
<>
|
||||
<span className="text-blue-300">·</span>
|
||||
@ -144,8 +154,13 @@ export const createColumns = (sortProps?: LogsSortProps): ColumnDef<LogEntry>[]
|
||||
</span>
|
||||
);
|
||||
|
||||
const tooltipParts = [
|
||||
sessionLlmCount > 0 && `${sessionLlmCount} LLM`,
|
||||
sessionAgentCount > 0 && `${sessionAgentCount} Agent`,
|
||||
sessionMcpCount > 0 && `${sessionMcpCount} MCP`,
|
||||
].filter(Boolean);
|
||||
return (
|
||||
<Tooltip title={`${sessionLlmCount} LLM • ${sessionMcpCount} MCP`}>
|
||||
<Tooltip title={tooltipParts.join(" • ")}>
|
||||
{sessionTypeBadge}
|
||||
</Tooltip>
|
||||
);
|
||||
|
||||
@ -15,6 +15,9 @@ export const ERROR_CODE_OPTIONS: { label: string; value: string }[] = [
|
||||
/** Call types that represent MCP tool invocations (shared across columns, index, drawer). */
|
||||
export const MCP_CALL_TYPES = ["call_mcp_tool", "list_mcp_tools"];
|
||||
|
||||
/** Call types that represent agent/A2A requests (e.g. asend_message). */
|
||||
export const AGENT_CALL_TYPES = ["asend_message"];
|
||||
|
||||
export const QUICK_SELECT_OPTIONS: { label: string; value: number; unit: string }[] = [
|
||||
{ label: "Last 15 Minutes", value: 15, unit: "minutes" },
|
||||
{ label: "Last Hour", value: 1, unit: "hours" },
|
||||
|
||||
@ -20,7 +20,7 @@ import KeyInfoView from "../templates/key_info_view";
|
||||
import AuditLogs from "./audit_logs";
|
||||
import { createColumns, LogEntry, type LogsSortField } from "./columns";
|
||||
import { ConfigInfoMessage } from "./ConfigInfoMessage";
|
||||
import { ERROR_CODE_OPTIONS, MCP_CALL_TYPES, QUICK_SELECT_OPTIONS } from "./constants";
|
||||
import { AGENT_CALL_TYPES, ERROR_CODE_OPTIONS, MCP_CALL_TYPES, QUICK_SELECT_OPTIONS } from "./constants";
|
||||
import { CostBreakdownViewer } from "./CostBreakdownViewer";
|
||||
import { ErrorViewer } from "./ErrorViewer";
|
||||
import { useLogFilterLogic } from "./log_filter_logic";
|
||||
@ -309,13 +309,15 @@ export default function SpendLogsTable({
|
||||
return matchesSearch;
|
||||
});
|
||||
|
||||
const sessionCompositionById = searchedLogs.reduce<Record<string, { llm: number; mcp: number }>>((acc, log) => {
|
||||
const sessionCompositionById = searchedLogs.reduce<Record<string, { llm: number; agent: number; mcp: number }>>((acc, log) => {
|
||||
if (!log.session_id) return acc;
|
||||
if (!acc[log.session_id]) {
|
||||
acc[log.session_id] = { llm: 0, mcp: 0 };
|
||||
acc[log.session_id] = { llm: 0, agent: 0, mcp: 0 };
|
||||
}
|
||||
if (MCP_CALL_TYPES.includes(log.call_type)) {
|
||||
acc[log.session_id].mcp += 1;
|
||||
} else if (AGENT_CALL_TYPES.includes(log.call_type)) {
|
||||
acc[log.session_id].agent += 1;
|
||||
} else {
|
||||
acc[log.session_id].llm += 1;
|
||||
}
|
||||
@ -343,6 +345,7 @@ export default function SpendLogsTable({
|
||||
request_duration_ms: log.request_duration_ms,
|
||||
session_llm_count: sessionComposition?.llm ?? undefined,
|
||||
session_mcp_count: sessionComposition?.mcp ?? undefined,
|
||||
session_agent_count: sessionComposition?.agent ?? undefined,
|
||||
onKeyHashClick: (keyHash: string) => setSelectedKeyIdInfoView(keyHash),
|
||||
onSessionClick: (sessionId: string) => {
|
||||
if (sessionId) {
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "react-jsx",
|
||||
"jsx": "preserve",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user