diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 33a1e4179f..1b594e20d3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2177,6 +2177,17 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): "`statement_cache_size`). Keys here override any default LiteLLM sets." ), ) + database_disable_prepared_statements: Optional[bool] = Field( + None, + description=( + "Disable server-side prepared statements by setting Prisma's " + "`pgbouncer=true` URL param. Use this for pgbouncer transaction-pooling " + "deployments, or to prevent the 'cached plan must not change result " + "type' error that pooled connections hit during rolling schema " + "migrations. An explicit `pgbouncer` in `database_extra_connection_params` " + "takes precedence." + ), + ) database_type: Optional[Literal["dynamo_db"]] = Field( None, description="to use dynamodb instead of postgres db" ) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index ae831ef1b5..8c3fa95290 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -44,15 +44,19 @@ def _build_db_connection_url_params( pool_timeout: Optional[Union[int, float]], connect_timeout: Optional[Union[int, float]] = None, socket_timeout: Optional[Union[int, float]] = None, + disable_prepared_statements: bool = False, extra_params: Optional[dict] = None, ) -> dict: """Build the Prisma DATABASE_URL query params controlling connection pool behavior. `connect_timeout` / `socket_timeout` map to the Prisma URL params of the same name (https://www.prisma.io/docs/orm/overview/databases/postgresql) and are - omitted when None so Prisma's defaults apply. `extra_params` is an - untyped passthrough — keys it provides win over the named arguments above, - so it can be used to override any default we set here. + omitted when None so Prisma's defaults apply. `disable_prepared_statements` + sets `pgbouncer=true`, which makes Prisma stop using server-side prepared + statements (pgbouncer transaction-pool compatible; also sidesteps the + "cached plan must not change result type" error during rolling migrations). + `extra_params` is an untyped passthrough — keys it provides win over the + named arguments above, so it can be used to override any default we set here. """ params: dict = { "connection_limit": connection_limit, @@ -63,6 +67,8 @@ def _build_db_connection_url_params( params["connect_timeout"] = connect_timeout if socket_timeout is not None: params["socket_timeout"] = socket_timeout + if disable_prepared_statements: + params["pgbouncer"] = "true" if extra_params: params.update(extra_params) return params @@ -963,6 +969,7 @@ def run_server( # noqa: PLR0915 db_connection_timeout: Optional[Union[int, float]] = 60 db_connect_timeout: Optional[Union[int, float]] = None db_socket_timeout: Optional[Union[int, float]] = None + db_disable_prepared_statements: bool = False db_extra_connection_params: Optional[dict] = None general_settings = {} ### GET DB TOKEN FOR IAM AUTH ### @@ -1083,6 +1090,17 @@ def run_server( # noqa: PLR0915 ) db_connect_timeout = general_settings.get("database_connect_timeout") db_socket_timeout = general_settings.get("database_socket_timeout") + _disable_prepared_statements = general_settings.get( + "database_disable_prepared_statements", False + ) + if isinstance(_disable_prepared_statements, str): + from litellm.secret_managers.main import str_to_bool + + db_disable_prepared_statements = ( + str_to_bool(_disable_prepared_statements) is True + ) + else: + db_disable_prepared_statements = bool(_disable_prepared_statements) db_extra_connection_params = general_settings.get( "database_extra_connection_params" ) @@ -1130,6 +1148,7 @@ def run_server( # noqa: PLR0915 pool_timeout=db_connection_timeout, connect_timeout=db_connect_timeout, socket_timeout=db_socket_timeout, + disable_prepared_statements=db_disable_prepared_statements, extra_params=db_extra_connection_params, ) if os.getenv("DATABASE_URL", None) is not None: diff --git a/tests/test_litellm/proxy/test_proxy_cli.py b/tests/test_litellm/proxy/test_proxy_cli.py index 4fb725b7ef..34c88e2fd3 100644 --- a/tests/test_litellm/proxy/test_proxy_cli.py +++ b/tests/test_litellm/proxy/test_proxy_cli.py @@ -795,6 +795,127 @@ class TestProxyInitializationHelpers: assert appended_params["pgbouncer"] == "true" assert appended_params["statement_cache_size"] == 0 + def test_build_db_connection_url_params_disable_prepared_statements(self): + from litellm.proxy.proxy_cli import _build_db_connection_url_params + + params = _build_db_connection_url_params( + connection_limit=10, + pool_timeout=60, + disable_prepared_statements=True, + ) + assert params["pgbouncer"] == "true" + + def test_build_db_connection_url_params_no_pgbouncer_by_default(self): + from litellm.proxy.proxy_cli import _build_db_connection_url_params + + params = _build_db_connection_url_params( + connection_limit=10, + pool_timeout=60, + ) + assert "pgbouncer" not in params + + def test_build_db_connection_url_params_extra_pgbouncer_overrides_flag(self): + from litellm.proxy.proxy_cli import _build_db_connection_url_params + + params = _build_db_connection_url_params( + connection_limit=10, + pool_timeout=60, + disable_prepared_statements=True, + extra_params={"pgbouncer": "false"}, + ) + assert params["pgbouncer"] == "false" + + @pytest.mark.parametrize( + "config_value, expect_pgbouncer", + [ + (True, True), + (False, False), + ("true", True), + ("false", False), + ("not-a-bool", False), + ], + ) + @patch("subprocess.run") + @patch("atexit.register") + @patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database") + @patch( + "litellm.proxy.db.prisma_client.should_update_prisma_schema", return_value=False + ) + def test_disable_prepared_statements_forwarded_to_url( + self, + mock_should_update, + mock_setup_db, + mock_atexit_register, + mock_subprocess_run, + config_value, + expect_pgbouncer, + ): + from click.testing import CliRunner + + from litellm.proxy.proxy_cli import run_server + + runner = CliRunner() + mock_subprocess_run.return_value = MagicMock(returncode=0) + + mock_proxy_module = MagicMock( + app=MagicMock(), + ProxyConfig=MagicMock(), + KeyManagementSettings=MagicMock(), + save_worker_config=MagicMock(), + ) + mock_proxy_module.ProxyConfig.return_value.get_config = AsyncMock( + return_value={ + "general_settings": { + "database_url": "postgresql://test:test@localhost:5432/test", + "database_disable_prepared_statements": config_value, + } + } + ) + + clean_env = { + k: v + for k, v in os.environ.items() + if k not in ("DATABASE_URL", "DIRECT_URL") + } + + with ( + patch.dict(os.environ, clean_env, clear=True), + patch.dict( + "sys.modules", + { + "proxy_server": mock_proxy_module, + "litellm.proxy.proxy_server": mock_proxy_module, + }, + ), + patch( + "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args" + ) as mock_get_args, + patch( + "litellm.proxy.proxy_cli.append_query_params", + side_effect=lambda url, params: str(url), + ) as mock_append_query_params, + ): + mock_get_args.return_value = { + "app": "litellm.proxy.proxy_server:app", + "host": "localhost", + "port": 8000, + } + + result = runner.invoke( + run_server, + ["--local", "--config", "test-config.yaml", "--skip_server_startup"], + ) + + assert ( + result.exit_code == 0 + ), f"exit_code={result.exit_code}, output={result.output}" + mock_append_query_params.assert_called() + appended_params = mock_append_query_params.call_args.args[1] + if expect_pgbouncer: + assert appended_params["pgbouncer"] == "true" + else: + assert "pgbouncer" not in appended_params + @patch("uvicorn.run") @patch("atexit.register") @patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database") diff --git a/ui/litellm-dashboard/src/lib/http/schema.d.ts b/ui/litellm-dashboard/src/lib/http/schema.d.ts index 203a56f615..8e47081955 100644 --- a/ui/litellm-dashboard/src/lib/http/schema.d.ts +++ b/ui/litellm-dashboard/src/lib/http/schema.d.ts @@ -22012,6 +22012,11 @@ export interface components { * @default 60 */ database_connection_timeout: number | null; + /** + * Database Disable Prepared Statements + * @description Disable server-side prepared statements by setting Prisma's `pgbouncer=true` URL param. Use this for pgbouncer transaction-pooling deployments, or to prevent the 'cached plan must not change result type' error that pooled connections hit during rolling schema migrations. An explicit `pgbouncer` in `database_extra_connection_params` takes precedence. + */ + database_disable_prepared_statements?: boolean | null; /** * Database Extra Connection Params * @description Escape hatch: extra key/value pairs appended verbatim to the Prisma DATABASE_URL / DIRECT_URL query string (e.g. `sslmode`, `pgbouncer`, `statement_cache_size`). Keys here override any default LiteLLM sets.