feat(proxy): add project_alias tracking through callback metadata pipeline

Thread project_alias alongside project_id through the metadata pipeline so
callbacks receive the human-readable project name. DRY up duplicate metadata
dict construction in proxy_track_cost_callback and pass_through_endpoints by
reusing get_sanitized_user_information_from_key — future metadata fields only
need adding in one place.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Krrish Dholakia 2026-03-23 10:44:17 -07:00
parent 63425b4cb4
commit 6809213957
10 changed files with 317 additions and 196 deletions

View File

@ -354,9 +354,9 @@ class Logging(LiteLLMLoggingBaseClass):
)
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -801,9 +801,9 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_spec=prompt_spec,
dynamic_callback_params=dynamic_callback_params,
):
self.model_call_details["prompt_integration"] = (
logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = logger.__class__.__name__
return logger
except Exception:
# If check fails, continue to next logger
@ -871,9 +871,9 @@ class Logging(LiteLLMLoggingBaseClass):
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = anthropic_cache_control_logger.__class__.__name__
return anthropic_cache_control_logger
#########################################################
@ -885,9 +885,9 @@ class Logging(LiteLLMLoggingBaseClass):
internal_usage_cache=None,
llm_router=None,
)
self.model_call_details["prompt_integration"] = (
vector_store_custom_logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = vector_store_custom_logger.__class__.__name__
# Add to global callbacks so post-call hooks are invoked
if (
vector_store_custom_logger
@ -947,9 +947,9 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
@ -978,7 +978,9 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = "redacted by litellm. \
_metadata[
"raw_request"
] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'"
else:
curl_command = self._get_request_curl_command(
@ -990,31 +992,35 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
except Exception as e:
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
)
_metadata[
"raw_request"
] = "Unable to Log \
raw request: {}".format(
str(e)
)
_metadata["raw_request"] = "Unable to Log \
raw request: {}".format(str(e))
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
self.logger_fn(
@ -1314,13 +1320,13 @@ class Logging(LiteLLMLoggingBaseClass):
for callback in callbacks:
try:
if isinstance(callback, CustomLogger):
response: Optional[MCPPostCallResponseObject] = (
await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
response: Optional[
MCPPostCallResponseObject
] = await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
######################################################################
# if any of the callbacks modify the response, use the modified response
@ -1521,9 +1527,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
try:
@ -1549,9 +1555,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
@ -1700,9 +1706,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["litellm_params"].setdefault("metadata", {})
if self.model_call_details["litellm_params"]["metadata"] is None:
self.model_call_details["litellm_params"]["metadata"] = {}
self.model_call_details["litellm_params"]["metadata"]["hidden_params"] = (
getattr(logging_result, "_hidden_params", {})
)
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = getattr(logging_result, "_hidden_params", {})
def _process_hidden_params_and_response_cost(
self,
@ -1731,9 +1737,9 @@ class Logging(LiteLLMLoggingBaseClass):
result=logging_result
)
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(logging_result, start_time, end_time)
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(logging_result, start_time, end_time)
if (
standard_logging_payload := self.model_call_details.get(
@ -1811,9 +1817,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
@ -1850,10 +1856,10 @@ class Logging(LiteLLMLoggingBaseClass):
end_time=end_time,
)
elif isinstance(result, dict) or isinstance(result, list):
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
result, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
result, start_time, end_time
)
if (
standard_logging_payload := self.model_call_details.get(
@ -1862,9 +1868,9 @@ class Logging(LiteLLMLoggingBaseClass):
) is not None:
emit_standard_logging_payload(standard_logging_payload)
elif standard_logging_object is not None:
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
else:
self.model_call_details["response_cost"] = None
@ -2022,20 +2028,20 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
self._merge_hidden_params_from_response_into_metadata(
complete_streaming_response
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
if (
standard_logging_payload := self.model_call_details.get(
@ -2369,10 +2375,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -2396,10 +2402,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
@ -2538,9 +2544,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
try:
if self.model_call_details.get("cache_hit", False) is True:
@ -2551,10 +2557,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
)
verbose_logger.debug(
@ -2571,10 +2577,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
# print standard logging payload
@ -2601,9 +2607,9 @@ class Logging(LiteLLMLoggingBaseClass):
# _success_handler_helper_fn
if self.model_call_details.get("standard_logging_object") is None:
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(result, start_time, end_time)
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(result, start_time, end_time)
# print standard logging payload
if (
@ -2846,18 +2852,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
return start_time, end_time
@ -3825,9 +3831,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
service_name=arize_config.project_name,
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -3853,13 +3859,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
else:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={arize_phoenix_config.project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={arize_phoenix_config.project_name}"
# Set Phoenix project name from environment variable
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
@ -3867,19 +3873,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
else:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={phoenix_project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={phoenix_project_name}"
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
for callback in _in_memory_loggers:
if (
@ -4066,9 +4072,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -4747,6 +4753,7 @@ class StandardLoggingPayloadSetup:
user_api_key_team_id=None,
user_api_key_org_id=None,
user_api_key_project_id=None,
user_api_key_project_alias=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
user_api_key_user_email=None,
@ -4992,10 +4999,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -5578,6 +5585,7 @@ def get_standard_logging_metadata(
user_api_key_team_id=None,
user_api_key_org_id=None,
user_api_key_project_id=None,
user_api_key_project_alias=None,
user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None,
@ -5634,9 +5642,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
else:
cleaned_user_api_key_metadata[k] = v

View File

@ -2416,6 +2416,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
organization_metadata: Optional[dict] = None
# Project Params
project_alias: Optional[str] = None
project_metadata: Optional[dict] = None
# Time stamps
@ -3228,6 +3229,7 @@ class SpendLogsMetadata(TypedDict):
user_api_key_alias: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_project_id: Optional[str]
user_api_key_project_alias: Optional[str]
user_api_key_org_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_team_alias: Optional[str]

View File

@ -836,6 +836,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
if _jwt_project_obj is not None:
valid_token.project_metadata = _jwt_project_obj.metadata
valid_token.project_alias = _jwt_project_obj.project_alias
# run through common checks
_ = await common_checks(
@ -1431,6 +1432,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
if _project_obj is not None:
valid_token.project_metadata = _project_obj.metadata
valid_token.project_alias = _project_obj.project_alias
global_proxy_spend = None
if (
@ -1888,6 +1890,7 @@ async def _run_post_custom_auth_checks(
)
if _project_obj is not None:
valid_token.project_metadata = _project_obj.metadata
valid_token.project_alias = _project_obj.project_alias
if general_settings.get("custom_auth_run_common_checks", False):
_ = await common_checks(

View File

@ -27,14 +27,16 @@ async def create_missing_views(db: _db): # noqa: PLR0915
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit,
p.project_alias AS project_alias
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ProjectTable" p ON v.project_id = p.project_id;
"""
)

View File

@ -18,11 +18,9 @@ from litellm.proxy.auth.auth_checks import (
log_db_metrics,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingUserAPIKeyMetadata,
)
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
@ -51,25 +49,8 @@ class _ProxyDBLogger(CustomLogger):
from litellm.proxy.proxy_server import proxy_logging_obj
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_spend=user_api_key_dict.spend,
user_api_key_max_budget=user_api_key_dict.max_budget,
user_api_key_budget_reset_at=(
user_api_key_dict.budget_reset_at.isoformat()
if user_api_key_dict.budget_reset_at
else None
),
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_project_id=user_api_key_dict.project_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_request_route=user_api_key_dict.request_route,
user_api_key_auth_metadata=user_api_key_dict.metadata,
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
)
_metadata["user_api_key"] = user_api_key_dict.api_key

View File

@ -658,6 +658,7 @@ class LiteLLMProxyRequestSetup:
user_api_key_max_budget=user_api_key_dict.max_budget,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_project_id=user_api_key_dict.project_id,
user_api_key_project_alias=user_api_key_dict.project_alias,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,

View File

@ -55,6 +55,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_headers,
)
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.utils import get_server_root_path, normalize_route_for_root_path
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
@ -62,7 +63,6 @@ from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
from .streaming_handler import PassThroughStreamingHandler
from .success_handler import PassThroughEndpointLogging
@ -502,25 +502,8 @@ class HttpPassThroughEndpointHelpers(BasePassthroughUtils):
litellm_params_in_body[k] = _parsed_body.pop(k, None)
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_project_id=user_api_key_dict.project_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_request_route=user_api_key_dict.request_route,
user_api_key_spend=user_api_key_dict.spend,
user_api_key_max_budget=user_api_key_dict.max_budget,
user_api_key_budget_reset_at=(
user_api_key_dict.budget_reset_at.isoformat()
if user_api_key_dict.budget_reset_at
else None
),
user_api_key_auth_metadata=user_api_key_dict.metadata,
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
)

View File

@ -1898,9 +1898,9 @@ class ProxyLogging:
normalized_call_type = CallTypes.aembedding.value
if normalized_call_type is not None:
litellm_logging_obj.call_type = normalized_call_type
litellm_logging_obj.model_call_details["call_type"] = (
normalized_call_type
)
litellm_logging_obj.model_call_details[
"call_type"
] = normalized_call_type
# Pass-through endpoints are logged via the callback loop's
# async_post_call_failure_hook — skip pre_call and failure handlers.
if litellm_logging_obj.call_type == CallTypes.pass_through.value:
@ -2498,7 +2498,8 @@ class PrismaClient:
required_view = "LiteLLM_VerificationTokenView"
expected_views_str = ", ".join(f"'{view}'" for view in expected_views)
pg_schema = os.getenv("DATABASE_SCHEMA", "public")
ret = await self.db.query_raw(f"""
ret = await self.db.query_raw(
f"""
WITH existing_views AS (
SELECT viewname
FROM pg_views
@ -2510,7 +2511,8 @@ class PrismaClient:
(SELECT COUNT(*) FROM existing_views) AS view_count,
ARRAY_AGG(viewname) AS view_names
FROM existing_views
""")
"""
)
expected_total_views = len(expected_views)
if ret[0]["view_count"] == expected_total_views:
verbose_proxy_logger.info("All necessary views exist!")
@ -2519,7 +2521,8 @@ class PrismaClient:
## check if required view exists ##
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db
await self.db.execute_raw("""
await self.db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
@ -2529,7 +2532,8 @@ class PrismaClient:
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
""")
"""
)
verbose_proxy_logger.info(
"LiteLLM_VerificationTokenView Created in DB!"
@ -2964,6 +2968,7 @@ class PrismaClient:
t.members_with_roles AS team_members_with_roles,
t.object_permission_id AS team_object_permission_id,
t.organization_id as org_id,
p.project_alias AS project_alias,
tm.spend AS team_member_spend,
m.aliases AS team_model_aliases,
-- Added comma to separate b.* columns
@ -2981,6 +2986,7 @@ class PrismaClient:
LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id
LEFT JOIN "LiteLLM_ProjectTable" AS p ON v.project_id = p.project_id
LEFT JOIN "LiteLLM_OrganizationTable" AS o ON v.organization_id = o.organization_id
LEFT JOIN "LiteLLM_BudgetTable" AS b2 ON o.budget_id = b2.budget_id
WHERE v.token = '{token}'

View File

@ -2488,6 +2488,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_org_id: Optional[str]
user_api_key_team_id: Optional[str]
user_api_key_project_id: Optional[str]
user_api_key_project_alias: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_user_email: Optional[str]
user_api_key_team_alias: Optional[str]

View File

@ -0,0 +1,134 @@
"""
Tests for project_alias and project_id tracking through callback kwargs / metadata.
Verifies that project_alias flows from UserAPIKeyAuth through the metadata pipeline
to StandardLoggingMetadata, mirroring how team_alias already works.
"""
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import LiteLLM_VerificationTokenView, UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
class TestProjectAliasOnTypes:
"""project_alias field exists on the relevant types."""
def test_verification_token_view_has_project_alias(self):
token_view = LiteLLM_VerificationTokenView(
token="test-token",
project_id="proj-123",
project_alias="My Project",
)
assert token_view.project_alias == "My Project"
def test_verification_token_view_project_alias_defaults_none(self):
token_view = LiteLLM_VerificationTokenView(token="test-token")
assert token_view.project_alias is None
def test_user_api_key_auth_inherits_project_alias(self):
"""UserAPIKeyAuth extends LiteLLM_VerificationTokenView, so it gets project_alias."""
auth = UserAPIKeyAuth(
api_key="sk-test",
project_id="proj-1",
project_alias="billing-service",
)
assert auth.project_alias == "billing-service"
def test_standard_logging_metadata_has_project_alias_field(self):
metadata = StandardLoggingUserAPIKeyMetadata(
user_api_key_hash="hash",
user_api_key_alias=None,
user_api_key_spend=None,
user_api_key_max_budget=None,
user_api_key_budget_reset_at=None,
user_api_key_org_id=None,
user_api_key_team_id=None,
user_api_key_project_id="proj-1",
user_api_key_project_alias="billing-service",
user_api_key_user_id=None,
user_api_key_user_email=None,
user_api_key_team_alias=None,
user_api_key_end_user_id=None,
user_api_key_request_route=None,
user_api_key_auth_metadata=None,
)
assert metadata["user_api_key_project_alias"] == "billing-service"
class TestProjectAliasThroughMetadataPipeline:
"""project_alias flows through the full metadata pipeline."""
def test_get_sanitized_user_information_includes_project_alias(self):
user_api_key_dict = UserAPIKeyAuth(
api_key="sk-hashed",
project_id="proj-123",
project_alias="My Cool Project",
team_id="team-1",
team_alias="my-team",
)
result = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
assert result["user_api_key_project_id"] == "proj-123"
assert result["user_api_key_project_alias"] == "My Cool Project"
def test_get_sanitized_user_information_project_alias_none_when_no_project(self):
user_api_key_dict = UserAPIKeyAuth(api_key="sk-hashed")
result = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=user_api_key_dict
)
assert result["user_api_key_project_id"] is None
assert result["user_api_key_project_alias"] is None
def test_project_alias_flows_to_standard_logging_metadata(self):
"""get_standard_logging_metadata picks up project_alias from input metadata."""
metadata = {
"user_api_key_project_id": "proj-123",
"user_api_key_project_alias": "My Cool Project",
"user_api_key_team_id": "team-1",
"user_api_key_team_alias": "my-team",
}
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
assert result["user_api_key_project_alias"] == "My Cool Project"
def test_project_alias_defaults_to_none_in_logging_metadata(self):
result = StandardLoggingPayloadSetup.get_standard_logging_metadata({})
assert result["user_api_key_project_alias"] is None
def test_end_to_end_project_alias_flow(self):
"""Full flow: UserAPIKeyAuth -> get_sanitized -> get_standard_logging_metadata."""
auth = UserAPIKeyAuth(
api_key="sk-test",
project_id="proj-abc",
project_alias="analytics-pipeline",
team_id="team-1",
team_alias="data-team",
)
# Step 1: Auth → sanitized metadata
sanitized = LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
user_api_key_dict=auth
)
# Step 2: Sanitized metadata → standard logging metadata
logging_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
dict(sanitized)
)
assert logging_metadata["user_api_key_project_id"] == "proj-abc"
assert logging_metadata["user_api_key_project_alias"] == "analytics-pipeline"
assert logging_metadata["user_api_key_team_id"] == "team-1"
assert logging_metadata["user_api_key_team_alias"] == "data-team"