diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 64dfc6a5d8..8268e612bf 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5412,6 +5412,24 @@ async def chat_completion( # noqa: PLR0915 global general_settings, user_debug, proxy_logging_obj, llm_model_list global user_temperature, user_request_timeout, user_max_tokens, user_api_base data = await _read_request_body(request=request) + if user_api_key_dict is not None: + if data.get("metadata") is None: + data["metadata"] = {} + if ( + hasattr(user_api_key_dict, "user_id") + and user_api_key_dict.user_id is not None + ): + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if ( + hasattr(user_api_key_dict, "team_id") + and user_api_key_dict.team_id is not None + ): + data["metadata"]["user_api_key_team_id"] = user_api_key_dict.team_id + if ( + hasattr(user_api_key_dict, "org_id") + and user_api_key_dict.org_id is not None + ): + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) try: result = await base_llm_response_processor.base_process_llm_request( @@ -5563,6 +5581,24 @@ async def completion( # noqa: PLR0915 data = {} try: data = await _read_request_body(request=request) + if user_api_key_dict is not None: + if data.get("metadata") is None: + data["metadata"] = {} + if ( + hasattr(user_api_key_dict, "user_id") + and user_api_key_dict.user_id is not None + ): + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if ( + hasattr(user_api_key_dict, "team_id") + and user_api_key_dict.team_id is not None + ): + data["metadata"]["user_api_key_team_id"] = user_api_key_dict.team_id + if ( + hasattr(user_api_key_dict, "org_id") + and user_api_key_dict.org_id is not None + ): + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) return await base_llm_response_processor.base_process_llm_request( request=request, @@ -5782,6 +5818,25 @@ async def embeddings( # noqa: PLR0915 ) data["input"] = input_list + if user_api_key_dict is not None: + if data.get("metadata") is None: + data["metadata"] = {} + if ( + hasattr(user_api_key_dict, "user_id") + and user_api_key_dict.user_id is not None + ): + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + if ( + hasattr(user_api_key_dict, "team_id") + and user_api_key_dict.team_id is not None + ): + data["metadata"]["user_api_key_team_id"] = user_api_key_dict.team_id + if ( + hasattr(user_api_key_dict, "org_id") + and user_api_key_dict.org_id is not None + ): + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id + # Use unified request processor (same as chat/completions and responses) base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) diff --git a/tests/test_litellm/proxy/test_chat_completion_metadata.py b/tests/test_litellm/proxy/test_chat_completion_metadata.py new file mode 100644 index 0000000000..38dcdc13c5 --- /dev/null +++ b/tests/test_litellm/proxy/test_chat_completion_metadata.py @@ -0,0 +1,154 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from litellm.proxy.proxy_server import chat_completion, completion, embeddings +from litellm.proxy._types import UserAPIKeyAuth +from fastapi import Request, Response + + +@pytest.mark.asyncio +async def test_chat_completion_metadata_population(): + # Setup + request = MagicMock(spec=Request) + # Mock _read_request_body to return a dict + with patch( + "litellm.proxy.proxy_server._read_request_body", new_callable=AsyncMock + ) as mock_read_body: + mock_read_body.return_value = {"model": "gpt-3.5-turbo", "messages": []} + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user_id", team_id="test_team_id", org_id="test_org_id" + ) + + fastapi_response = MagicMock(spec=Response) + + # Mock ProxyBaseLLMRequestProcessing + with patch( + "litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing" + ) as MockProcessor: + mock_instance = MockProcessor.return_value + mock_instance.base_process_llm_request = AsyncMock( + return_value={"choices": []} + ) + + # Execute + await chat_completion( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + ) + + # Verify + # Check if ProxyBaseLLMRequestProcessing was initialized with data containing metadata + call_args = MockProcessor.call_args + assert call_args is not None + data_arg = call_args.kwargs.get("data") + assert data_arg is not None + + assert "metadata" in data_arg + assert data_arg["metadata"]["user_api_key_user_id"] == "test_user_id" + assert data_arg["metadata"]["user_api_key_team_id"] == "test_team_id" + assert data_arg["metadata"]["user_api_key_org_id"] == "test_org_id" + + +@pytest.mark.asyncio +async def test_embedding_metadata_population(): + """ + Test that the embedding endpoint correctly populates metadata + from UserAPIKeyAuth. + """ + # Setup + with patch( + "litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing.base_process_llm_request" + ): + with patch( + "litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing.__init__", + return_value=None, + ) as mock_base_process_init: + # Create a mock UserAPIKeyAuth object + mock_user_auth = MagicMock(spec=UserAPIKeyAuth) + mock_user_auth.user_id = "test_user_id_emb" + mock_user_auth.team_id = "test_team_id_emb" + mock_user_auth.org_id = "test_org_id_emb" + + # Create a mock Request object + mock_request = MagicMock(spec=Request) + mock_request.json = AsyncMock( + return_value={"model": "gpt-3.5-turbo", "input": "hello"} + ) + # Mock _read_request_body to return our data + with patch( + "litellm.proxy.proxy_server._read_request_body", + new=AsyncMock( + return_value={"model": "gpt-3.5-turbo", "input": "hello"} + ), + ): + # Call the endpoint function directly + await embeddings( + request=mock_request, + fastapi_response=MagicMock(spec=Response), + user_api_key_dict=mock_user_auth, + ) + + # Check if ProxyBaseLLMRequestProcessing was initialized with the correct metadata + mock_base_process_init.assert_called_once() + call_args = mock_base_process_init.call_args + # handle both positional and keyword args for data + if "data" in call_args.kwargs: + data_arg = call_args.kwargs["data"] + else: + data_arg = call_args.args[0] + + assert ( + data_arg["metadata"]["user_api_key_user_id"] == "test_user_id_emb" + ) + assert ( + data_arg["metadata"]["user_api_key_team_id"] == "test_team_id_emb" + ) + assert data_arg["metadata"]["user_api_key_org_id"] == "test_org_id_emb" + + +@pytest.mark.asyncio +async def test_completion_metadata_population(): + # Setup + request = MagicMock(spec=Request) + # Mock _read_request_body to return a dict + with patch( + "litellm.proxy.proxy_server._read_request_body", new_callable=AsyncMock + ) as mock_read_body: + mock_read_body.return_value = { + "model": "gpt-3.5-turbo-instruct", + "prompt": "test", + } + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user_id_2", team_id="test_team_id_2", org_id="test_org_id_2" + ) + + fastapi_response = MagicMock(spec=Response) + + # Mock ProxyBaseLLMRequestProcessing + with patch( + "litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing" + ) as MockProcessor: + mock_instance = MockProcessor.return_value + mock_instance.base_process_llm_request = AsyncMock( + return_value={"choices": []} + ) + + # Execute + await completion( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + ) + + # Verify + call_args = MockProcessor.call_args + assert call_args is not None + data_arg = call_args.kwargs.get("data") + assert data_arg is not None + + assert "metadata" in data_arg + assert data_arg["metadata"]["user_api_key_user_id"] == "test_user_id_2" + assert data_arg["metadata"]["user_api_key_team_id"] == "test_team_id_2" + assert data_arg["metadata"]["user_api_key_org_id"] == "test_org_id_2"