fix(proxy): expose Prisma idle/connect timeout + extra DB URL params (#28395)
* fix(proxy): expose Prisma idle/connect timeout + extra DB URL params Operators have reported large numbers of idle Prisma connections that never get closed. The proxy already forwards `connection_limit` and `pool_timeout` to the DATABASE_URL, but had no knob for capping idle or slow connections. Add three new `general_settings` keys that thread through to the DATABASE_URL / DIRECT_URL query string: - `database_connect_timeout` -> Prisma `connect_timeout` - `database_socket_timeout` -> Prisma `socket_timeout` (the main knob for closing idle connections from the LiteLLM side) - `database_extra_connection_params` -> untyped passthrough dict for any other Prisma URL param (`pgbouncer`, `statement_cache_size`, `sslmode`, ...); keys here override LiteLLM defaults. Refactors the duplicated DATABASE_URL/DIRECT_URL param dicts into a single `_build_db_connection_url_params` helper. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Update litellm/proxy/proxy_cli.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
parent
8acf64e16c
commit
2f9ac77b24
@ -2361,6 +2361,30 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||
database_connection_timeout: Optional[float] = Field(
|
||||
60, description="default timeout for a connection to the database"
|
||||
)
|
||||
database_connect_timeout: Optional[float] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Prisma `connect_timeout` URL param (seconds). Bounds how long the "
|
||||
"engine waits to establish a new connection before failing. Defaults "
|
||||
"to Prisma's built-in value when unset."
|
||||
),
|
||||
)
|
||||
database_socket_timeout: Optional[float] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Prisma `socket_timeout` URL param (seconds). When set, an idle/slow "
|
||||
"connection that has not produced data within this window is closed. "
|
||||
"This is the main knob for capping idle DB connections from LiteLLM."
|
||||
),
|
||||
)
|
||||
database_extra_connection_params: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Escape hatch: extra key/value pairs appended verbatim to the Prisma "
|
||||
"DATABASE_URL / DIRECT_URL query string (e.g. `sslmode`, `pgbouncer`, "
|
||||
"`statement_cache_size`). Keys here override any default LiteLLM sets."
|
||||
),
|
||||
)
|
||||
database_type: Optional[Literal["dynamo_db"]] = Field(
|
||||
None, description="to use dynamodb instead of postgres db"
|
||||
)
|
||||
|
||||
@ -38,6 +38,35 @@ class LiteLLMDatabaseConnectionPool(Enum):
|
||||
database_connection_pool_timeout = 60
|
||||
|
||||
|
||||
def _build_db_connection_url_params(
|
||||
connection_limit: int,
|
||||
pool_timeout: Optional[Union[int, float]],
|
||||
connect_timeout: Optional[Union[int, float]] = None,
|
||||
socket_timeout: Optional[Union[int, float]] = None,
|
||||
extra_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Build the Prisma DATABASE_URL query params controlling connection pool behavior.
|
||||
|
||||
`connect_timeout` / `socket_timeout` map to the Prisma URL params of the same
|
||||
name (https://www.prisma.io/docs/orm/overview/databases/postgresql) and are
|
||||
omitted when None so Prisma's defaults apply. `extra_params` is an
|
||||
untyped passthrough — keys it provides win over the named arguments above,
|
||||
so it can be used to override any default we set here.
|
||||
"""
|
||||
params: dict = {
|
||||
"connection_limit": connection_limit,
|
||||
}
|
||||
if pool_timeout is not None:
|
||||
params["pool_timeout"] = pool_timeout
|
||||
if connect_timeout is not None:
|
||||
params["connect_timeout"] = connect_timeout
|
||||
if socket_timeout is not None:
|
||||
params["socket_timeout"] = socket_timeout
|
||||
if extra_params:
|
||||
params.update(extra_params)
|
||||
return params
|
||||
|
||||
|
||||
def append_query_params(url: Optional[str], params: dict) -> str:
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
@ -807,6 +836,9 @@ def run_server( # noqa: PLR0915
|
||||
db_connection_pool_limit = 100
|
||||
# Starts optional due to config fallback checks; guaranteed non-None before use.
|
||||
db_connection_timeout: Optional[Union[int, float]] = 60
|
||||
db_connect_timeout: Optional[Union[int, float]] = None
|
||||
db_socket_timeout: Optional[Union[int, float]] = None
|
||||
db_extra_connection_params: Optional[dict] = None
|
||||
general_settings = {}
|
||||
### GET DB TOKEN FOR IAM AUTH ###
|
||||
|
||||
@ -924,6 +956,11 @@ def run_server( # noqa: PLR0915
|
||||
db_connection_timeout = (
|
||||
LiteLLMDatabaseConnectionPool.database_connection_pool_timeout.value
|
||||
)
|
||||
db_connect_timeout = general_settings.get("database_connect_timeout")
|
||||
db_socket_timeout = general_settings.get("database_socket_timeout")
|
||||
db_extra_connection_params = general_settings.get(
|
||||
"database_extra_connection_params"
|
||||
)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
original_dir = os.getcwd()
|
||||
# set the working directory to where this script is
|
||||
@ -963,27 +1000,26 @@ def run_server( # noqa: PLR0915
|
||||
try:
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
connection_url_params = _build_db_connection_url_params(
|
||||
connection_limit=db_connection_pool_limit,
|
||||
pool_timeout=db_connection_timeout,
|
||||
connect_timeout=db_connect_timeout,
|
||||
socket_timeout=db_socket_timeout,
|
||||
extra_params=db_extra_connection_params,
|
||||
)
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
params = {
|
||||
"connection_limit": db_connection_pool_limit,
|
||||
"pool_timeout": db_connection_timeout,
|
||||
}
|
||||
database_url = get_secret("DATABASE_URL", default_value=None)
|
||||
modified_url = append_query_params(
|
||||
str(database_url) if database_url else None, params
|
||||
str(database_url) if database_url else None,
|
||||
connection_url_params,
|
||||
)
|
||||
os.environ["DATABASE_URL"] = modified_url
|
||||
if os.getenv("DIRECT_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
params = {
|
||||
"connection_limit": db_connection_pool_limit,
|
||||
"pool_timeout": db_connection_timeout,
|
||||
}
|
||||
database_url = os.getenv("DIRECT_URL")
|
||||
modified_url = append_query_params(database_url, params)
|
||||
modified_url = append_query_params(
|
||||
database_url, connection_url_params
|
||||
)
|
||||
os.environ["DIRECT_URL"] = modified_url
|
||||
###
|
||||
subprocess.run(["prisma"], capture_output=True)
|
||||
is_prisma_runnable = True
|
||||
except FileNotFoundError:
|
||||
|
||||
@ -483,6 +483,136 @@ class TestProxyInitializationHelpers:
|
||||
assert appended_params["connection_limit"] == 5
|
||||
assert appended_params["pool_timeout"] == expected_timeout
|
||||
|
||||
def test_build_db_connection_url_params_defaults(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(connection_limit=10, pool_timeout=60)
|
||||
assert params == {"connection_limit": 10, "pool_timeout": 60}
|
||||
|
||||
def test_build_db_connection_url_params_omits_none_timeouts(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
connect_timeout=None,
|
||||
socket_timeout=None,
|
||||
)
|
||||
assert "connect_timeout" not in params
|
||||
assert "socket_timeout" not in params
|
||||
|
||||
def test_build_db_connection_url_params_includes_optional_timeouts(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
connect_timeout=15,
|
||||
socket_timeout=120,
|
||||
)
|
||||
assert params["connect_timeout"] == 15
|
||||
assert params["socket_timeout"] == 120
|
||||
|
||||
def test_build_db_connection_url_params_extras_override_defaults(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
extra_params={
|
||||
"pgbouncer": "true",
|
||||
"statement_cache_size": 0,
|
||||
"pool_timeout": 5,
|
||||
},
|
||||
)
|
||||
assert params["pgbouncer"] == "true"
|
||||
assert params["statement_cache_size"] == 0
|
||||
assert params["pool_timeout"] == 5
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("atexit.register")
|
||||
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
||||
@patch(
|
||||
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
||||
)
|
||||
def test_db_connection_extra_params_forwarded_to_url(
|
||||
self,
|
||||
mock_should_update,
|
||||
mock_setup_db,
|
||||
mock_atexit_register,
|
||||
mock_subprocess_run,
|
||||
):
|
||||
from click.testing import CliRunner
|
||||
|
||||
from litellm.proxy.proxy_cli import run_server
|
||||
|
||||
runner = CliRunner()
|
||||
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
mock_proxy_module = MagicMock(
|
||||
app=MagicMock(),
|
||||
ProxyConfig=MagicMock(),
|
||||
KeyManagementSettings=MagicMock(),
|
||||
save_worker_config=MagicMock(),
|
||||
)
|
||||
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
|
||||
return_value={
|
||||
"general_settings": {
|
||||
"database_url": "postgresql://test:test@localhost:5432/test",
|
||||
"database_connect_timeout": 15,
|
||||
"database_socket_timeout": 120,
|
||||
"database_extra_connection_params": {
|
||||
"pgbouncer": "true",
|
||||
"statement_cache_size": 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
clean_env = {
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k not in ("DATABASE_URL", "DIRECT_URL")
|
||||
}
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, clean_env, clear=True),
|
||||
patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"proxy_server": mock_proxy_module,
|
||||
"litellm.proxy.proxy_server": mock_proxy_module,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
||||
) as mock_get_args,
|
||||
patch(
|
||||
"litellm.proxy.proxy_cli.append_query_params",
|
||||
side_effect=lambda url, params: str(url),
|
||||
) as mock_append_query_params,
|
||||
):
|
||||
mock_get_args.return_value = {
|
||||
"app": "litellm.proxy.proxy_server:app",
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
}
|
||||
|
||||
result = runner.invoke(
|
||||
run_server,
|
||||
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
|
||||
)
|
||||
|
||||
assert (
|
||||
result.exit_code == 0
|
||||
), f"exit_code={result.exit_code}, output={result.output}"
|
||||
mock_append_query_params.assert_called()
|
||||
appended_params = mock_append_query_params.call_args.args[1]
|
||||
assert appended_params["connect_timeout"] == 15
|
||||
assert appended_params["socket_timeout"] == 120
|
||||
assert appended_params["pgbouncer"] == "true"
|
||||
assert appended_params["statement_cache_size"] == 0
|
||||
|
||||
@patch("uvicorn.run")
|
||||
@patch("atexit.register")
|
||||
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user