feat(proxy): add option to disable server-side prepared statements for DB lookups (#29984)
This commit is contained in:
parent
3bd3951e37
commit
dff25fef44
@ -2177,6 +2177,17 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||
"`statement_cache_size`). Keys here override any default LiteLLM sets."
|
||||
),
|
||||
)
|
||||
database_disable_prepared_statements: Optional[bool] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Disable server-side prepared statements by setting Prisma's "
|
||||
"`pgbouncer=true` URL param. Use this for pgbouncer transaction-pooling "
|
||||
"deployments, or to prevent the 'cached plan must not change result "
|
||||
"type' error that pooled connections hit during rolling schema "
|
||||
"migrations. An explicit `pgbouncer` in `database_extra_connection_params` "
|
||||
"takes precedence."
|
||||
),
|
||||
)
|
||||
database_type: Optional[Literal["dynamo_db"]] = Field(
|
||||
None, description="to use dynamodb instead of postgres db"
|
||||
)
|
||||
|
||||
@ -44,15 +44,19 @@ def _build_db_connection_url_params(
|
||||
pool_timeout: Optional[Union[int, float]],
|
||||
connect_timeout: Optional[Union[int, float]] = None,
|
||||
socket_timeout: Optional[Union[int, float]] = None,
|
||||
disable_prepared_statements: bool = False,
|
||||
extra_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Build the Prisma DATABASE_URL query params controlling connection pool behavior.
|
||||
|
||||
`connect_timeout` / `socket_timeout` map to the Prisma URL params of the same
|
||||
name (https://www.prisma.io/docs/orm/overview/databases/postgresql) and are
|
||||
omitted when None so Prisma's defaults apply. `extra_params` is an
|
||||
untyped passthrough — keys it provides win over the named arguments above,
|
||||
so it can be used to override any default we set here.
|
||||
omitted when None so Prisma's defaults apply. `disable_prepared_statements`
|
||||
sets `pgbouncer=true`, which makes Prisma stop using server-side prepared
|
||||
statements (pgbouncer transaction-pool compatible; also sidesteps the
|
||||
"cached plan must not change result type" error during rolling migrations).
|
||||
`extra_params` is an untyped passthrough — keys it provides win over the
|
||||
named arguments above, so it can be used to override any default we set here.
|
||||
"""
|
||||
params: dict = {
|
||||
"connection_limit": connection_limit,
|
||||
@ -63,6 +67,8 @@ def _build_db_connection_url_params(
|
||||
params["connect_timeout"] = connect_timeout
|
||||
if socket_timeout is not None:
|
||||
params["socket_timeout"] = socket_timeout
|
||||
if disable_prepared_statements:
|
||||
params["pgbouncer"] = "true"
|
||||
if extra_params:
|
||||
params.update(extra_params)
|
||||
return params
|
||||
@ -963,6 +969,7 @@ def run_server( # noqa: PLR0915
|
||||
db_connection_timeout: Optional[Union[int, float]] = 60
|
||||
db_connect_timeout: Optional[Union[int, float]] = None
|
||||
db_socket_timeout: Optional[Union[int, float]] = None
|
||||
db_disable_prepared_statements: bool = False
|
||||
db_extra_connection_params: Optional[dict] = None
|
||||
general_settings = {}
|
||||
### GET DB TOKEN FOR IAM AUTH ###
|
||||
@ -1083,6 +1090,17 @@ def run_server( # noqa: PLR0915
|
||||
)
|
||||
db_connect_timeout = general_settings.get("database_connect_timeout")
|
||||
db_socket_timeout = general_settings.get("database_socket_timeout")
|
||||
_disable_prepared_statements = general_settings.get(
|
||||
"database_disable_prepared_statements", False
|
||||
)
|
||||
if isinstance(_disable_prepared_statements, str):
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
db_disable_prepared_statements = (
|
||||
str_to_bool(_disable_prepared_statements) is True
|
||||
)
|
||||
else:
|
||||
db_disable_prepared_statements = bool(_disable_prepared_statements)
|
||||
db_extra_connection_params = general_settings.get(
|
||||
"database_extra_connection_params"
|
||||
)
|
||||
@ -1130,6 +1148,7 @@ def run_server( # noqa: PLR0915
|
||||
pool_timeout=db_connection_timeout,
|
||||
connect_timeout=db_connect_timeout,
|
||||
socket_timeout=db_socket_timeout,
|
||||
disable_prepared_statements=db_disable_prepared_statements,
|
||||
extra_params=db_extra_connection_params,
|
||||
)
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
|
||||
@ -795,6 +795,127 @@ class TestProxyInitializationHelpers:
|
||||
assert appended_params["pgbouncer"] == "true"
|
||||
assert appended_params["statement_cache_size"] == 0
|
||||
|
||||
def test_build_db_connection_url_params_disable_prepared_statements(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
disable_prepared_statements=True,
|
||||
)
|
||||
assert params["pgbouncer"] == "true"
|
||||
|
||||
def test_build_db_connection_url_params_no_pgbouncer_by_default(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
)
|
||||
assert "pgbouncer" not in params
|
||||
|
||||
def test_build_db_connection_url_params_extra_pgbouncer_overrides_flag(self):
|
||||
from litellm.proxy.proxy_cli import _build_db_connection_url_params
|
||||
|
||||
params = _build_db_connection_url_params(
|
||||
connection_limit=10,
|
||||
pool_timeout=60,
|
||||
disable_prepared_statements=True,
|
||||
extra_params={"pgbouncer": "false"},
|
||||
)
|
||||
assert params["pgbouncer"] == "false"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_value, expect_pgbouncer",
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
("true", True),
|
||||
("false", False),
|
||||
("not-a-bool", False),
|
||||
],
|
||||
)
|
||||
@patch("subprocess.run")
|
||||
@patch("atexit.register")
|
||||
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
||||
@patch(
|
||||
"litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False
|
||||
)
|
||||
def test_disable_prepared_statements_forwarded_to_url(
|
||||
self,
|
||||
mock_should_update,
|
||||
mock_setup_db,
|
||||
mock_atexit_register,
|
||||
mock_subprocess_run,
|
||||
config_value,
|
||||
expect_pgbouncer,
|
||||
):
|
||||
from click.testing import CliRunner
|
||||
|
||||
from litellm.proxy.proxy_cli import run_server
|
||||
|
||||
runner = CliRunner()
|
||||
mock_subprocess_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
mock_proxy_module = MagicMock(
|
||||
app=MagicMock(),
|
||||
ProxyConfig=MagicMock(),
|
||||
KeyManagementSettings=MagicMock(),
|
||||
save_worker_config=MagicMock(),
|
||||
)
|
||||
mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock(
|
||||
return_value={
|
||||
"general_settings": {
|
||||
"database_url": "postgresql://test:test@localhost:5432/test",
|
||||
"database_disable_prepared_statements": config_value,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
clean_env = {
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k not in ("DATABASE_URL", "DIRECT_URL")
|
||||
}
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, clean_env, clear=True),
|
||||
patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"proxy_server": mock_proxy_module,
|
||||
"litellm.proxy.proxy_server": mock_proxy_module,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
|
||||
) as mock_get_args,
|
||||
patch(
|
||||
"litellm.proxy.proxy_cli.append_query_params",
|
||||
side_effect=lambda url, params: str(url),
|
||||
) as mock_append_query_params,
|
||||
):
|
||||
mock_get_args.return_value = {
|
||||
"app": "litellm.proxy.proxy_server:app",
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
}
|
||||
|
||||
result = runner.invoke(
|
||||
run_server,
|
||||
["--local", "--config", "test-config.yaml", "--skip_server_startup"],
|
||||
)
|
||||
|
||||
assert (
|
||||
result.exit_code == 0
|
||||
), f"exit_code={result.exit_code}, output={result.output}"
|
||||
mock_append_query_params.assert_called()
|
||||
appended_params = mock_append_query_params.call_args.args[1]
|
||||
if expect_pgbouncer:
|
||||
assert appended_params["pgbouncer"] == "true"
|
||||
else:
|
||||
assert "pgbouncer" not in appended_params
|
||||
|
||||
@patch("uvicorn.run")
|
||||
@patch("atexit.register")
|
||||
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
|
||||
|
||||
5
ui/litellm-dashboard/src/lib/http/schema.d.ts
generated
vendored
5
ui/litellm-dashboard/src/lib/http/schema.d.ts
generated
vendored
@ -22012,6 +22012,11 @@ export interface components {
|
||||
* @default 60
|
||||
*/
|
||||
database_connection_timeout: number | null;
|
||||
/**
|
||||
* Database Disable Prepared Statements
|
||||
* @description Disable server-side prepared statements by setting Prisma's `pgbouncer=true` URL param. Use this for pgbouncer transaction-pooling deployments, or to prevent the 'cached plan must not change result type' error that pooled connections hit during rolling schema migrations. An explicit `pgbouncer` in `database_extra_connection_params` takes precedence.
|
||||
*/
|
||||
database_disable_prepared_statements?: boolean | null;
|
||||
/**
|
||||
* Database Extra Connection Params
|
||||
* @description Escape hatch: extra key/value pairs appended verbatim to the Prisma DATABASE_URL / DIRECT_URL query string (e.g. `sslmode`, `pgbouncer`, `statement_cache_size`). Keys here override any default LiteLLM sets.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user