[Proxy Startup]fix db config through envs (#13111)
* fix db config through envs * add helper * fix ruff * fix imports * add unit tests in db config changes
This commit is contained in:
parent
79be436c2b
commit
524a1ffd5f
@ -5,7 +5,6 @@ import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.parse
|
||||
import urllib.parse as urlparse
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
@ -686,26 +685,10 @@ def run_server( # noqa: PLR0915
|
||||
)
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url is None and os.getenv("DATABASE_URL") is None:
|
||||
# Check if all required variables are provided
|
||||
database_host = os.getenv("DATABASE_HOST")
|
||||
database_username = os.getenv("DATABASE_USERNAME")
|
||||
database_password = os.getenv("DATABASE_PASSWORD")
|
||||
database_name = os.getenv("DATABASE_NAME")
|
||||
|
||||
if (
|
||||
database_host
|
||||
and database_username
|
||||
and database_password
|
||||
and database_name
|
||||
):
|
||||
# Handle the problem of special character escaping in the database URL
|
||||
database_username_enc = urllib.parse.quote_plus(database_username)
|
||||
database_password_enc = urllib.parse.quote_plus(database_password)
|
||||
database_name_enc = urllib.parse.quote_plus(database_name)
|
||||
|
||||
# Construct DATABASE_URL from the provided variables
|
||||
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
|
||||
|
||||
# Use helper function to construct DATABASE_URL from individual variables
|
||||
from litellm.proxy.utils import construct_database_url_from_env_vars
|
||||
database_url = construct_database_url_from_env_vars()
|
||||
if database_url:
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
db_connection_pool_limit = general_settings.get(
|
||||
"database_connection_pool_limit",
|
||||
@ -729,6 +712,19 @@ def run_server( # noqa: PLR0915
|
||||
if database_url is not None and isinstance(database_url, str):
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
|
||||
# Handle database URL construction when no config file is used
|
||||
if config is None and os.getenv("DATABASE_URL") is None:
|
||||
# Use helper function to construct DATABASE_URL from individual variables
|
||||
from litellm.proxy.utils import construct_database_url_from_env_vars
|
||||
database_url = construct_database_url_from_env_vars()
|
||||
if database_url:
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
|
||||
# Set default values for connection pool settings when no config is used
|
||||
if config is None:
|
||||
db_connection_pool_limit = LiteLLMDatabaseConnectionPool.database_connection_pool_limit.value
|
||||
db_connection_timeout = LiteLLMDatabaseConnectionPool.database_connection_pool_timeout.value
|
||||
|
||||
if (
|
||||
os.getenv("DATABASE_URL", None) is not None
|
||||
or os.getenv("DIRECT_URL", None) is not None
|
||||
|
||||
@ -3434,3 +3434,39 @@ def is_valid_api_key(key: str) -> bool:
|
||||
if re.match(r"^[a-fA-F0-9]{64}$", key):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def construct_database_url_from_env_vars() -> Optional[str]:
|
||||
"""
|
||||
Construct a DATABASE_URL from individual environment variables.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The constructed DATABASE_URL or None if required variables are missing
|
||||
"""
|
||||
import urllib.parse
|
||||
|
||||
# Check if all required variables are provided
|
||||
database_host = os.getenv("DATABASE_HOST")
|
||||
database_username = os.getenv("DATABASE_USERNAME")
|
||||
database_password = os.getenv("DATABASE_PASSWORD")
|
||||
database_name = os.getenv("DATABASE_NAME")
|
||||
|
||||
if (
|
||||
database_host
|
||||
and database_username
|
||||
and database_name
|
||||
):
|
||||
# Handle the problem of special character escaping in the database URL
|
||||
database_username_enc = urllib.parse.quote_plus(database_username)
|
||||
database_password_enc = urllib.parse.quote_plus(database_password) if database_password else ""
|
||||
database_name_enc = urllib.parse.quote_plus(database_name)
|
||||
|
||||
# Construct DATABASE_URL from the provided variables
|
||||
if database_password:
|
||||
database_url = f"postgresql://{database_username_enc}:{database_password_enc}@{database_host}/{database_name_enc}"
|
||||
else:
|
||||
database_url = f"postgresql://{database_username_enc}@{database_host}/{database_name_enc}"
|
||||
|
||||
return database_url
|
||||
|
||||
return None
|
||||
|
||||
@ -313,6 +313,131 @@ class TestProxyInitializationHelpers:
|
||||
call_args = mock_uvicorn_run.call_args
|
||||
assert call_args[1]["timeout_keep_alive"] == 30
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_construct_database_url_from_env_vars(self):
|
||||
"""Test the construct_database_url_from_env_vars function with various scenarios"""
|
||||
from litellm.proxy.utils import construct_database_url_from_env_vars
|
||||
|
||||
# Test with all required variables present
|
||||
test_env = {
|
||||
"DATABASE_HOST": "localhost:5432",
|
||||
"DATABASE_USERNAME": "testuser",
|
||||
"DATABASE_PASSWORD": "testpass",
|
||||
"DATABASE_NAME": "testdb",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env):
|
||||
result = construct_database_url_from_env_vars()
|
||||
expected_url = "postgresql://testuser:testpass@localhost:5432/testdb"
|
||||
assert result == expected_url
|
||||
|
||||
# Test with special characters that need URL encoding
|
||||
test_env_special = {
|
||||
"DATABASE_HOST": "localhost:5432",
|
||||
"DATABASE_USERNAME": "user@with+special",
|
||||
"DATABASE_PASSWORD": "pass&word!@#$%",
|
||||
"DATABASE_NAME": "db_name/test",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env_special):
|
||||
result = construct_database_url_from_env_vars()
|
||||
expected_url = "postgresql://user%40with%2Bspecial:pass%26word%21%40%23%24%25@localhost:5432/db_name%2Ftest"
|
||||
assert result == expected_url
|
||||
|
||||
# Test without password (should still work)
|
||||
test_env_no_password = {
|
||||
"DATABASE_HOST": "localhost:5432",
|
||||
"DATABASE_USERNAME": "testuser",
|
||||
"DATABASE_NAME": "testdb",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env_no_password):
|
||||
result = construct_database_url_from_env_vars()
|
||||
expected_url = "postgresql://testuser@localhost:5432/testdb"
|
||||
assert result == expected_url
|
||||
|
||||
# Test with missing required variables (should return None)
|
||||
test_env_missing = {
|
||||
"DATABASE_HOST": "localhost:5432",
|
||||
"DATABASE_USERNAME": "testuser",
|
||||
# Missing DATABASE_NAME
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, test_env_missing):
|
||||
result = construct_database_url_from_env_vars()
|
||||
assert result is None
|
||||
|
||||
# Test with empty environment (should return None)
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
result = construct_database_url_from_env_vars()
|
||||
assert result is None
|
||||
|
||||
@patch("uvicorn.run")
|
||||
@patch("builtins.print")
|
||||
def test_run_server_no_config_passed(self, mock_print, mock_uvicorn_run):
|
||||
"""Test that run_server properly handles the case when no config is passed"""
|
||||
from click.testing import CliRunner
|
||||
from litellm.proxy.proxy_cli import run_server
|
||||
import asyncio
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_proxy_config = MagicMock()
|
||||
mock_key_mgmt = MagicMock()
|
||||
mock_save_worker_config = MagicMock()
|
||||
|
||||
# Mock the ProxyConfig.get_config method to return a proper async config
|
||||
async def mock_get_config(config_file_path=None):
|
||||
return {
|
||||
"general_settings": {},
|
||||
"litellm_settings": {}
|
||||
}
|
||||
|
||||
mock_proxy_config_instance = MagicMock()
|
||||
mock_proxy_config_instance.get_config = mock_get_config
|
||||
mock_proxy_config.return_value = mock_proxy_config_instance
|
||||
|
||||
# Ensure DATABASE_URL is not set in the environment
|
||||
with patch.dict(os.environ, {"DATABASE_URL": ""}, clear=True):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"proxy_server": MagicMock(
|
||||
app=mock_app,
|
||||
ProxyConfig=mock_proxy_config,
|
||||
KeyManagementSettings=mock_key_mgmt,
|
||||
save_worker_config=mock_save_worker_config,
|
||||
)
|
||||
},
|
||||
), patch(
|
||||
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
||||
) as mock_get_args:
|
||||
mock_get_args.return_value = {
|
||||
"app": "litellm.proxy.proxy_server:app",
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
}
|
||||
|
||||
# Test with no config parameter (config=None)
|
||||
result = runner.invoke(run_server, ["--local"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify that uvicorn.run was called
|
||||
mock_uvicorn_run.assert_called_once()
|
||||
|
||||
# Reset mocks for second test
|
||||
mock_uvicorn_run.reset_mock()
|
||||
|
||||
# Test with explicit --config None (should behave the same)
|
||||
result = runner.invoke(run_server, ["--local", "--config", "None"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify that uvicorn.run was called again
|
||||
mock_uvicorn_run.assert_called_once()
|
||||
|
||||
|
||||
class TestHealthAppFactory:
|
||||
"""Test cases for the health app factory module"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user