fix: Set user from token user_id for OpenMeter integration (#13029)
This commit is contained in:
parent
d23a6e3ea4
commit
f8c09e44f6
@ -170,12 +170,12 @@ def _get_dynamic_logging_metadata(
|
||||
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
|
||||
) -> Optional[TeamCallbackMetadata]:
|
||||
callback_settings_obj: Optional[TeamCallbackMetadata] = None
|
||||
key_dynamic_logging_settings: Optional[dict] = (
|
||||
KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
|
||||
)
|
||||
team_dynamic_logging_settings: Optional[dict] = (
|
||||
KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
|
||||
)
|
||||
key_dynamic_logging_settings: Optional[
|
||||
dict
|
||||
] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict)
|
||||
team_dynamic_logging_settings: Optional[
|
||||
dict
|
||||
] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict)
|
||||
#########################################################################################
|
||||
# Key-based callbacks
|
||||
#########################################################################################
|
||||
@ -439,7 +439,7 @@ class LiteLLMProxyRequestSetup:
|
||||
user_api_key_request_route=user_api_key_dict.request_route,
|
||||
)
|
||||
return user_api_key_logged_metadata
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_user_api_key_auth_to_request_metadata(
|
||||
data: dict,
|
||||
@ -457,9 +457,7 @@ class LiteLLMProxyRequestSetup:
|
||||
data[_metadata_variable_name].update(user_api_key_logged_metadata)
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key"
|
||||
] = (
|
||||
user_api_key_dict.api_key
|
||||
) # this is just the hashed token
|
||||
] = user_api_key_dict.api_key # this is just the hashed token
|
||||
|
||||
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
|
||||
user_api_key_dict, "end_user_max_budget", None
|
||||
@ -479,11 +477,11 @@ class LiteLLMProxyRequestSetup:
|
||||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name]["tags"] = (
|
||||
LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
data[_metadata_variable_name][
|
||||
"tags"
|
||||
] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||
key_metadata["spend_logs_metadata"], dict
|
||||
@ -632,6 +630,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
if "user" not in data:
|
||||
data["user"] = user
|
||||
|
||||
# Set user from token if not already set and token has user_id
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
@ -694,9 +696,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data[_metadata_variable_name]["litellm_api_version"] = version
|
||||
|
||||
if general_settings is not None:
|
||||
data[_metadata_variable_name]["global_max_parallel_requests"] = (
|
||||
general_settings.get("global_max_parallel_requests", None)
|
||||
)
|
||||
data[_metadata_variable_name][
|
||||
"global_max_parallel_requests"
|
||||
] = general_settings.get("global_max_parallel_requests", None)
|
||||
|
||||
### KEY-LEVEL Controls
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -18,7 +16,7 @@ class TestOpenMeterIntegration:
|
||||
# Set required environment variables
|
||||
os.environ["OPENMETER_API_KEY"] = "test-api-key"
|
||||
os.environ["OPENMETER_API_ENDPOINT"] = "https://test.openmeter.com"
|
||||
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test environment"""
|
||||
# Clean up environment variables
|
||||
@ -40,26 +38,22 @@ class TestOpenMeterIntegration:
|
||||
def test_common_logic_with_string_user(self):
|
||||
"""Test that _common_logic correctly handles string user parameter"""
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": "test-user-123",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "test-call-id"
|
||||
"litellm_call_id": "test-call-id",
|
||||
}
|
||||
|
||||
|
||||
# Mock response object
|
||||
response_obj = {
|
||||
"id": "test-response-id",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
|
||||
result = logger._common_logic(kwargs, response_obj)
|
||||
|
||||
|
||||
# Verify subject is a string, not a tuple
|
||||
assert isinstance(result["subject"], str)
|
||||
assert result["subject"] == "test-user-123"
|
||||
@ -69,25 +63,21 @@ class TestOpenMeterIntegration:
|
||||
def test_common_logic_with_integer_user(self):
|
||||
"""Test that _common_logic correctly converts integer user to string"""
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": 12345, # Integer user ID
|
||||
"model": "gpt-4",
|
||||
"response_cost": 0.002,
|
||||
"litellm_call_id": "test-call-id-2"
|
||||
"litellm_call_id": "test-call-id-2",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {
|
||||
"id": "test-response-id-2",
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 30
|
||||
}
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30},
|
||||
}
|
||||
|
||||
|
||||
result = logger._common_logic(kwargs, response_obj)
|
||||
|
||||
|
||||
# Verify subject is converted to string
|
||||
assert isinstance(result["subject"], str)
|
||||
assert result["subject"] == "12345"
|
||||
@ -95,31 +85,31 @@ class TestOpenMeterIntegration:
|
||||
def test_common_logic_missing_user(self):
|
||||
"""Test that _common_logic raises exception when user is missing"""
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "test-call-id"
|
||||
"litellm_call_id": "test-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {"id": "test-response-id"}
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="OpenMeter: user is required"):
|
||||
logger._common_logic(kwargs, response_obj)
|
||||
|
||||
def test_common_logic_none_user(self):
|
||||
"""Test that _common_logic raises exception when user is None"""
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": None,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "test-call-id"
|
||||
"litellm_call_id": "test-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {"id": "test-response-id"}
|
||||
|
||||
|
||||
with pytest.raises(Exception, match="OpenMeter: user is required"):
|
||||
logger._common_logic(kwargs, response_obj)
|
||||
|
||||
@ -140,44 +130,40 @@ class TestOpenMeterIntegration:
|
||||
assert isinstance(result["subject"], str)
|
||||
assert result["subject"] == ""
|
||||
|
||||
@patch('litellm.integrations.openmeter.HTTPHandler')
|
||||
@patch("litellm.integrations.openmeter.HTTPHandler")
|
||||
def test_log_success_event(self, mock_http_handler):
|
||||
"""Test synchronous log_success_event method"""
|
||||
mock_post = MagicMock()
|
||||
mock_http_handler.return_value.post = mock_post
|
||||
|
||||
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": "test-user",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "test-call-id"
|
||||
"litellm_call_id": "test-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {
|
||||
"id": "test-response-id",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
|
||||
logger.log_success_event(kwargs, response_obj, None, None)
|
||||
|
||||
|
||||
# Verify HTTP call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
# Verify the data structure
|
||||
call_args = mock_post.call_args
|
||||
data = json.loads(call_args[1]['data'])
|
||||
|
||||
data = json.loads(call_args[1]["data"])
|
||||
|
||||
assert data["subject"] == "test-user"
|
||||
assert isinstance(data["subject"], str)
|
||||
assert data["data"]["model"] == "gpt-3.5-turbo"
|
||||
|
||||
@patch('litellm.integrations.openmeter.get_async_httpx_client')
|
||||
@patch("litellm.integrations.openmeter.get_async_httpx_client")
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_success_event(self, mock_get_client):
|
||||
"""Test asynchronous log_success_event method"""
|
||||
@ -185,34 +171,30 @@ class TestOpenMeterIntegration:
|
||||
mock_client = MagicMock()
|
||||
mock_client.post = mock_post
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": "async-test-user",
|
||||
"model": "gpt-4",
|
||||
"response_cost": 0.002,
|
||||
"litellm_call_id": "async-test-call-id"
|
||||
"litellm_call_id": "async-test-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {
|
||||
"id": "async-test-response-id",
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 30
|
||||
}
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30},
|
||||
}
|
||||
|
||||
|
||||
await logger.async_log_success_event(kwargs, response_obj, None, None)
|
||||
|
||||
|
||||
# Verify async HTTP call was made
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Verify the data structure
|
||||
|
||||
# Verify the data structure
|
||||
call_args = mock_post.call_args
|
||||
data = json.loads(call_args[1]['data'])
|
||||
|
||||
data = json.loads(call_args[1]["data"])
|
||||
|
||||
assert data["subject"] == "async-test-user"
|
||||
assert isinstance(data["subject"], str)
|
||||
assert data["data"]["model"] == "gpt-4"
|
||||
@ -220,26 +202,22 @@ class TestOpenMeterIntegration:
|
||||
def test_cloudevents_structure(self):
|
||||
"""Test that the CloudEvents structure is correct"""
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": "cloudevents-test-user",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "cloudevents-test-call-id"
|
||||
"litellm_call_id": "cloudevents-test-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_data = {
|
||||
"id": "cloudevents-test-response-id",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 8,
|
||||
"total_tokens": 23
|
||||
}
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 8, "total_tokens": 23},
|
||||
}
|
||||
response_obj = litellm.ModelResponse(**response_data)
|
||||
|
||||
|
||||
result = logger._common_logic(kwargs, response_obj)
|
||||
|
||||
|
||||
# Verify CloudEvents required fields
|
||||
assert result["specversion"] == "1.0"
|
||||
assert result["type"] == "litellm_tokens" # default value
|
||||
@ -248,7 +226,7 @@ class TestOpenMeterIntegration:
|
||||
assert "time" in result
|
||||
assert isinstance(result["subject"], str)
|
||||
assert result["subject"] == "cloudevents-test-user"
|
||||
|
||||
|
||||
# Verify data structure
|
||||
assert "data" in result
|
||||
assert result["data"]["model"] == "gpt-3.5-turbo"
|
||||
@ -260,25 +238,109 @@ class TestOpenMeterIntegration:
|
||||
def test_custom_event_type(self):
|
||||
"""Test that custom event type is used when set"""
|
||||
os.environ["OPENMETER_EVENT_TYPE"] = "custom_event_type"
|
||||
|
||||
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
|
||||
kwargs = {
|
||||
"user": "custom-event-user",
|
||||
"model": "gpt-4",
|
||||
"response_cost": 0.003,
|
||||
"litellm_call_id": "custom-event-call-id"
|
||||
"litellm_call_id": "custom-event-call-id",
|
||||
}
|
||||
|
||||
|
||||
response_obj = {
|
||||
"id": "custom-event-response-id",
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 37
|
||||
}
|
||||
"usage": {"prompt_tokens": 25, "completion_tokens": 12, "total_tokens": 37},
|
||||
}
|
||||
|
||||
|
||||
result = logger._common_logic(kwargs, response_obj)
|
||||
|
||||
|
||||
assert result["type"] == "custom_event_type"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openmeter_integration_with_token_user_id():
|
||||
"""
|
||||
Test complete integration: token with user_id -> add_litellm_data_to_request -> OpenMeter callback
|
||||
|
||||
This test verifies that when a token has user_id but request has no user,
|
||||
the user_id is properly passed to OpenMeter callback.
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from unittest.mock import MagicMock
|
||||
from fastapi import Request
|
||||
|
||||
# Setup environment for OpenMeter
|
||||
os.environ["OPENMETER_API_KEY"] = "test-api-key"
|
||||
os.environ["OPENMETER_API_ENDPOINT"] = "https://test.openmeter.com"
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {"Content-Type": "application/json"}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
# Setup user API key with user_id
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="integration-test-user-123", # This should reach OpenMeter
|
||||
team_id="test_team_id",
|
||||
)
|
||||
|
||||
# Setup request data WITHOUT user field
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"response_cost": 0.001,
|
||||
"litellm_call_id": "integration-test-call-id",
|
||||
}
|
||||
|
||||
# Setup proxy config
|
||||
proxy_config = MagicMock()
|
||||
|
||||
# Step 1: Call add_litellm_data_to_request to set user from token
|
||||
processed_data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request_mock,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify that user was set from token
|
||||
assert "user" in processed_data
|
||||
assert processed_data["user"] == "integration-test-user-123"
|
||||
|
||||
# Step 2: Test that OpenMeter callback works with this data
|
||||
logger = OpenMeterLogger()
|
||||
|
||||
# Mock response object
|
||||
response_obj = {
|
||||
"id": "integration-test-response-id",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
# Test OpenMeter _common_logic with the processed data
|
||||
result = logger._common_logic(processed_data, response_obj)
|
||||
|
||||
# Verify that OpenMeter received the user from token
|
||||
assert result["subject"] == "integration-test-user-123"
|
||||
assert isinstance(result["subject"], str)
|
||||
assert result["data"]["model"] == "gpt-3.5-turbo"
|
||||
assert result["data"]["cost"] == 0.001
|
||||
|
||||
# Verify CloudEvents structure
|
||||
assert result["specversion"] == "1.0"
|
||||
assert result["type"] == "litellm_tokens"
|
||||
assert result["id"] == "integration-test-response-id"
|
||||
assert result["source"] == "litellm-proxy"
|
||||
assert "time" in result
|
||||
|
||||
# Clean up environment
|
||||
os.environ.pop("OPENMETER_API_KEY", None)
|
||||
os.environ.pop("OPENMETER_API_ENDPOINT", None)
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
@ -14,7 +12,6 @@ from litellm.proxy.litellm_pre_call_utils import (
|
||||
LiteLLMProxyRequestSetup,
|
||||
_get_dynamic_logging_metadata,
|
||||
_get_enforced_params,
|
||||
add_litellm_data_to_request,
|
||||
check_if_token_is_service_account,
|
||||
)
|
||||
|
||||
@ -31,17 +28,17 @@ def test_check_if_token_is_service_account():
|
||||
service_account_token = UserAPIKeyAuth(
|
||||
api_key="test-key", metadata={"service_account_id": "test-service-account"}
|
||||
)
|
||||
assert check_if_token_is_service_account(service_account_token) == True
|
||||
assert check_if_token_is_service_account(service_account_token) is True
|
||||
|
||||
# Test case 2: Regular user token
|
||||
regular_token = UserAPIKeyAuth(api_key="test-key", metadata={})
|
||||
assert check_if_token_is_service_account(regular_token) == False
|
||||
assert check_if_token_is_service_account(regular_token) is False
|
||||
|
||||
# Test case 3: Token with other metadata
|
||||
other_metadata_token = UserAPIKeyAuth(
|
||||
api_key="test-key", metadata={"user_id": "test-user"}
|
||||
)
|
||||
assert check_if_token_is_service_account(other_metadata_token) == False
|
||||
assert check_if_token_is_service_account(other_metadata_token) is False
|
||||
|
||||
|
||||
def test_get_enforced_params_for_service_account_settings():
|
||||
@ -210,7 +207,6 @@ async def test_add_litellm_data_to_request_audio_transcription_multipart():
|
||||
|
||||
# Assert metadata was parsed correctly
|
||||
metadata_field = updated_data.get("metadata", {})
|
||||
litellm_metadata = updated_data.get("litellm_metadata", {})
|
||||
|
||||
assert isinstance(metadata_field, dict)
|
||||
assert "tags" in metadata_field
|
||||
@ -606,7 +602,6 @@ def test_get_dynamic_logging_metadata_with_arize_team_logging():
|
||||
assert result.callback_vars["arize_space_id"] == "test_arize_space_id"
|
||||
|
||||
|
||||
|
||||
def test_get_num_retries_from_request():
|
||||
"""
|
||||
Test LiteLLMProxyRequestSetup._get_num_retries_from_request method
|
||||
@ -668,6 +663,7 @@ def test_get_num_retries_from_request():
|
||||
)
|
||||
assert result == -1
|
||||
|
||||
|
||||
def test_add_user_api_key_auth_to_request_metadata():
|
||||
"""
|
||||
Test that add_user_api_key_auth_to_request_metadata properly adds user API key authentication data to request metadata
|
||||
@ -676,9 +672,9 @@ def test_add_user_api_key_auth_to_request_metadata():
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"litellm_metadata": {} # This will be the metadata variable name
|
||||
"litellm_metadata": {}, # This will be the metadata variable name
|
||||
}
|
||||
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="hashed-test-key-123",
|
||||
user_id="test-user-123",
|
||||
@ -689,21 +685,21 @@ def test_add_user_api_key_auth_to_request_metadata():
|
||||
team_alias="test-team-alias",
|
||||
end_user_id="test-end-user-123",
|
||||
request_route="/chat/completions",
|
||||
end_user_max_budget=500.0
|
||||
end_user_max_budget=500.0,
|
||||
)
|
||||
|
||||
|
||||
metadata_variable_name = "litellm_metadata"
|
||||
|
||||
|
||||
# Call the function
|
||||
result = LiteLLMProxyRequestSetup.add_user_api_key_auth_to_request_metadata(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
_metadata_variable_name=metadata_variable_name
|
||||
_metadata_variable_name=metadata_variable_name,
|
||||
)
|
||||
|
||||
|
||||
# Verify the metadata was properly added
|
||||
metadata = result[metadata_variable_name]
|
||||
|
||||
|
||||
# Check that user API key information was added
|
||||
assert metadata["user_api_key_hash"] == "hashed-test-key-123"
|
||||
assert metadata["user_api_key_alias"] == "test-key-alias"
|
||||
@ -714,13 +710,211 @@ def test_add_user_api_key_auth_to_request_metadata():
|
||||
assert metadata["user_api_key_end_user_id"] == "test-end-user-123"
|
||||
assert metadata["user_api_key_user_email"] == "test@example.com"
|
||||
assert metadata["user_api_key_request_route"] == "/chat/completions"
|
||||
|
||||
|
||||
# Check that the hashed API key was added
|
||||
assert metadata["user_api_key"] == "hashed-test-key-123"
|
||||
|
||||
|
||||
# Check that end user max budget was added
|
||||
assert metadata["user_api_end_user_max_budget"] == 500.0
|
||||
|
||||
|
||||
# Verify original data is preserved
|
||||
assert result["model"] == "gpt-3.5-turbo"
|
||||
assert result["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
assert result["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_sets_user_from_token():
|
||||
"""
|
||||
Test that user is set from user_api_key_dict.user_id when no user is provided in request
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {"Content-Type": "application/json"}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
# Setup user API key with user_id
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="test-user-123", # This should be set in data["user"]
|
||||
team_id="test_team_id",
|
||||
)
|
||||
|
||||
# Setup request data WITHOUT user field
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
# NO "user" field
|
||||
}
|
||||
|
||||
# Setup proxy config
|
||||
proxy_config = MagicMock()
|
||||
|
||||
# Call add_litellm_data_to_request
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request_mock,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify that user was set from token
|
||||
assert "user" in result
|
||||
assert result["user"] == "test-user-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_preserves_existing_user():
|
||||
"""
|
||||
Test that existing user in request data is not overwritten by token user_id
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {"Content-Type": "application/json"}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
# Setup user API key with user_id
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="token-user-456", # This should NOT overwrite existing user
|
||||
team_id="test_team_id",
|
||||
)
|
||||
|
||||
# Setup request data WITH existing user field
|
||||
data = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"user": "existing-user-789", # This should be preserved
|
||||
}
|
||||
|
||||
# Setup proxy config
|
||||
proxy_config = MagicMock()
|
||||
|
||||
# Call add_litellm_data_to_request
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request_mock,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify that existing user was preserved
|
||||
assert "user" in result
|
||||
assert result["user"] == "existing-user-789" # Original user preserved
|
||||
assert result["user"] != "token-user-456" # Token user_id not used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_user_from_headers_takes_priority():
|
||||
"""
|
||||
Test that user from headers takes priority over token user_id
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request with user header
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-User-ID": "header-user-999", # User from header
|
||||
}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
# Setup user API key with user_id
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="token-user-123", # This should NOT be used
|
||||
team_id="test_team_id",
|
||||
)
|
||||
|
||||
# Setup request data WITHOUT user field
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
|
||||
# Setup general settings with user header name
|
||||
general_settings = {"user_header_name": "X-User-ID"}
|
||||
|
||||
# Setup proxy config
|
||||
proxy_config = MagicMock()
|
||||
|
||||
# Call add_litellm_data_to_request
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request_mock,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
# Verify that header user takes priority
|
||||
assert "user" in result
|
||||
assert result["user"] == "header-user-999" # Header user used
|
||||
assert result["user"] != "token-user-123" # Token user_id not used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_no_user_when_token_has_no_user_id():
|
||||
"""
|
||||
Test that no user is set when token has no user_id and no user provided
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
# Setup mock request
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {"Content-Type": "application/json"}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
# Setup user API key WITHOUT user_id
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id=None, # No user_id in token
|
||||
team_id="test_team_id",
|
||||
)
|
||||
|
||||
# Setup request data WITHOUT user field
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
|
||||
# Setup proxy config
|
||||
proxy_config = MagicMock()
|
||||
|
||||
# Call add_litellm_data_to_request
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request_mock,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify that no user is set
|
||||
assert result.get("user") is None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user