feat(proxy): add project_alias tracking through callback metadata pipeline
Thread project_alias alongside project_id through the metadata pipeline so callbacks receive the human-readable project name. DRY up duplicate metadata dict construction in proxy_track_cost_callback and pass_through_endpoints by reusing get_sanitized_user_information_from_key — future metadata fields only need adding in one place. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
63425b4cb4
commit
6809213957
@ -354,9 +354,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
@ -801,9 +801,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
prompt_spec=prompt_spec,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = logger.__class__.__name__
|
||||
return logger
|
||||
except Exception:
|
||||
# If check fails, continue to next logger
|
||||
@ -871,9 +871,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
anthropic_cache_control_logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = anthropic_cache_control_logger.__class__.__name__
|
||||
return anthropic_cache_control_logger
|
||||
|
||||
#########################################################
|
||||
@ -885,9 +885,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
vector_store_custom_logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = vector_store_custom_logger.__class__.__name__
|
||||
# Add to global callbacks so post-call hooks are invoked
|
||||
if (
|
||||
vector_store_custom_logger
|
||||
@ -947,9 +947,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
# Log the exact input to the LLM API
|
||||
@ -978,7 +978,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata["raw_request"] = "redacted by litellm. \
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
@ -990,31 +992,35 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
)
|
||||
_metadata["raw_request"] = "Unable to Log \
|
||||
raw request: {}".format(str(e))
|
||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||
try:
|
||||
self.logger_fn(
|
||||
@ -1314,13 +1320,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
response: Optional[MCPPostCallResponseObject] = (
|
||||
await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
response: Optional[
|
||||
MCPPostCallResponseObject
|
||||
] = await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
######################################################################
|
||||
# if any of the callbacks modify the response, use the modified response
|
||||
@ -1521,9 +1527,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -1549,9 +1555,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
|
||||
return None
|
||||
|
||||
@ -1700,9 +1706,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
self.model_call_details["litellm_params"].setdefault("metadata", {})
|
||||
if self.model_call_details["litellm_params"]["metadata"] is None:
|
||||
self.model_call_details["litellm_params"]["metadata"] = {}
|
||||
self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = (
|
||||
getattr(logging_result, "_hidden_params", {})
|
||||
)
|
||||
self.model_call_details["litellm_params"]["metadata"][
|
||||
"hidden_params"
|
||||
] = getattr(logging_result, "_hidden_params", {})
|
||||
|
||||
def _process_hidden_params_and_response_cost(
|
||||
self,
|
||||
@ -1731,9 +1737,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
result=logging_result
|
||||
)
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(logging_result, start_time, end_time)
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(logging_result, start_time, end_time)
|
||||
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -1811,9 +1817,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
@ -1850,10 +1856,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time=end_time,
|
||||
)
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -1862,9 +1868,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
) is not None:
|
||||
emit_standard_logging_payload(standard_logging_payload)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
else:
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
@ -2022,20 +2028,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
self._merge_hidden_params_from_response_into_metadata(
|
||||
complete_streaming_response
|
||||
)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -2369,10 +2375,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
@ -2396,10 +2402,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
@ -2538,9 +2544,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
@ -2551,10 +2557,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
@ -2571,10 +2577,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
|
||||
# print standard logging payload
|
||||
@ -2601,9 +2607,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
# _success_handler_helper_fn
|
||||
if self.model_call_details.get("standard_logging_object") is None:
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(result, start_time, end_time)
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(result, start_time, end_time)
|
||||
|
||||
# print standard logging payload
|
||||
if (
|
||||
@ -2846,18 +2852,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
@ -3825,9 +3831,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
service_name=arize_config.project_name,
|
||||
)
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
@ -3853,13 +3859,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
else:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
|
||||
# Set Phoenix project name from environment variable
|
||||
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
|
||||
@ -3867,19 +3873,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
else:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={phoenix_project_name}"
|
||||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
@ -4066,9 +4072,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
@ -4747,6 +4753,7 @@ class StandardLoggingPayloadSetup:
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_project_id=None,
|
||||
user_api_key_project_alias=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
user_api_key_user_email=None,
|
||||
@ -4992,10 +4999,10 @@ class StandardLoggingPayloadSetup:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
@ -5578,6 +5585,7 @@ def get_standard_logging_metadata(
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_project_id=None,
|
||||
user_api_key_project_alias=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_user_email=None,
|
||||
user_api_key_team_alias=None,
|
||||
@ -5634,9 +5642,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
|
||||
@ -2416,6 +2416,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||
organization_metadata: Optional[dict] = None
|
||||
|
||||
# Project Params
|
||||
project_alias: Optional[str] = None
|
||||
project_metadata: Optional[dict] = None
|
||||
|
||||
# Time stamps
|
||||
@ -3228,6 +3229,7 @@ class SpendLogsMetadata(TypedDict):
|
||||
user_api_key_alias: Optional[str]
|
||||
user_api_key_team_id: Optional[str]
|
||||
user_api_key_project_id: Optional[str]
|
||||
user_api_key_project_alias: Optional[str]
|
||||
user_api_key_org_id: Optional[str]
|
||||
user_api_key_user_id: Optional[str]
|
||||
user_api_key_team_alias: Optional[str]
|
||||
|
||||
@ -836,6 +836,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
)
|
||||
if _jwt_project_obj is not None:
|
||||
valid_token.project_metadata = _jwt_project_obj.metadata
|
||||
valid_token.project_alias = _jwt_project_obj.project_alias
|
||||
|
||||
# run through common checks
|
||||
_ = await common_checks(
|
||||
@ -1431,6 +1432,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
)
|
||||
if _project_obj is not None:
|
||||
valid_token.project_metadata = _project_obj.metadata
|
||||
valid_token.project_alias = _project_obj.project_alias
|
||||
|
||||
global_proxy_spend = None
|
||||
if (
|
||||
@ -1888,6 +1890,7 @@ async def _run_post_custom_auth_checks(
|
||||
)
|
||||
if _project_obj is not None:
|
||||
valid_token.project_metadata = _project_obj.metadata
|
||||
valid_token.project_alias = _project_obj.project_alias
|
||||
|
||||
if general_settings.get("custom_auth_run_common_checks", False):
|
||||
_ = await common_checks(
|
||||
|
||||
@ -27,14 +27,16 @@ async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
t.spend AS team_spend,
|
||||
t.max_budget AS team_max_budget,
|
||||
t.tpm_limit AS team_tpm_limit,
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
SELECT
|
||||
v.*,
|
||||
t.spend AS team_spend,
|
||||
t.max_budget AS team_max_budget,
|
||||
t.tpm_limit AS team_tpm_limit,
|
||||
t.rpm_limit AS team_rpm_limit,
|
||||
p.project_alias AS project_alias
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
|
||||
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@ -18,11 +18,9 @@ from litellm.proxy.auth.auth_checks import (
|
||||
log_db_metrics,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
@ -51,25 +49,8 @@ class _ProxyDBLogger(CustomLogger):
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
|
||||
@ -658,6 +658,7 @@ class LiteLLMProxyRequestSetup:
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_project_alias=user_api_key_dict.project_alias,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
|
||||
@ -55,6 +55,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
_read_request_body,
|
||||
_safe_get_request_headers,
|
||||
)
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.proxy.utils import get_server_root_path, normalize_route_for_root_path
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
@ -62,7 +63,6 @@ from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
EndpointType,
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
|
||||
|
||||
from .streaming_handler import PassThroughStreamingHandler
|
||||
from .success_handler import PassThroughEndpointLogging
|
||||
@ -502,25 +502,8 @@ class HttpPassThroughEndpointHelpers(BasePassthroughUtils):
|
||||
litellm_params_in_body[k] = _parsed_body.pop(k, None)
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_project_id=user_api_key_dict.project_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
user_api_key_spend=user_api_key_dict.spend,
|
||||
user_api_key_max_budget=user_api_key_dict.max_budget,
|
||||
user_api_key_budget_reset_at=(
|
||||
user_api_key_dict.budget_reset_at.isoformat()
|
||||
if user_api_key_dict.budget_reset_at
|
||||
else None
|
||||
),
|
||||
user_api_key_auth_metadata=user_api_key_dict.metadata,
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1898,9 +1898,9 @@ class ProxyLogging:
|
||||
normalized_call_type = CallTypes.aembedding.value
|
||||
if normalized_call_type is not None:
|
||||
litellm_logging_obj.call_type = normalized_call_type
|
||||
litellm_logging_obj.model_call_details["call_type"] = (
|
||||
normalized_call_type
|
||||
)
|
||||
litellm_logging_obj.model_call_details[
|
||||
"call_type"
|
||||
] = normalized_call_type
|
||||
# Pass-through endpoints are logged via the callback loop's
|
||||
# async_post_call_failure_hook — skip pre_call and failure handlers.
|
||||
if litellm_logging_obj.call_type == CallTypes.pass_through.value:
|
||||
@ -2498,7 +2498,8 @@ class PrismaClient:
|
||||
required_view = "LiteLLM_VerificationTokenView"
|
||||
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
|
||||
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
|
||||
ret = await self.db.query_raw(f"""
|
||||
ret = await self.db.query_raw(
|
||||
f"""
|
||||
WITH existing_views AS (
|
||||
SELECT viewname
|
||||
FROM pg_views
|
||||
@ -2510,7 +2511,8 @@ class PrismaClient:
|
||||
(SELECT COUNT(*) FROM existing_views) AS view_count,
|
||||
ARRAY_AGG(viewname) AS view_names
|
||||
FROM existing_views
|
||||
""")
|
||||
"""
|
||||
)
|
||||
expected_total_views = len(expected_views)
|
||||
if ret[0]["view_count"] == expected_total_views:
|
||||
verbose_proxy_logger.info("All necessary views exist!")
|
||||
@ -2519,7 +2521,8 @@ class PrismaClient:
|
||||
## check if required view exists ##
|
||||
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||
await self.health_check() # make sure we can connect to db
|
||||
await self.db.execute_raw("""
|
||||
await self.db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
@ -2529,7 +2532,8 @@ class PrismaClient:
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"LiteLLM_VerificationTokenView Created in DB!"
|
||||
@ -2964,6 +2968,7 @@ class PrismaClient:
|
||||
t.members_with_roles AS team_members_with_roles,
|
||||
t.object_permission_id AS team_object_permission_id,
|
||||
t.organization_id as org_id,
|
||||
p.project_alias AS project_alias,
|
||||
tm.spend AS team_member_spend,
|
||||
m.aliases AS team_model_aliases,
|
||||
-- Added comma to separate b.* columns
|
||||
@ -2981,6 +2986,7 @@ class PrismaClient:
|
||||
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
|
||||
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||||
LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id
|
||||
LEFT JOIN "LiteLLM_ProjectTable" AS p ON v.project_id = p.project_id
|
||||
LEFT JOIN "LiteLLM_OrganizationTable" AS o ON v.organization_id = o.organization_id
|
||||
LEFT JOIN "LiteLLM_BudgetTable" AS b2 ON o.budget_id = b2.budget_id
|
||||
WHERE v.token = '{token}'
|
||||
|
||||
@ -2488,6 +2488,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
||||
user_api_key_org_id: Optional[str]
|
||||
user_api_key_team_id: Optional[str]
|
||||
user_api_key_project_id: Optional[str]
|
||||
user_api_key_project_alias: Optional[str]
|
||||
user_api_key_user_id: Optional[str]
|
||||
user_api_key_user_email: Optional[str]
|
||||
user_api_key_team_alias: Optional[str]
|
||||
|
||||
134
tests/test_litellm/test_project_alias_tracking.py
Normal file
134
tests/test_litellm/test_project_alias_tracking.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""
|
||||
Tests for project_alias and project_id tracking through callback kwargs / metadata.
|
||||
|
||||
Verifies that project_alias flows from UserAPIKeyAuth through the metadata pipeline
|
||||
to StandardLoggingMetadata, mirroring how team_alias already works.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import LiteLLM_VerificationTokenView, UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
|
||||
|
||||
|
||||
class TestProjectAliasOnTypes:
|
||||
"""project_alias field exists on the relevant types."""
|
||||
|
||||
def test_verification_token_view_has_project_alias(self):
|
||||
token_view = LiteLLM_VerificationTokenView(
|
||||
token="test-token",
|
||||
project_id="proj-123",
|
||||
project_alias="My Project",
|
||||
)
|
||||
assert token_view.project_alias == "My Project"
|
||||
|
||||
def test_verification_token_view_project_alias_defaults_none(self):
|
||||
token_view = LiteLLM_VerificationTokenView(token="test-token")
|
||||
assert token_view.project_alias is None
|
||||
|
||||
def test_user_api_key_auth_inherits_project_alias(self):
|
||||
"""UserAPIKeyAuth extends LiteLLM_VerificationTokenView, so it gets project_alias."""
|
||||
auth = UserAPIKeyAuth(
|
||||
api_key="sk-test",
|
||||
project_id="proj-1",
|
||||
project_alias="billing-service",
|
||||
)
|
||||
assert auth.project_alias == "billing-service"
|
||||
|
||||
def test_standard_logging_metadata_has_project_alias_field(self):
|
||||
metadata = StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash="hash",
|
||||
user_api_key_alias=None,
|
||||
user_api_key_spend=None,
|
||||
user_api_key_max_budget=None,
|
||||
user_api_key_budget_reset_at=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_project_id="proj-1",
|
||||
user_api_key_project_alias="billing-service",
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_user_email=None,
|
||||
user_api_key_team_alias=None,
|
||||
user_api_key_end_user_id=None,
|
||||
user_api_key_request_route=None,
|
||||
user_api_key_auth_metadata=None,
|
||||
)
|
||||
assert metadata["user_api_key_project_alias"] == "billing-service"
|
||||
|
||||
|
||||
class TestProjectAliasThroughMetadataPipeline:
|
||||
"""project_alias flows through the full metadata pipeline."""
|
||||
|
||||
def test_get_sanitized_user_information_includes_project_alias(self):
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="sk-hashed",
|
||||
project_id="proj-123",
|
||||
project_alias="My Cool Project",
|
||||
team_id="team-1",
|
||||
team_alias="my-team",
|
||||
)
|
||||
|
||||
result = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
assert result["user_api_key_project_id"] == "proj-123"
|
||||
assert result["user_api_key_project_alias"] == "My Cool Project"
|
||||
|
||||
def test_get_sanitized_user_information_project_alias_none_when_no_project(self):
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="sk-hashed")
|
||||
|
||||
result = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
|
||||
assert result["user_api_key_project_id"] is None
|
||||
assert result["user_api_key_project_alias"] is None
|
||||
|
||||
def test_project_alias_flows_to_standard_logging_metadata(self):
|
||||
"""get_standard_logging_metadata picks up project_alias from input metadata."""
|
||||
metadata = {
|
||||
"user_api_key_project_id": "proj-123",
|
||||
"user_api_key_project_alias": "My Cool Project",
|
||||
"user_api_key_team_id": "team-1",
|
||||
"user_api_key_team_alias": "my-team",
|
||||
}
|
||||
|
||||
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
|
||||
assert result["user_api_key_project_alias"] == "My Cool Project"
|
||||
|
||||
def test_project_alias_defaults_to_none_in_logging_metadata(self):
|
||||
result = StandardLoggingPayloadSetup.get_standard_logging_metadata({})
|
||||
assert result["user_api_key_project_alias"] is None
|
||||
|
||||
def test_end_to_end_project_alias_flow(self):
|
||||
"""Full flow: UserAPIKeyAuth -> get_sanitized -> get_standard_logging_metadata."""
|
||||
auth = UserAPIKeyAuth(
|
||||
api_key="sk-test",
|
||||
project_id="proj-abc",
|
||||
project_alias="analytics-pipeline",
|
||||
team_id="team-1",
|
||||
team_alias="data-team",
|
||||
)
|
||||
|
||||
# Step 1: Auth → sanitized metadata
|
||||
sanitized = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=auth
|
||||
)
|
||||
|
||||
# Step 2: Sanitized metadata → standard logging metadata
|
||||
logging_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||
dict(sanitized)
|
||||
)
|
||||
|
||||
assert logging_metadata["user_api_key_project_id"] == "proj-abc"
|
||||
assert logging_metadata["user_api_key_project_alias"] == "analytics-pipeline"
|
||||
assert logging_metadata["user_api_key_team_id"] == "team-1"
|
||||
assert logging_metadata["user_api_key_team_alias"] == "data-team"
|
||||
Loading…
Reference in New Issue
Block a user