litellm/tests/test_litellm/proxy/test_proxy_cli.py
Cole McIntosh 3b911ba1b2
Enhance proxy CLI with Rich formatting and improved user experience (#11420)
* Enhance proxy CLI with Rich formatting and improved user experience

- Integrated Rich library for better console output in `proxy_cli.py`, including version display, health check results, and test completion responses.
- Updated health check and test completion methods to provide progress indicators and formatted tables.
- Refactored feedback display in `proxy_server.py` to use Rich for a more visually appealing user interface.
- Adjusted tests in `test_proxy_cli.py` to mock console output instead of using print statements, ensuring compatibility with Rich formatting.

* fix linting error

* refactor(proxy_cli.py): simplify DB setup logging

- Removed progress indicators for IAM token generation and environment variable decryption to simplify the code.
- Consolidated the logic for generating the database URL and setting environment variables.
- Enhanced error handling for configuration loading and database setup, ensuring clearer feedback

* Update test-linting workflow to include proxy-dev dependencies in Poetry installation

* Enhance proxy server initialization with Rich console for improved model display. Added support for loading model parameters from environment variables and refined provider identification logic. Fallback to original print formatting if Rich is not available.

* Refactor feedback handling: Moved feedback message generation and custom warning display to utils.py. Enhanced feedback box with rich formatting and fallback to ASCII for environments without rich. Cleaned up proxy_server.py by removing obsolete code.

* fix linting error

* Refactor model initialization display: Moved model initialization logic to a new utility function `display_model_initialization` for improved readability and maintainability. Enhanced model provider extraction with a dedicated function. Fallback to basic logging if Rich console is unavailable.

* Refactor model provider extraction: Replace the `_extract_provider_from_model` function with a more robust approach using `get_llm_provider`. Implement fallback logic for provider identification and improve error handling. Ensure compatibility with Rich console for model initialization display.
2025-06-06 17:16:53 -07:00

248 lines
9.0 KiB
Python

import os
import sys
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system-path
from litellm.proxy.proxy_cli import ProxyInitializationHelpers
class TestProxyInitializationHelpers:
@patch("importlib.metadata.version")
@patch("litellm.proxy.proxy_cli.console.print")
def test_echo_litellm_version(self, mock_console_print, mock_version):
# Setup
mock_version.return_value = "1.0.0"
# Execute
ProxyInitializationHelpers._echo_litellm_version()
# Assert
mock_version.assert_called_once_with("litellm")
# Should call console.print multiple times (for empty lines and panel)
assert mock_console_print.call_count >= 3
@patch("httpx.get")
@patch("litellm.proxy.proxy_cli.console.print")
def test_run_health_check(self, mock_console_print, mock_get):
# Setup
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"model1": {"status": "healthy", "response_time": "0.5s"}}
mock_get.return_value = mock_response
# Execute
ProxyInitializationHelpers._run_health_check("localhost", 8000)
# Assert
mock_get.assert_called_once_with(url="http://localhost:8000/health")
mock_response.json.assert_called_once()
# Should call console.print multiple times (progress, success message, table)
assert mock_console_print.call_count >= 2
@patch("openai.OpenAI")
@patch("click.echo")
@patch("builtins.print")
def test_run_test_chat_completion(self, mock_print, mock_echo, mock_openai):
# Setup
mock_client = MagicMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_stream_response = MagicMock()
mock_stream_response.__iter__.return_value = [MagicMock(), MagicMock()]
mock_client.chat.completions.create.side_effect = [
mock_response,
mock_stream_response,
]
# Execute
with pytest.raises(ValueError, match="Invalid test value"):
ProxyInitializationHelpers._run_test_chat_completion(
"localhost", 8000, "gpt-3.5-turbo", True
)
# Test with valid string test value
ProxyInitializationHelpers._run_test_chat_completion(
"localhost", 8000, "gpt-3.5-turbo", "http://test-url"
)
# Assert
mock_openai.assert_called_once_with(
api_key="My API Key", base_url="http://test-url"
)
mock_client.chat.completions.create.assert_called()
def test_get_default_unvicorn_init_args(self):
# Test without log_config
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000
)
assert args["app"] == "litellm.proxy.proxy_server:app"
assert args["host"] == "localhost"
assert args["port"] == 8000
# Test with log_config
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000, "log_config.json"
)
assert args["log_config"] == "log_config.json"
# Test with json_logs=True
with patch("litellm.json_logs", True):
args = ProxyInitializationHelpers._get_default_unvicorn_init_args(
"localhost", 8000
)
assert args["log_config"] is None
@patch("asyncio.run")
@patch("builtins.print")
def test_init_hypercorn_server(self, mock_print, mock_asyncio_run):
# Setup
mock_app = MagicMock()
# Execute
ProxyInitializationHelpers._init_hypercorn_server(
mock_app, "localhost", 8000, None, None
)
# Assert
mock_asyncio_run.assert_called_once()
# Test with SSL
ProxyInitializationHelpers._init_hypercorn_server(
mock_app, "localhost", 8000, "cert.pem", "key.pem"
)
@patch("subprocess.Popen")
def test_run_ollama_serve(self, mock_popen):
# Execute
ProxyInitializationHelpers._run_ollama_serve()
# Assert
mock_popen.assert_called_once()
# Test exception handling
mock_popen.side_effect = Exception("Test exception")
ProxyInitializationHelpers._run_ollama_serve() # Should not raise
@patch("socket.socket")
def test_is_port_in_use(self, mock_socket):
# Setup for port in use
mock_socket_instance = MagicMock()
mock_socket_instance.connect_ex.return_value = 0
mock_socket.return_value.__enter__.return_value = mock_socket_instance
# Execute and Assert
assert ProxyInitializationHelpers._is_port_in_use(8000) is True
# Setup for port not in use
mock_socket_instance.connect_ex.return_value = 1
# Execute and Assert
assert ProxyInitializationHelpers._is_port_in_use(8000) is False
def test_get_loop_type(self):
# Test on Windows
with patch("sys.platform", "win32"):
assert ProxyInitializationHelpers._get_loop_type() is None
# Test on Linux
with patch("sys.platform", "linux"):
assert ProxyInitializationHelpers._get_loop_type() == "uvloop"
@patch.dict(os.environ, {}, clear=True)
def test_database_url_construction_with_special_characters(self):
# Setup environment variables with special characters that need escaping
test_env = {
"DATABASE_HOST": "localhost:5432",
"DATABASE_USERNAME": "user@with+special",
"DATABASE_PASSWORD": "pass&word!@#$%",
"DATABASE_NAME": "db_name/test",
}
with patch.dict(os.environ, test_env):
# Call the relevant function - we'll need to extract the database URL construction logic
# This is simulating what happens in the run_server function when database_url is None
import urllib.parse
from litellm.proxy.proxy_cli import append_query_params
database_host = os.environ["DATABASE_HOST"]
database_username = os.environ["DATABASE_USERNAME"]
database_password = os.environ["DATABASE_PASSWORD"]
database_name = os.environ["DATABASE_NAME"]
# Test the URL encoding part
database_username_enc = urllib.parse.quote_plus(database_username)
database_password_enc = urllib.parse.quote_plus(database_password)
database_name_enc = urllib.parse.quote_plus(database_name)
# Construct DATABASE_URL from the provided variables
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
# Assert the correct URL was constructed with properly escaped characters
expected_url = "postgresql://user%40with%2Bspecial:pass%26word%21%40%23%24%25@localhost:5432/db_name%2Ftest"
assert database_url == expected_url
# Test appending query parameters
params = {"connection_limit": 10, "pool_timeout": 60}
modified_url = append_query_params(database_url, params)
assert "connection_limit=10" in modified_url
assert "pool_timeout=60" in modified_url
@patch("uvicorn.run")
@patch("litellm.proxy.proxy_cli.console.print")
def test_skip_server_startup(self, mock_console_print, mock_uvicorn_run):
"""Test that the skip_server_startup flag prevents server startup when True"""
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_app = MagicMock()
mock_proxy_config = MagicMock()
mock_key_mgmt = MagicMock()
mock_save_worker_config = MagicMock()
with patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=mock_app,
ProxyConfig=mock_proxy_config,
KeyManagementSettings=mock_key_mgmt,
save_worker_config=mock_save_worker_config,
)
},
), patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args:
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(run_server, ["--local", "--skip_server_startup"])
assert result.exit_code == 0
mock_uvicorn_run.assert_not_called()
# Check that console.print was called (for skip message)
assert mock_console_print.call_count >= 1
mock_uvicorn_run.reset_mock()
mock_console_print.reset_mock()
result = runner.invoke(run_server, ["--local"])
assert result.exit_code == 0
mock_uvicorn_run.assert_called_once()