fix(proxy): expose Prisma idle/connect timeout + extra DB URL params (#28395)

* fix(proxy): expose Prisma idle/connect timeout + extra DB URL params

Operators have reported large numbers of idle Prisma connections that
never get closed. The proxy already forwards `connection_limit` and
`pool_timeout` to the DATABASE_URL, but had no knob for capping idle
or slow connections. Add three new `general_settings` keys that thread
through to the DATABASE_URL / DIRECT_URL query string:

- `database_connect_timeout`  -> Prisma `connect_timeout`
- `database_socket_timeout`   -> Prisma `socket_timeout` (the main
  knob for closing idle connections from the LiteLLM side)
- `database_extra_connection_params` -> untyped passthrough dict for
  any other Prisma URL param (`pgbouncer`, `statement_cache_size`,
  `sslmode`, ...); keys here override LiteLLM defaults.

Refactors the duplicated DATABASE_URL/DIRECT_URL param dicts into a
single `_build_db_connection_url_params` helper.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Update litellm/proxy/proxy_cli.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
Yassin Kortam 2026-05-20 17:19:24 -07:00 committed by GitHub
parent 8acf64e16c
commit 2f9ac77b24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 203 additions and 13 deletions

View File

@ -2361,6 +2361,30 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
database_connection_timeout: Optional[float] = Field(
60, description="default timeout for a connection to the database"
)
database_connect_timeout: Optional[float] = Field(
None,
description=(
"Prisma `connect_timeout` URL param (seconds). Bounds how long the "
"engine waits to establish a new connection before failing. Defaults "
"to Prisma's built-in value when unset."
),
)
database_socket_timeout: Optional[float] = Field(
None,
description=(
"Prisma `socket_timeout` URL param (seconds). When set, an idle/slow "
"connection that has not produced data within this window is closed. "
"This is the main knob for capping idle DB connections from LiteLLM."
),
)
database_extra_connection_params: Optional[Dict[str, Any]] = Field(
None,
description=(
"Escape hatch: extra key/value pairs appended verbatim to the Prisma "
"DATABASE_URL / DIRECT_URL query string (e.g. `sslmode`, `pgbouncer`, "
"`statement_cache_size`). Keys here override any default LiteLLM sets."
),
)
database_type: Optional[Literal["dynamo_db"]] = Field(
None, description="to use dynamodb instead of postgres db"
)

View File

@ -38,6 +38,35 @@ class LiteLLMDatabaseConnectionPool(Enum):
database_connection_pool_timeout = 60
def _build_db_connection_url_params(
connection_limit: int,
pool_timeout: Optional[Union[int, float]],
connect_timeout: Optional[Union[int, float]] = None,
socket_timeout: Optional[Union[int, float]] = None,
extra_params: Optional[dict] = None,
) -> dict:
"""Build the Prisma DATABASE_URL query params controlling connection pool behavior.
`connect_timeout` / `socket_timeout` map to the Prisma URL params of the same
name (https://www.prisma.io/docs/orm/overview/databases/postgresql) and are
omitted when None so Prisma's defaults apply. `extra_params` is an
untyped passthrough keys it provides win over the named arguments above,
so it can be used to override any default we set here.
"""
params: dict = {
"connection_limit": connection_limit,
}
if pool_timeout is not None:
params["pool_timeout"] = pool_timeout
if connect_timeout is not None:
params["connect_timeout"] = connect_timeout
if socket_timeout is not None:
params["socket_timeout"] = socket_timeout
if extra_params:
params.update(extra_params)
return params
def append_query_params(url: Optional[str], params: dict) -> str:
from litellm._logging import verbose_proxy_logger
@ -807,6 +836,9 @@ def run_server( # noqa: PLR0915
db_connection_pool_limit = 100
# Starts optional due to config fallback checks; guaranteed non-None before use.
db_connection_timeout: Optional[Union[int, float]] = 60
db_connect_timeout: Optional[Union[int, float]] = None
db_socket_timeout: Optional[Union[int, float]] = None
db_extra_connection_params: Optional[dict] = None
general_settings = {}
### GET DB TOKEN FOR IAM AUTH ###
@ -924,6 +956,11 @@ def run_server( # noqa: PLR0915
db_connection_timeout = (
LiteLLMDatabaseConnectionPool.database_connection_pool_timeout.value
)
db_connect_timeout = general_settings.get("database_connect_timeout")
db_socket_timeout = general_settings.get("database_socket_timeout")
db_extra_connection_params = general_settings.get(
"database_extra_connection_params"
)
if database_url and database_url.startswith("os.environ/"):
original_dir = os.getcwd()
# set the working directory to where this script is
@ -963,27 +1000,26 @@ def run_server( # noqa: PLR0915
try:
from litellm.secret_managers.main import get_secret
connection_url_params = _build_db_connection_url_params(
connection_limit=db_connection_pool_limit,
pool_timeout=db_connection_timeout,
connect_timeout=db_connect_timeout,
socket_timeout=db_socket_timeout,
extra_params=db_extra_connection_params,
)
if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args
params = {
"connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout,
}
database_url = get_secret("DATABASE_URL", default_value=None)
modified_url = append_query_params(
str(database_url) if database_url else None, params
str(database_url) if database_url else None,
connection_url_params,
)
os.environ["DATABASE_URL"] = modified_url
if os.getenv("DIRECT_URL", None) is not None:
### add connection pool + pool timeout args
params = {
"connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout,
}
database_url = os.getenv("DIRECT_URL")
modified_url = append_query_params(database_url, params)
modified_url = append_query_params(
database_url, connection_url_params
)
os.environ["DIRECT_URL"] = modified_url
###
subprocess.run(["prisma"], capture_output=True)
is_prisma_runnable = True
except FileNotFoundError:

View File

@ -483,6 +483,136 @@ class TestProxyInitializationHelpers:
assert appended_params["connection_limit"] == 5
assert appended_params["pool_timeout"] == expected_timeout
def test_build_db_connection_url_params_defaults(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(connection_limit=10, pool_timeout=60)
assert params == {"connection_limit": 10, "pool_timeout": 60}
def test_build_db_connection_url_params_omits_none_timeouts(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
connect_timeout=None,
socket_timeout=None,
)
assert "connect_timeout" not in params
assert "socket_timeout" not in params
def test_build_db_connection_url_params_includes_optional_timeouts(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
connect_timeout=15,
socket_timeout=120,
)
assert params["connect_timeout"] == 15
assert params["socket_timeout"] == 120
def test_build_db_connection_url_params_extras_override_defaults(self):
from litellm.proxy.proxy_cli import _build_db_connection_url_params
params = _build_db_connection_url_params(
connection_limit=10,
pool_timeout=60,
extra_params={
"pgbouncer": "true",
"statement_cache_size": 0,
"pool_timeout": 5,
},
)
assert params["pgbouncer"] == "true"
assert params["statement_cache_size"] == 0
assert params["pool_timeout"] == 5
@patch("subprocess.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
@patch(
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
)
def test_db_connection_extra_params_forwarded_to_url(
self,
mock_should_update,
mock_setup_db,
mock_atexit_register,
mock_subprocess_run,
):
from click.testing import CliRunner
from litellm.proxy.proxy_cli import run_server
runner = CliRunner()
mock_subprocess_run.return_value = MagicMock(returncode=0)
mock_proxy_module = MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
return_value={
"general_settings": {
"database_url": "postgresql://test:test@localhost:5432/test",
"database_connect_timeout": 15,
"database_socket_timeout": 120,
"database_extra_connection_params": {
"pgbouncer": "true",
"statement_cache_size": 0,
},
}
}
)
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": mock_proxy_module,
"litellm.proxy.proxy_server": mock_proxy_module,
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
patch(
"litellm.proxy.proxy_cli.append_query_params",
side_effect=lambda url, params: str(url),
) as mock_append_query_params,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}
result = runner.invoke(
run_server,
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
)
assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_append_query_params.assert_called()
appended_params = mock_append_query_params.call_args.args[1]
assert appended_params["connect_timeout"] == 15
assert appended_params["socket_timeout"] == 120
assert appended_params["pgbouncer"] == "true"
assert appended_params["statement_cache_size"] == 0
@patch("uvicorn.run")
@patch("atexit.register")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")