diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index aaeb86e34c..34b73ab6a6 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -170,12 +170,12 @@ def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None - key_dynamic_logging_settings: Optional[dict] = ( - KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) - ) - team_dynamic_logging_settings: Optional[dict] = ( - KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) - ) + key_dynamic_logging_settings: Optional[ + dict + ] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) + team_dynamic_logging_settings: Optional[ + dict + ] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) ######################################################################################### # Key-based callbacks ######################################################################################### @@ -439,7 +439,7 @@ class LiteLLMProxyRequestSetup: user_api_key_request_route=user_api_key_dict.request_route, ) return user_api_key_logged_metadata - + @staticmethod def add_user_api_key_auth_to_request_metadata( data: dict, @@ -457,9 +457,7 @@ class LiteLLMProxyRequestSetup: data[_metadata_variable_name].update(user_api_key_logged_metadata) data[_metadata_variable_name][ "user_api_key" - ] = ( - user_api_key_dict.api_key - ) # this is just the hashed token + ] = user_api_key_dict.api_key # this is just the hashed token data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr( user_api_key_dict, "end_user_max_budget", None @@ -479,11 +477,11 @@ class LiteLLMProxyRequestSetup: ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - data[_metadata_variable_name]["tags"] = ( - LiteLLMProxyRequestSetup._merge_tags( - request_tags=data[_metadata_variable_name].get("tags"), - tags_to_add=key_metadata["tags"], - ) + data[_metadata_variable_name][ + "tags" + ] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], ) if "spend_logs_metadata" in key_metadata and isinstance( key_metadata["spend_logs_metadata"], dict @@ -632,6 +630,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915 if "user" not in data: data["user"] = user + # Set user from token if not already set and token has user_id + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), @@ -694,9 +696,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data[_metadata_variable_name]["global_max_parallel_requests"] = ( - general_settings.get("global_max_parallel_requests", None) - ) + data[_metadata_variable_name][ + "global_max_parallel_requests" + ] = general_settings.get("global_max_parallel_requests", None) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata diff --git a/tests/test_litellm/integrations/test_openmeter.py b/tests/test_litellm/integrations/test_openmeter.py index 5f947513a1..3da19369f6 100644 --- a/tests/test_litellm/integrations/test_openmeter.py +++ b/tests/test_litellm/integrations/test_openmeter.py @@ -1,7 +1,5 @@ -import asyncio import json import os -import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -18,7 +16,7 @@ class TestOpenMeterIntegration: # Set required environment variables os.environ["OPENMETER_API_KEY"] = "test-api-key" os.environ["OPENMETER_API_ENDPOINT"] = "https://test.openmeter.com" - + def teardown_method(self): """Clean up test environment""" # Clean up environment variables @@ -40,26 +38,22 @@ class TestOpenMeterIntegration: def test_common_logic_with_string_user(self): """Test that _common_logic correctly handles string user parameter""" logger = OpenMeterLogger() - + kwargs = { "user": "test-user-123", "model": "gpt-3.5-turbo", "response_cost": 0.001, - "litellm_call_id": "test-call-id" + "litellm_call_id": "test-call-id", } - + # Mock response object response_obj = { "id": "test-response-id", - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15 - } + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } - + result = logger._common_logic(kwargs, response_obj) - + # Verify subject is a string, not a tuple assert isinstance(result["subject"], str) assert result["subject"] == "test-user-123" @@ -69,25 +63,21 @@ class TestOpenMeterIntegration: def test_common_logic_with_integer_user(self): """Test that _common_logic correctly converts integer user to string""" logger = OpenMeterLogger() - + kwargs = { "user": 12345, # Integer user ID "model": "gpt-4", "response_cost": 0.002, - "litellm_call_id": "test-call-id-2" + "litellm_call_id": "test-call-id-2", } - + response_obj = { "id": "test-response-id-2", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30 - } + "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, } - + result = logger._common_logic(kwargs, response_obj) - + # Verify subject is converted to string assert isinstance(result["subject"], str) assert result["subject"] == "12345" @@ -95,31 +85,31 @@ class TestOpenMeterIntegration: def test_common_logic_missing_user(self): """Test that _common_logic raises exception when user is missing""" logger = OpenMeterLogger() - + kwargs = { "model": "gpt-3.5-turbo", "response_cost": 0.001, - "litellm_call_id": "test-call-id" + "litellm_call_id": "test-call-id", } - + response_obj = {"id": "test-response-id"} - + with pytest.raises(Exception, match="OpenMeter: user is required"): logger._common_logic(kwargs, response_obj) def test_common_logic_none_user(self): """Test that _common_logic raises exception when user is None""" logger = OpenMeterLogger() - + kwargs = { "user": None, "model": "gpt-3.5-turbo", "response_cost": 0.001, - "litellm_call_id": "test-call-id" + "litellm_call_id": "test-call-id", } - + response_obj = {"id": "test-response-id"} - + with pytest.raises(Exception, match="OpenMeter: user is required"): logger._common_logic(kwargs, response_obj) @@ -140,44 +130,40 @@ class TestOpenMeterIntegration: assert isinstance(result["subject"], str) assert result["subject"] == "" - @patch('litellm.integrations.openmeter.HTTPHandler') + @patch("litellm.integrations.openmeter.HTTPHandler") def test_log_success_event(self, mock_http_handler): """Test synchronous log_success_event method""" mock_post = MagicMock() mock_http_handler.return_value.post = mock_post - + logger = OpenMeterLogger() - + kwargs = { "user": "test-user", "model": "gpt-3.5-turbo", "response_cost": 0.001, - "litellm_call_id": "test-call-id" + "litellm_call_id": "test-call-id", } - + response_obj = { "id": "test-response-id", - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15 - } + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } - + logger.log_success_event(kwargs, response_obj, None, None) - + # Verify HTTP call was made mock_post.assert_called_once() - + # Verify the data structure call_args = mock_post.call_args - data = json.loads(call_args[1]['data']) - + data = json.loads(call_args[1]["data"]) + assert data["subject"] == "test-user" assert isinstance(data["subject"], str) assert data["data"]["model"] == "gpt-3.5-turbo" - @patch('litellm.integrations.openmeter.get_async_httpx_client') + @patch("litellm.integrations.openmeter.get_async_httpx_client") @pytest.mark.asyncio async def test_async_log_success_event(self, mock_get_client): """Test asynchronous log_success_event method""" @@ -185,34 +171,30 @@ class TestOpenMeterIntegration: mock_client = MagicMock() mock_client.post = mock_post mock_get_client.return_value = mock_client - + logger = OpenMeterLogger() - + kwargs = { "user": "async-test-user", "model": "gpt-4", "response_cost": 0.002, - "litellm_call_id": "async-test-call-id" + "litellm_call_id": "async-test-call-id", } - + response_obj = { "id": "async-test-response-id", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30 - } + "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, } - + await logger.async_log_success_event(kwargs, response_obj, None, None) - + # Verify async HTTP call was made mock_post.assert_called_once() - - # Verify the data structure + + # Verify the data structure call_args = mock_post.call_args - data = json.loads(call_args[1]['data']) - + data = json.loads(call_args[1]["data"]) + assert data["subject"] == "async-test-user" assert isinstance(data["subject"], str) assert data["data"]["model"] == "gpt-4" @@ -220,26 +202,22 @@ class TestOpenMeterIntegration: def test_cloudevents_structure(self): """Test that the CloudEvents structure is correct""" logger = OpenMeterLogger() - + kwargs = { "user": "cloudevents-test-user", "model": "gpt-3.5-turbo", "response_cost": 0.001, - "litellm_call_id": "cloudevents-test-call-id" + "litellm_call_id": "cloudevents-test-call-id", } - + response_data = { "id": "cloudevents-test-response-id", - "usage": { - "prompt_tokens": 15, - "completion_tokens": 8, - "total_tokens": 23 - } + "usage": {"prompt_tokens": 15, "completion_tokens": 8, "total_tokens": 23}, } response_obj = litellm.ModelResponse(**response_data) - + result = logger._common_logic(kwargs, response_obj) - + # Verify CloudEvents required fields assert result["specversion"] == "1.0" assert result["type"] == "litellm_tokens" # default value @@ -248,7 +226,7 @@ class TestOpenMeterIntegration: assert "time" in result assert isinstance(result["subject"], str) assert result["subject"] == "cloudevents-test-user" - + # Verify data structure assert "data" in result assert result["data"]["model"] == "gpt-3.5-turbo" @@ -260,25 +238,109 @@ class TestOpenMeterIntegration: def test_custom_event_type(self): """Test that custom event type is used when set""" os.environ["OPENMETER_EVENT_TYPE"] = "custom_event_type" - + logger = OpenMeterLogger() - + kwargs = { "user": "custom-event-user", "model": "gpt-4", "response_cost": 0.003, - "litellm_call_id": "custom-event-call-id" + "litellm_call_id": "custom-event-call-id", } - + response_obj = { "id": "custom-event-response-id", - "usage": { - "prompt_tokens": 25, - "completion_tokens": 12, - "total_tokens": 37 - } + "usage": {"prompt_tokens": 25, "completion_tokens": 12, "total_tokens": 37}, } - + result = logger._common_logic(kwargs, response_obj) - + assert result["type"] == "custom_event_type" + + +@pytest.mark.asyncio +async def test_openmeter_integration_with_token_user_id(): + """ + Test complete integration: token with user_id -> add_litellm_data_to_request -> OpenMeter callback + + This test verifies that when a token has user_id but request has no user, + the user_id is properly passed to OpenMeter callback. + """ + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + from litellm.proxy._types import UserAPIKeyAuth + from unittest.mock import MagicMock + from fastapi import Request + + # Setup environment for OpenMeter + os.environ["OPENMETER_API_KEY"] = "test-api-key" + os.environ["OPENMETER_API_ENDPOINT"] = "https://test.openmeter.com" + + # Setup mock request + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + # Setup user API key with user_id + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="integration-test-user-123", # This should reach OpenMeter + team_id="test_team_id", + ) + + # Setup request data WITHOUT user field + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "response_cost": 0.001, + "litellm_call_id": "integration-test-call-id", + } + + # Setup proxy config + proxy_config = MagicMock() + + # Step 1: Call add_litellm_data_to_request to set user from token + processed_data = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify that user was set from token + assert "user" in processed_data + assert processed_data["user"] == "integration-test-user-123" + + # Step 2: Test that OpenMeter callback works with this data + logger = OpenMeterLogger() + + # Mock response object + response_obj = { + "id": "integration-test-response-id", + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + # Test OpenMeter _common_logic with the processed data + result = logger._common_logic(processed_data, response_obj) + + # Verify that OpenMeter received the user from token + assert result["subject"] == "integration-test-user-123" + assert isinstance(result["subject"], str) + assert result["data"]["model"] == "gpt-3.5-turbo" + assert result["data"]["cost"] == 0.001 + + # Verify CloudEvents structure + assert result["specversion"] == "1.0" + assert result["type"] == "litellm_tokens" + assert result["id"] == "integration-test-response-id" + assert result["source"] == "litellm-proxy" + assert "time" in result + + # Clean up environment + os.environ.pop("OPENMETER_API_KEY", None) + os.environ.pop("OPENMETER_API_ENDPOINT", None) diff --git a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py index 0ec3fd9393..a7eba45f6e 100644 --- a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py +++ b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py @@ -1,9 +1,7 @@ -import asyncio -import copy import json import os import sys -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from fastapi import Request @@ -14,7 +12,6 @@ from litellm.proxy.litellm_pre_call_utils import ( LiteLLMProxyRequestSetup, _get_dynamic_logging_metadata, _get_enforced_params, - add_litellm_data_to_request, check_if_token_is_service_account, ) @@ -31,17 +28,17 @@ def test_check_if_token_is_service_account(): service_account_token = UserAPIKeyAuth( api_key="test-key", metadata={"service_account_id": "test-service-account"} ) - assert check_if_token_is_service_account(service_account_token) == True + assert check_if_token_is_service_account(service_account_token) is True # Test case 2: Regular user token regular_token = UserAPIKeyAuth(api_key="test-key", metadata={}) - assert check_if_token_is_service_account(regular_token) == False + assert check_if_token_is_service_account(regular_token) is False # Test case 3: Token with other metadata other_metadata_token = UserAPIKeyAuth( api_key="test-key", metadata={"user_id": "test-user"} ) - assert check_if_token_is_service_account(other_metadata_token) == False + assert check_if_token_is_service_account(other_metadata_token) is False def test_get_enforced_params_for_service_account_settings(): @@ -210,7 +207,6 @@ async def test_add_litellm_data_to_request_audio_transcription_multipart(): # Assert metadata was parsed correctly metadata_field = updated_data.get("metadata", {}) - litellm_metadata = updated_data.get("litellm_metadata", {}) assert isinstance(metadata_field, dict) assert "tags" in metadata_field @@ -606,7 +602,6 @@ def test_get_dynamic_logging_metadata_with_arize_team_logging(): assert result.callback_vars["arize_space_id"] == "test_arize_space_id" - def test_get_num_retries_from_request(): """ Test LiteLLMProxyRequestSetup._get_num_retries_from_request method @@ -668,6 +663,7 @@ def test_get_num_retries_from_request(): ) assert result == -1 + def test_add_user_api_key_auth_to_request_metadata(): """ Test that add_user_api_key_auth_to_request_metadata properly adds user API key authentication data to request metadata @@ -676,9 +672,9 @@ def test_add_user_api_key_auth_to_request_metadata(): data = { "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}], - "litellm_metadata": {} # This will be the metadata variable name + "litellm_metadata": {}, # This will be the metadata variable name } - + user_api_key_dict = UserAPIKeyAuth( api_key="hashed-test-key-123", user_id="test-user-123", @@ -689,21 +685,21 @@ def test_add_user_api_key_auth_to_request_metadata(): team_alias="test-team-alias", end_user_id="test-end-user-123", request_route="/chat/completions", - end_user_max_budget=500.0 + end_user_max_budget=500.0, ) - + metadata_variable_name = "litellm_metadata" - + # Call the function result = LiteLLMProxyRequestSetup.add_user_api_key_auth_to_request_metadata( data=data, user_api_key_dict=user_api_key_dict, - _metadata_variable_name=metadata_variable_name + _metadata_variable_name=metadata_variable_name, ) - + # Verify the metadata was properly added metadata = result[metadata_variable_name] - + # Check that user API key information was added assert metadata["user_api_key_hash"] == "hashed-test-key-123" assert metadata["user_api_key_alias"] == "test-key-alias" @@ -714,13 +710,211 @@ def test_add_user_api_key_auth_to_request_metadata(): assert metadata["user_api_key_end_user_id"] == "test-end-user-123" assert metadata["user_api_key_user_email"] == "test@example.com" assert metadata["user_api_key_request_route"] == "/chat/completions" - + # Check that the hashed API key was added assert metadata["user_api_key"] == "hashed-test-key-123" - + # Check that end user max budget was added assert metadata["user_api_end_user_max_budget"] == 500.0 - + # Verify original data is preserved assert result["model"] == "gpt-3.5-turbo" - assert result["messages"] == [{"role": "user", "content": "Hello"}] \ No newline at end of file + assert result["messages"] == [{"role": "user", "content": "Hello"}] + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_sets_user_from_token(): + """ + Test that user is set from user_api_key_dict.user_id when no user is provided in request + """ + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + # Setup mock request + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + # Setup user API key with user_id + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="test-user-123", # This should be set in data["user"] + team_id="test_team_id", + ) + + # Setup request data WITHOUT user field + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + # NO "user" field + } + + # Setup proxy config + proxy_config = MagicMock() + + # Call add_litellm_data_to_request + result = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify that user was set from token + assert "user" in result + assert result["user"] == "test-user-123" + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_preserves_existing_user(): + """ + Test that existing user in request data is not overwritten by token user_id + """ + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + # Setup mock request + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + # Setup user API key with user_id + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="token-user-456", # This should NOT overwrite existing user + team_id="test_team_id", + ) + + # Setup request data WITH existing user field + data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "user": "existing-user-789", # This should be preserved + } + + # Setup proxy config + proxy_config = MagicMock() + + # Call add_litellm_data_to_request + result = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify that existing user was preserved + assert "user" in result + assert result["user"] == "existing-user-789" # Original user preserved + assert result["user"] != "token-user-456" # Token user_id not used + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_user_from_headers_takes_priority(): + """ + Test that user from headers takes priority over token user_id + """ + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + # Setup mock request with user header + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = { + "Content-Type": "application/json", + "X-User-ID": "header-user-999", # User from header + } + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + # Setup user API key with user_id + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="token-user-123", # This should NOT be used + team_id="test_team_id", + ) + + # Setup request data WITHOUT user field + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + + # Setup general settings with user header name + general_settings = {"user_header_name": "X-User-ID"} + + # Setup proxy config + proxy_config = MagicMock() + + # Call add_litellm_data_to_request + result = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + general_settings=general_settings, + ) + + # Verify that header user takes priority + assert "user" in result + assert result["user"] == "header-user-999" # Header user used + assert result["user"] != "token-user-123" # Token user_id not used + + +@pytest.mark.asyncio +async def test_add_litellm_data_to_request_no_user_when_token_has_no_user_id(): + """ + Test that no user is set when token has no user_id and no user provided + """ + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + # Setup mock request + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + # Setup user API key WITHOUT user_id + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id=None, # No user_id in token + team_id="test_team_id", + ) + + # Setup request data WITHOUT user field + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + + # Setup proxy config + proxy_config = MagicMock() + + # Call add_litellm_data_to_request + result = await add_litellm_data_to_request( + data=data, + request=request_mock, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify that no user is set + assert result.get("user") is None