feat(proxy): native /health/drain preStop hook for graceful shutdown (#29439)
This commit is contained in:
parent
a5ccd96152
commit
3a1c6bba97
@ -36,6 +36,7 @@ jobs:
|
||||
tests/test_litellm/proxy/a2a
|
||||
tests/test_litellm/proxy/discovery_endpoints
|
||||
tests/test_litellm/proxy/health_endpoints
|
||||
tests/test_litellm/proxy/shutdown
|
||||
tests/test_litellm/proxy/public_endpoints
|
||||
tests/test_litellm/proxy/prompts
|
||||
tests/test_litellm/proxy/rag_endpoints
|
||||
|
||||
2
Makefile
2
Makefile
@ -146,7 +146,7 @@ test-unit-proxy-core: install-test-deps
|
||||
$(UV_RUN) pytest tests/test_litellm/proxy/auth tests/test_litellm/proxy/client tests/test_litellm/proxy/db tests/test_litellm/proxy/hooks tests/test_litellm/proxy/policy_engine --tb=short -vv -n 4 --durations=20
|
||||
|
||||
test-unit-proxy-misc: install-test-deps
|
||||
$(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20
|
||||
$(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/shutdown tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20
|
||||
|
||||
test-unit-integrations: install-test-deps
|
||||
$(UV_RUN) pytest tests/test_litellm/integrations --tb=short -vv -n 4 --durations=20
|
||||
|
||||
@ -285,11 +285,31 @@ db:
|
||||
deployStandalone: true
|
||||
|
||||
# Lifecycle hooks for the LiteLLM container
|
||||
#
|
||||
# Prefer the native /health/drain preStop hook over a fixed `sleep`: it marks
|
||||
# the pod NotReady and blocks only until in-flight requests actually finish
|
||||
# (bounded by GRACEFUL_SHUTDOWN_TIMEOUT, default 30s), instead of always
|
||||
# waiting the worst-case duration. The drain runs once (the preStop hook and
|
||||
# the SIGTERM handler share it), so set terminationGracePeriodSeconds a few
|
||||
# seconds above GRACEFUL_SHUTDOWN_TIMEOUT to leave room for teardown before
|
||||
# SIGKILL.
|
||||
#
|
||||
# /health/drain is off by default; enable it with
|
||||
# general_settings.enable_drain_endpoint: true. The kubelet calls preStop
|
||||
# hooks without proxy credentials, so when the health port is reachable from
|
||||
# other pods (the common case) also set
|
||||
# general_settings.drain_endpoint_token (or the DRAIN_ENDPOINT_TOKEN env
|
||||
# var) and send the same value on the X-Drain-Token header from the hook.
|
||||
# Calls missing/wrong the token get a 401 and have no side effect.
|
||||
# Example:
|
||||
# lifecycle:
|
||||
# preStop:
|
||||
# exec:
|
||||
# command: ["/bin/sh", "-c", "sleep 10"]
|
||||
# httpGet:
|
||||
# path: /health/drain
|
||||
# port: 4000
|
||||
# httpHeaders:
|
||||
# - name: X-Drain-Token
|
||||
# value: <same value as drain_endpoint_token>
|
||||
lifecycle: {}
|
||||
|
||||
# Settings for Bitnami postgresql chart (if db.deployStandalone is true, ignored
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
@ -39,6 +40,7 @@ from litellm.proxy.health_check import (
|
||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||
get_in_flight_requests,
|
||||
)
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||
|
||||
#### Health ENDPOINTS ####
|
||||
|
||||
@ -1551,6 +1553,50 @@ def _allow_public_health_readiness_details() -> bool:
|
||||
return general_settings.get("allow_public_health_readiness_details") is True
|
||||
|
||||
|
||||
def _drain_endpoint_enabled() -> bool:
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
return general_settings.get("enable_drain_endpoint") is True
|
||||
|
||||
|
||||
def _drain_endpoint_token() -> Optional[str]:
|
||||
"""
|
||||
Shared secret required on the X-Drain-Token header to call /health/drain.
|
||||
|
||||
Falls back to the ``DRAIN_ENDPOINT_TOKEN`` env var when unset in
|
||||
general_settings so the kubelet preStop hook can supply it via
|
||||
``valueFrom.secretKeyRef`` without a config reload.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
token = general_settings.get("drain_endpoint_token")
|
||||
if isinstance(token, str) and token:
|
||||
return token
|
||||
env_token = os.getenv("DRAIN_ENDPOINT_TOKEN")
|
||||
if env_token:
|
||||
return env_token
|
||||
return None
|
||||
|
||||
|
||||
def _authorize_drain_request(request: Request) -> None:
|
||||
"""
|
||||
Reject /health/drain calls that don't carry the configured X-Drain-Token.
|
||||
|
||||
When no token is configured the endpoint is treated as already opted-in
|
||||
(the ``enable_drain_endpoint`` flag is the only gate). Comparison uses
|
||||
``secrets.compare_digest`` to avoid timing leaks.
|
||||
"""
|
||||
expected = _drain_endpoint_token()
|
||||
if expected is None:
|
||||
return
|
||||
supplied = request.headers.get("x-drain-token") or ""
|
||||
if not secrets.compare_digest(supplied, expected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing X-Drain-Token",
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_public_readiness_db(response: Response) -> str:
|
||||
"""
|
||||
Return the db status string for the public probe and flip the response to
|
||||
@ -1580,6 +1626,10 @@ async def health_readiness(response: Response):
|
||||
credential. Admins can opt into the legacy detailed payload with
|
||||
general_settings.allow_public_health_readiness_details.
|
||||
"""
|
||||
if GracefulShutdownManager.is_shutting_down():
|
||||
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
return {"status": "shutting_down"}
|
||||
|
||||
if _allow_public_health_readiness_details():
|
||||
return await _get_health_readiness_details(response=response)
|
||||
|
||||
@ -1616,6 +1666,54 @@ async def health_backlog():
|
||||
return {"in_flight_requests": get_in_flight_requests()}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health/drain",
|
||||
tags=["health"],
|
||||
)
|
||||
async def health_drain(request: Request):
|
||||
"""
|
||||
Graceful-drain probe for Kubernetes ``preStop`` hooks.
|
||||
|
||||
Disabled by default and returns 404 unless ``general_settings`` sets
|
||||
``enable_drain_endpoint: true``. Calling it flips a process-wide
|
||||
shutting-down flag, so a successful call permanently takes the worker out
|
||||
of rotation until the pod restarts.
|
||||
|
||||
Because the kubelet calls preStop hooks without proxy credentials, the
|
||||
endpoint does not require ``user_api_key_auth``. To prevent any
|
||||
pod-reachable caller from triggering shutdown, set
|
||||
``general_settings.drain_endpoint_token`` (or the ``DRAIN_ENDPOINT_TOKEN``
|
||||
env var) and supply the same value on the ``X-Drain-Token`` header from
|
||||
the preStop hook. Calls without the header (or with a wrong value) get a
|
||||
401 and have no side effect.
|
||||
|
||||
When enabled, it marks the worker as shutting down (so /health/readiness
|
||||
and /health/liveliness immediately start returning 503, removing the pod
|
||||
from service) and blocks until the in-flight request counter drains to
|
||||
zero or ``GRACEFUL_SHUTDOWN_TIMEOUT`` elapses. Unlike a fixed ``sleep``,
|
||||
this returns as soon as real in-flight work is done.
|
||||
|
||||
Wire it up as:
|
||||
|
||||
```yaml
|
||||
lifecycle:
|
||||
preStop:
|
||||
httpGet:
|
||||
path: /health/drain
|
||||
port: 4000
|
||||
httpHeaders:
|
||||
- name: X-Drain-Token
|
||||
value: <same value as drain_endpoint_token>
|
||||
```
|
||||
"""
|
||||
if not _drain_endpoint_enabled():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
|
||||
_authorize_drain_request(request)
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(exclude_self=True)
|
||||
return {"status": "drained", "drained_requests": drained}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health/liveliness", # Historical LiteLLM name; doesn't match k8s terminology but kept for backwards compatibility
|
||||
tags=["health"],
|
||||
@ -1624,10 +1722,16 @@ async def health_backlog():
|
||||
"/health/liveness", # Kubernetes has "liveness" probes (https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/#define-a-liveness-command)
|
||||
tags=["health"],
|
||||
)
|
||||
async def health_liveliness():
|
||||
async def health_liveliness(response: Response):
|
||||
"""
|
||||
Unprotected endpoint for checking if worker is alive
|
||||
Unprotected endpoint for checking if worker is alive.
|
||||
|
||||
Returns 503 once graceful shutdown has begun so Kubernetes stops counting
|
||||
the draining pod as live and terminates it on schedule.
|
||||
"""
|
||||
if GracefulShutdownManager.is_shutting_down():
|
||||
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
return {"status": "shutting_down"}
|
||||
return "I'm alive!"
|
||||
|
||||
|
||||
|
||||
@ -417,6 +417,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
|
||||
from litellm.proxy.middleware.request_size_limit_middleware import (
|
||||
RequestSizeLimitMiddleware,
|
||||
)
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||
from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
router as openai_files_router,
|
||||
@ -973,6 +974,11 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
|
||||
# End of startup event
|
||||
yield
|
||||
|
||||
# Shutdown event - drain in-flight requests before tearing down dependencies
|
||||
# so SIGTERM (rolling update, scale-down, liveness kill) doesn't drop them.
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
await GracefulShutdownManager.wait_for_drain()
|
||||
|
||||
# Shutdown event - close shared aiohttp session
|
||||
if shared_aiohttp_session is not None:
|
||||
try:
|
||||
|
||||
0
litellm/proxy/shutdown/__init__.py
Normal file
0
litellm/proxy/shutdown/__init__.py
Normal file
174
litellm/proxy/shutdown/graceful_shutdown_manager.py
Normal file
174
litellm/proxy/shutdown/graceful_shutdown_manager.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Application-level graceful shutdown coordination for the LiteLLM proxy.
|
||||
|
||||
Kubernetes terminates a pod by sending ``SIGTERM`` and, after
|
||||
``terminationGracePeriodSeconds``, ``SIGKILL``. By default LiteLLM delegates
|
||||
the signal to uvicorn and tears down immediately, dropping any in-flight
|
||||
requests (streaming, batch inference, long-lived calls).
|
||||
|
||||
A fixed ``preStop`` sleep can not solve this: it has to be sized for the
|
||||
*worst-case* request, so it either wastes time on every routine shutdown or is
|
||||
too short for a long-running request. This manager instead drains based on the
|
||||
*actual* in-flight request counter (already tracked by
|
||||
``InFlightRequestsMiddleware``), so a pod terminates as soon as its real
|
||||
in-flight work is done — and never waits longer than ``GRACEFUL_SHUTDOWN_TIMEOUT``.
|
||||
|
||||
The state is process-scoped (class-level), matching the per-uvicorn-worker
|
||||
granularity of ``InFlightRequestsMiddleware``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||
get_in_flight_requests,
|
||||
)
|
||||
|
||||
# Keep below terminationGracePeriodSeconds so the process exits before SIGKILL.
|
||||
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30.0
|
||||
_DRAIN_POLL_INTERVAL = 0.1
|
||||
_DRAIN_LOG_INTERVAL = 5.0
|
||||
|
||||
|
||||
class GracefulShutdownManager:
|
||||
"""
|
||||
Process-scoped singleton that tracks whether the worker is draining and
|
||||
blocks until in-flight requests reach zero (or a timeout elapses).
|
||||
"""
|
||||
|
||||
_is_shutting_down: bool = False
|
||||
_shutdown_started_at: Optional[float] = None
|
||||
_drain_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def is_shutting_down(cls) -> bool:
|
||||
"""Whether this worker has begun graceful shutdown."""
|
||||
return cls._is_shutting_down
|
||||
|
||||
@classmethod
|
||||
def get_timeout(cls) -> float:
|
||||
"""
|
||||
Read GRACEFUL_SHUTDOWN_TIMEOUT (seconds) from the environment on each
|
||||
call so deployments can tune it without code changes. Falls back to the
|
||||
default on an unset or malformed value.
|
||||
"""
|
||||
raw = os.getenv("GRACEFUL_SHUTDOWN_TIMEOUT")
|
||||
if raw is None:
|
||||
return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError):
|
||||
verbose_proxy_logger.warning(
|
||||
"GRACEFUL_SHUTDOWN_TIMEOUT=%r is not a number; using default %ss",
|
||||
raw,
|
||||
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||
)
|
||||
return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
|
||||
@classmethod
|
||||
def start_shutdown(cls) -> None:
|
||||
"""
|
||||
Mark the worker as draining. Idempotent — repeated calls (e.g. SIGTERM
|
||||
followed by a preStop hit on /health/drain) do not reset the clock.
|
||||
"""
|
||||
if cls._is_shutting_down:
|
||||
return
|
||||
cls._is_shutting_down = True
|
||||
cls._shutdown_started_at = time.monotonic()
|
||||
verbose_proxy_logger.info(
|
||||
"graceful_shutdown_started in_flight_requests=%s",
|
||||
get_in_flight_requests(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def wait_for_drain(
|
||||
cls,
|
||||
timeout: Optional[float] = None,
|
||||
exclude_self: bool = False,
|
||||
count_fn: Optional[Callable[[], int]] = None,
|
||||
poll_interval: float = _DRAIN_POLL_INTERVAL,
|
||||
log_interval: float = _DRAIN_LOG_INTERVAL,
|
||||
) -> int:
|
||||
"""
|
||||
Poll the in-flight request counter until it reaches the drain target or
|
||||
``timeout`` seconds elapse.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait. Defaults to ``get_timeout()``.
|
||||
exclude_self: When the caller is itself an in-flight HTTP request
|
||||
(the /health/drain endpoint), set this so the caller's own
|
||||
request is not counted as outstanding work.
|
||||
count_fn: Source of the current in-flight count. Defaults to the
|
||||
live ``InFlightRequestsMiddleware`` counter; injectable for tests.
|
||||
poll_interval: Seconds between counter polls.
|
||||
log_interval: Minimum seconds between ``drain_waiting`` log lines.
|
||||
|
||||
Returns:
|
||||
Number of requests that drained while waiting (>= 0).
|
||||
"""
|
||||
# A preStop /health/drain hook and the lifespan SIGTERM handler both
|
||||
# drain; once one has run, the other must not wait again, otherwise the
|
||||
# effective window is 2x the timeout and terminationGracePeriodSeconds
|
||||
# has to be doubled to avoid a mid-drain SIGKILL.
|
||||
if cls._drain_performed:
|
||||
return 0
|
||||
cls._drain_performed = True
|
||||
|
||||
if timeout is None:
|
||||
timeout = cls.get_timeout()
|
||||
if count_fn is None:
|
||||
count_fn = get_in_flight_requests
|
||||
|
||||
# The /health/drain HTTP request flows through InFlightRequestsMiddleware
|
||||
# and so counts itself; treat <=1 as "drained" in that case.
|
||||
target = 1 if exclude_self else 0
|
||||
|
||||
start = time.monotonic()
|
||||
initial = count_fn()
|
||||
last_log = start
|
||||
|
||||
if timeout <= 0:
|
||||
return max(0, initial - target)
|
||||
|
||||
while True:
|
||||
current = count_fn()
|
||||
if current <= target:
|
||||
drained = max(0, initial - current)
|
||||
verbose_proxy_logger.info(
|
||||
"graceful_shutdown_complete drained_requests=%s elapsed_s=%.2f",
|
||||
drained,
|
||||
time.monotonic() - start,
|
||||
)
|
||||
return drained
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
verbose_proxy_logger.warning(
|
||||
"graceful_shutdown_timeout in_flight_requests=%s elapsed_s=%.2f "
|
||||
"timeout_s=%s — proceeding with teardown",
|
||||
current,
|
||||
elapsed,
|
||||
timeout,
|
||||
)
|
||||
return max(0, initial - current)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_log >= log_interval:
|
||||
verbose_proxy_logger.info(
|
||||
"drain_waiting in_flight_requests=%s elapsed_s=%.2f",
|
||||
current,
|
||||
elapsed,
|
||||
)
|
||||
last_log = now
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Reset state. Intended for use in tests."""
|
||||
cls._is_shutting_down = False
|
||||
cls._shutdown_started_at = None
|
||||
cls._drain_performed = False
|
||||
@ -14,7 +14,6 @@ import pytest
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
_PROXY_MODULE_GLOBALS_TO_ISOLATE = (
|
||||
"master_key",
|
||||
"prisma_client",
|
||||
@ -49,6 +48,18 @@ def _isolate_proxy_module_globals():
|
||||
setattr(proxy_server, name, value)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_graceful_shutdown_state():
|
||||
"""Graceful shutdown state is process-scoped; keep it from leaking between tests."""
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import (
|
||||
GracefulShutdownManager,
|
||||
)
|
||||
|
||||
GracefulShutdownManager.reset()
|
||||
yield
|
||||
GracefulShutdownManager.reset()
|
||||
|
||||
|
||||
def build_cache_config(enable_cache: bool = True) -> Optional[Dict]:
|
||||
"""
|
||||
Build Redis cache configuration from environment variables.
|
||||
|
||||
@ -0,0 +1,166 @@
|
||||
"""
|
||||
Behaviour tests for the graceful-shutdown health probes.
|
||||
|
||||
Builds a minimal FastAPI app from the health router plus
|
||||
InFlightRequestsMiddleware so the probe responses can be asserted without
|
||||
standing up the full proxy.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router
|
||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||
InFlightRequestsMiddleware,
|
||||
)
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset():
|
||||
GracefulShutdownManager.reset()
|
||||
InFlightRequestsMiddleware._in_flight = 0
|
||||
yield
|
||||
GracefulShutdownManager.reset()
|
||||
InFlightRequestsMiddleware._in_flight = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
app.add_middleware(InFlightRequestsMiddleware)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_drain(monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(
|
||||
proxy_server, "general_settings", {"enable_drain_endpoint": True}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_drain_with_token(monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(
|
||||
proxy_server,
|
||||
"general_settings",
|
||||
{"enable_drain_endpoint": True, "drain_endpoint_token": "secret-123"},
|
||||
)
|
||||
|
||||
|
||||
def test_drain_disabled_by_default_returns_404_with_no_side_effect(client, monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(proxy_server, "general_settings", {})
|
||||
resp = client.get("/health/drain")
|
||||
assert resp.status_code == 404
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
def test_drain_disabled_ignores_token_header(client, monkeypatch):
|
||||
"""A token alone must not bypass the enable flag; otherwise enabling the
|
||||
token side-channel would silently enable the endpoint."""
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(
|
||||
proxy_server, "general_settings", {"drain_endpoint_token": "secret-123"}
|
||||
)
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"})
|
||||
assert resp.status_code == 404
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
def test_drain_when_enabled_without_token_sets_shutting_down_and_returns_drained(
|
||||
client, enable_drain
|
||||
):
|
||||
resp = client.get("/health/drain")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "drained"
|
||||
assert body["drained_requests"] == 0
|
||||
assert GracefulShutdownManager.is_shutting_down() is True
|
||||
|
||||
|
||||
def test_drain_with_token_configured_rejects_missing_header(
|
||||
client, enable_drain_with_token
|
||||
):
|
||||
resp = client.get("/health/drain")
|
||||
assert resp.status_code == 401
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
def test_drain_with_token_configured_rejects_wrong_header(
|
||||
client, enable_drain_with_token
|
||||
):
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "wrong-value"})
|
||||
assert resp.status_code == 401
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
def test_drain_with_token_configured_accepts_correct_header(
|
||||
client, enable_drain_with_token
|
||||
):
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "drained"
|
||||
assert GracefulShutdownManager.is_shutting_down() is True
|
||||
|
||||
|
||||
def test_drain_with_token_from_env_var(client, enable_drain, monkeypatch):
|
||||
monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token")
|
||||
resp = client.get("/health/drain")
|
||||
assert resp.status_code == 401
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_drain_general_settings_token_overrides_env_var(client, monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(
|
||||
proxy_server,
|
||||
"general_settings",
|
||||
{"enable_drain_endpoint": True, "drain_endpoint_token": "config-token"},
|
||||
)
|
||||
monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token")
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"})
|
||||
assert resp.status_code == 401
|
||||
resp = client.get("/health/drain", headers={"X-Drain-Token": "config-token"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_readiness_returns_503_shutting_down_during_drain(client):
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
resp = client.get("/health/readiness")
|
||||
assert resp.status_code == 503
|
||||
assert resp.json() == {"status": "shutting_down"}
|
||||
|
||||
|
||||
def test_readiness_does_not_report_shutting_down_normally(client):
|
||||
resp = client.get("/health/readiness")
|
||||
assert resp.json().get("status") != "shutting_down"
|
||||
|
||||
|
||||
def test_liveliness_returns_503_during_drain(client):
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
resp = client.get("/health/liveliness")
|
||||
assert resp.status_code == 503
|
||||
assert resp.json() == {"status": "shutting_down"}
|
||||
|
||||
|
||||
def test_liveness_alias_returns_503_during_drain(client):
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
resp = client.get("/health/liveness")
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
def test_liveliness_returns_alive_when_not_shutting_down(client):
|
||||
resp = client.get("/health/liveliness")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == "I'm alive!"
|
||||
@ -0,0 +1,181 @@
|
||||
"""
|
||||
Tests for GracefulShutdownManager.
|
||||
|
||||
These verify the drain logic that lets a pod terminate as soon as its real
|
||||
in-flight work is done (bounded by GRACEFUL_SHUTDOWN_TIMEOUT) rather than
|
||||
sleeping for a fixed worst-case duration.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.shutdown.graceful_shutdown_manager import (
|
||||
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||
GracefulShutdownManager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset():
|
||||
GracefulShutdownManager.reset()
|
||||
yield
|
||||
GracefulShutdownManager.reset()
|
||||
|
||||
|
||||
def _counter_that_drains_after(calls_before_zero: int):
|
||||
"""Return a count_fn that reports N in-flight until it has been polled
|
||||
`calls_before_zero` times, then reports 0."""
|
||||
state = {"polls": 0}
|
||||
|
||||
def count_fn() -> int:
|
||||
state["polls"] += 1
|
||||
return 0 if state["polls"] > calls_before_zero else 3
|
||||
|
||||
return count_fn
|
||||
|
||||
|
||||
# ── shutdown flag ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_not_shutting_down_by_default():
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
def test_start_shutdown_sets_flag():
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
assert GracefulShutdownManager.is_shutting_down() is True
|
||||
|
||||
|
||||
def test_start_shutdown_is_idempotent_and_does_not_reset_clock():
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
first = GracefulShutdownManager._shutdown_started_at
|
||||
time.sleep(0.01)
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
assert GracefulShutdownManager._shutdown_started_at == first
|
||||
|
||||
|
||||
def test_reset_clears_flag():
|
||||
GracefulShutdownManager.start_shutdown()
|
||||
GracefulShutdownManager.reset()
|
||||
assert GracefulShutdownManager.is_shutting_down() is False
|
||||
|
||||
|
||||
# ── timeout config ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_timeout_defaults_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False)
|
||||
assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
|
||||
|
||||
def test_timeout_reads_env(monkeypatch):
|
||||
monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "5")
|
||||
assert GracefulShutdownManager.get_timeout() == 5.0
|
||||
|
||||
|
||||
def test_timeout_falls_back_on_garbage(monkeypatch):
|
||||
monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "not-a-number")
|
||||
assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||
|
||||
|
||||
# ── wait_for_drain ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_immediately_when_already_drained():
|
||||
start = time.monotonic()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=10, count_fn=lambda: 0
|
||||
)
|
||||
assert drained == 0
|
||||
assert time.monotonic() - start < 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_waits_until_counter_reaches_zero_then_returns_drained_count():
|
||||
count_fn = _counter_that_drains_after(calls_before_zero=3)
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=10, count_fn=count_fn
|
||||
)
|
||||
assert drained == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_times_out_when_counter_never_drains():
|
||||
start = time.monotonic()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=0.3, count_fn=lambda: 2
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
assert 0.3 <= elapsed < 2.0
|
||||
assert drained == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_timeout_does_not_block():
|
||||
start = time.monotonic()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=0, count_fn=lambda: 5
|
||||
)
|
||||
assert time.monotonic() - start < 0.2
|
||||
assert drained == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exclude_self_treats_one_inflight_as_drained():
|
||||
"""The /health/drain request counts itself, so a steady count of 1 must be
|
||||
treated as fully drained rather than timing out."""
|
||||
start = time.monotonic()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=5, exclude_self=True, count_fn=lambda: 1
|
||||
)
|
||||
assert time.monotonic() - start < 0.5
|
||||
assert drained == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_without_exclude_self_one_inflight_blocks_until_timeout():
|
||||
start = time.monotonic()
|
||||
await GracefulShutdownManager.wait_for_drain(timeout=0.3, count_fn=lambda: 1)
|
||||
assert time.monotonic() - start >= 0.3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_to_get_timeout_and_live_counter(monkeypatch):
|
||||
"""With no timeout/count_fn passed, it falls back to get_timeout() and the
|
||||
live InFlightRequestsMiddleware counter."""
|
||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||
InFlightRequestsMiddleware,
|
||||
)
|
||||
|
||||
monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False)
|
||||
InFlightRequestsMiddleware._in_flight = 0
|
||||
drained = await GracefulShutdownManager.wait_for_drain()
|
||||
assert drained == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_drain_is_a_noop_so_window_is_not_doubled():
|
||||
"""preStop /health/drain and the lifespan SIGTERM handler both drain; the
|
||||
second call must return immediately rather than waiting another full
|
||||
timeout (which would require doubling terminationGracePeriodSeconds)."""
|
||||
await GracefulShutdownManager.wait_for_drain(timeout=0.2, count_fn=lambda: 1)
|
||||
|
||||
start = time.monotonic()
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=5, count_fn=lambda: 1
|
||||
)
|
||||
assert time.monotonic() - start < 0.1
|
||||
assert drained == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emits_periodic_drain_waiting_log_while_waiting():
|
||||
"""With a zero log interval, the periodic drain_waiting branch runs on each
|
||||
poll until the counter finally drains."""
|
||||
count_fn = _counter_that_drains_after(calls_before_zero=2)
|
||||
drained = await GracefulShutdownManager.wait_for_drain(
|
||||
timeout=10, count_fn=count_fn, poll_interval=0, log_interval=0
|
||||
)
|
||||
assert drained == 3
|
||||
Loading…
Reference in New Issue
Block a user