feat(proxy): native /health/drain preStop hook for graceful shutdown (#29439)
This commit is contained in:
parent
a5ccd96152
commit
3a1c6bba97
@ -36,6 +36,7 @@ jobs:
|
|||||||
tests/test_litellm/proxy/a2a
|
tests/test_litellm/proxy/a2a
|
||||||
tests/test_litellm/proxy/discovery_endpoints
|
tests/test_litellm/proxy/discovery_endpoints
|
||||||
tests/test_litellm/proxy/health_endpoints
|
tests/test_litellm/proxy/health_endpoints
|
||||||
|
tests/test_litellm/proxy/shutdown
|
||||||
tests/test_litellm/proxy/public_endpoints
|
tests/test_litellm/proxy/public_endpoints
|
||||||
tests/test_litellm/proxy/prompts
|
tests/test_litellm/proxy/prompts
|
||||||
tests/test_litellm/proxy/rag_endpoints
|
tests/test_litellm/proxy/rag_endpoints
|
||||||
|
|||||||
2
Makefile
2
Makefile
@ -146,7 +146,7 @@ test-unit-proxy-core: install-test-deps
|
|||||||
$(UV_RUN) pytest tests/test_litellm/proxy/auth tests/test_litellm/proxy/client tests/test_litellm/proxy/db tests/test_litellm/proxy/hooks tests/test_litellm/proxy/policy_engine --tb=short -vv -n 4 --durations=20
|
$(UV_RUN) pytest tests/test_litellm/proxy/auth tests/test_litellm/proxy/client tests/test_litellm/proxy/db tests/test_litellm/proxy/hooks tests/test_litellm/proxy/policy_engine --tb=short -vv -n 4 --durations=20
|
||||||
|
|
||||||
test-unit-proxy-misc: install-test-deps
|
test-unit-proxy-misc: install-test-deps
|
||||||
$(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20
|
$(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/shutdown tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20
|
||||||
|
|
||||||
test-unit-integrations: install-test-deps
|
test-unit-integrations: install-test-deps
|
||||||
$(UV_RUN) pytest tests/test_litellm/integrations --tb=short -vv -n 4 --durations=20
|
$(UV_RUN) pytest tests/test_litellm/integrations --tb=short -vv -n 4 --durations=20
|
||||||
|
|||||||
@ -285,11 +285,31 @@ db:
|
|||||||
deployStandalone: true
|
deployStandalone: true
|
||||||
|
|
||||||
# Lifecycle hooks for the LiteLLM container
|
# Lifecycle hooks for the LiteLLM container
|
||||||
|
#
|
||||||
|
# Prefer the native /health/drain preStop hook over a fixed `sleep`: it marks
|
||||||
|
# the pod NotReady and blocks only until in-flight requests actually finish
|
||||||
|
# (bounded by GRACEFUL_SHUTDOWN_TIMEOUT, default 30s), instead of always
|
||||||
|
# waiting the worst-case duration. The drain runs once (the preStop hook and
|
||||||
|
# the SIGTERM handler share it), so set terminationGracePeriodSeconds a few
|
||||||
|
# seconds above GRACEFUL_SHUTDOWN_TIMEOUT to leave room for teardown before
|
||||||
|
# SIGKILL.
|
||||||
|
#
|
||||||
|
# /health/drain is off by default; enable it with
|
||||||
|
# general_settings.enable_drain_endpoint: true. The kubelet calls preStop
|
||||||
|
# hooks without proxy credentials, so when the health port is reachable from
|
||||||
|
# other pods (the common case) also set
|
||||||
|
# general_settings.drain_endpoint_token (or the DRAIN_ENDPOINT_TOKEN env
|
||||||
|
# var) and send the same value on the X-Drain-Token header from the hook.
|
||||||
|
# Calls missing/wrong the token get a 401 and have no side effect.
|
||||||
# Example:
|
# Example:
|
||||||
# lifecycle:
|
# lifecycle:
|
||||||
# preStop:
|
# preStop:
|
||||||
# exec:
|
# httpGet:
|
||||||
# command: ["/bin/sh", "-c", "sleep 10"]
|
# path: /health/drain
|
||||||
|
# port: 4000
|
||||||
|
# httpHeaders:
|
||||||
|
# - name: X-Drain-Token
|
||||||
|
# value: <same value as drain_endpoint_token>
|
||||||
lifecycle: {}
|
lifecycle: {}
|
||||||
|
|
||||||
# Settings for Bitnami postgresql chart (if db.deployStandalone is true, ignored
|
# Settings for Bitnami postgresql chart (if db.deployStandalone is true, ignored
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -39,6 +40,7 @@ from litellm.proxy.health_check import (
|
|||||||
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||||
get_in_flight_requests,
|
get_in_flight_requests,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||||
|
|
||||||
#### Health ENDPOINTS ####
|
#### Health ENDPOINTS ####
|
||||||
|
|
||||||
@ -1551,6 +1553,50 @@ def _allow_public_health_readiness_details() -> bool:
|
|||||||
return general_settings.get("allow_public_health_readiness_details") is True
|
return general_settings.get("allow_public_health_readiness_details") is True
|
||||||
|
|
||||||
|
|
||||||
|
def _drain_endpoint_enabled() -> bool:
|
||||||
|
from litellm.proxy.proxy_server import general_settings
|
||||||
|
|
||||||
|
return general_settings.get("enable_drain_endpoint") is True
|
||||||
|
|
||||||
|
|
||||||
|
def _drain_endpoint_token() -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Shared secret required on the X-Drain-Token header to call /health/drain.
|
||||||
|
|
||||||
|
Falls back to the ``DRAIN_ENDPOINT_TOKEN`` env var when unset in
|
||||||
|
general_settings so the kubelet preStop hook can supply it via
|
||||||
|
``valueFrom.secretKeyRef`` without a config reload.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import general_settings
|
||||||
|
|
||||||
|
token = general_settings.get("drain_endpoint_token")
|
||||||
|
if isinstance(token, str) and token:
|
||||||
|
return token
|
||||||
|
env_token = os.getenv("DRAIN_ENDPOINT_TOKEN")
|
||||||
|
if env_token:
|
||||||
|
return env_token
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _authorize_drain_request(request: Request) -> None:
|
||||||
|
"""
|
||||||
|
Reject /health/drain calls that don't carry the configured X-Drain-Token.
|
||||||
|
|
||||||
|
When no token is configured the endpoint is treated as already opted-in
|
||||||
|
(the ``enable_drain_endpoint`` flag is the only gate). Comparison uses
|
||||||
|
``secrets.compare_digest`` to avoid timing leaks.
|
||||||
|
"""
|
||||||
|
expected = _drain_endpoint_token()
|
||||||
|
if expected is None:
|
||||||
|
return
|
||||||
|
supplied = request.headers.get("x-drain-token") or ""
|
||||||
|
if not secrets.compare_digest(supplied, expected):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or missing X-Drain-Token",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_public_readiness_db(response: Response) -> str:
|
async def _resolve_public_readiness_db(response: Response) -> str:
|
||||||
"""
|
"""
|
||||||
Return the db status string for the public probe and flip the response to
|
Return the db status string for the public probe and flip the response to
|
||||||
@ -1580,6 +1626,10 @@ async def health_readiness(response: Response):
|
|||||||
credential. Admins can opt into the legacy detailed payload with
|
credential. Admins can opt into the legacy detailed payload with
|
||||||
general_settings.allow_public_health_readiness_details.
|
general_settings.allow_public_health_readiness_details.
|
||||||
"""
|
"""
|
||||||
|
if GracefulShutdownManager.is_shutting_down():
|
||||||
|
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
return {"status": "shutting_down"}
|
||||||
|
|
||||||
if _allow_public_health_readiness_details():
|
if _allow_public_health_readiness_details():
|
||||||
return await _get_health_readiness_details(response=response)
|
return await _get_health_readiness_details(response=response)
|
||||||
|
|
||||||
@ -1616,6 +1666,54 @@ async def health_backlog():
|
|||||||
return {"in_flight_requests": get_in_flight_requests()}
|
return {"in_flight_requests": get_in_flight_requests()}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health/drain",
|
||||||
|
tags=["health"],
|
||||||
|
)
|
||||||
|
async def health_drain(request: Request):
|
||||||
|
"""
|
||||||
|
Graceful-drain probe for Kubernetes ``preStop`` hooks.
|
||||||
|
|
||||||
|
Disabled by default and returns 404 unless ``general_settings`` sets
|
||||||
|
``enable_drain_endpoint: true``. Calling it flips a process-wide
|
||||||
|
shutting-down flag, so a successful call permanently takes the worker out
|
||||||
|
of rotation until the pod restarts.
|
||||||
|
|
||||||
|
Because the kubelet calls preStop hooks without proxy credentials, the
|
||||||
|
endpoint does not require ``user_api_key_auth``. To prevent any
|
||||||
|
pod-reachable caller from triggering shutdown, set
|
||||||
|
``general_settings.drain_endpoint_token`` (or the ``DRAIN_ENDPOINT_TOKEN``
|
||||||
|
env var) and supply the same value on the ``X-Drain-Token`` header from
|
||||||
|
the preStop hook. Calls without the header (or with a wrong value) get a
|
||||||
|
401 and have no side effect.
|
||||||
|
|
||||||
|
When enabled, it marks the worker as shutting down (so /health/readiness
|
||||||
|
and /health/liveliness immediately start returning 503, removing the pod
|
||||||
|
from service) and blocks until the in-flight request counter drains to
|
||||||
|
zero or ``GRACEFUL_SHUTDOWN_TIMEOUT`` elapses. Unlike a fixed ``sleep``,
|
||||||
|
this returns as soon as real in-flight work is done.
|
||||||
|
|
||||||
|
Wire it up as:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lifecycle:
|
||||||
|
preStop:
|
||||||
|
httpGet:
|
||||||
|
path: /health/drain
|
||||||
|
port: 4000
|
||||||
|
httpHeaders:
|
||||||
|
- name: X-Drain-Token
|
||||||
|
value: <same value as drain_endpoint_token>
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not _drain_endpoint_enabled():
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
|
||||||
|
_authorize_drain_request(request)
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(exclude_self=True)
|
||||||
|
return {"status": "drained", "drained_requests": drained}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/health/liveliness", # Historical LiteLLM name; doesn't match k8s terminology but kept for backwards compatibility
|
"/health/liveliness", # Historical LiteLLM name; doesn't match k8s terminology but kept for backwards compatibility
|
||||||
tags=["health"],
|
tags=["health"],
|
||||||
@ -1624,10 +1722,16 @@ async def health_backlog():
|
|||||||
"/health/liveness", # Kubernetes has "liveness" probes (https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/#define-a-liveness-command)
|
"/health/liveness", # Kubernetes has "liveness" probes (https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/#define-a-liveness-command)
|
||||||
tags=["health"],
|
tags=["health"],
|
||||||
)
|
)
|
||||||
async def health_liveliness():
|
async def health_liveliness(response: Response):
|
||||||
"""
|
"""
|
||||||
Unprotected endpoint for checking if worker is alive
|
Unprotected endpoint for checking if worker is alive.
|
||||||
|
|
||||||
|
Returns 503 once graceful shutdown has begun so Kubernetes stops counting
|
||||||
|
the draining pod as live and terminates it on schedule.
|
||||||
"""
|
"""
|
||||||
|
if GracefulShutdownManager.is_shutting_down():
|
||||||
|
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
return {"status": "shutting_down"}
|
||||||
return "I'm alive!"
|
return "I'm alive!"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -417,6 +417,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
|
|||||||
from litellm.proxy.middleware.request_size_limit_middleware import (
|
from litellm.proxy.middleware.request_size_limit_middleware import (
|
||||||
RequestSizeLimitMiddleware,
|
RequestSizeLimitMiddleware,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||||
from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router
|
from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router
|
||||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
router as openai_files_router,
|
router as openai_files_router,
|
||||||
@ -973,6 +974,11 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
|
|||||||
# End of startup event
|
# End of startup event
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Shutdown event - drain in-flight requests before tearing down dependencies
|
||||||
|
# so SIGTERM (rolling update, scale-down, liveness kill) doesn't drop them.
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
await GracefulShutdownManager.wait_for_drain()
|
||||||
|
|
||||||
# Shutdown event - close shared aiohttp session
|
# Shutdown event - close shared aiohttp session
|
||||||
if shared_aiohttp_session is not None:
|
if shared_aiohttp_session is not None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
0
litellm/proxy/shutdown/__init__.py
Normal file
0
litellm/proxy/shutdown/__init__.py
Normal file
174
litellm/proxy/shutdown/graceful_shutdown_manager.py
Normal file
174
litellm/proxy/shutdown/graceful_shutdown_manager.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
Application-level graceful shutdown coordination for the LiteLLM proxy.
|
||||||
|
|
||||||
|
Kubernetes terminates a pod by sending ``SIGTERM`` and, after
|
||||||
|
``terminationGracePeriodSeconds``, ``SIGKILL``. By default LiteLLM delegates
|
||||||
|
the signal to uvicorn and tears down immediately, dropping any in-flight
|
||||||
|
requests (streaming, batch inference, long-lived calls).
|
||||||
|
|
||||||
|
A fixed ``preStop`` sleep can not solve this: it has to be sized for the
|
||||||
|
*worst-case* request, so it either wastes time on every routine shutdown or is
|
||||||
|
too short for a long-running request. This manager instead drains based on the
|
||||||
|
*actual* in-flight request counter (already tracked by
|
||||||
|
``InFlightRequestsMiddleware``), so a pod terminates as soon as its real
|
||||||
|
in-flight work is done — and never waits longer than ``GRACEFUL_SHUTDOWN_TIMEOUT``.
|
||||||
|
|
||||||
|
The state is process-scoped (class-level), matching the per-uvicorn-worker
|
||||||
|
granularity of ``InFlightRequestsMiddleware``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||||
|
get_in_flight_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep below terminationGracePeriodSeconds so the process exits before SIGKILL.
|
||||||
|
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30.0
|
||||||
|
_DRAIN_POLL_INTERVAL = 0.1
|
||||||
|
_DRAIN_LOG_INTERVAL = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
class GracefulShutdownManager:
|
||||||
|
"""
|
||||||
|
Process-scoped singleton that tracks whether the worker is draining and
|
||||||
|
blocks until in-flight requests reach zero (or a timeout elapses).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_is_shutting_down: bool = False
|
||||||
|
_shutdown_started_at: Optional[float] = None
|
||||||
|
_drain_performed: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_shutting_down(cls) -> bool:
|
||||||
|
"""Whether this worker has begun graceful shutdown."""
|
||||||
|
return cls._is_shutting_down
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_timeout(cls) -> float:
|
||||||
|
"""
|
||||||
|
Read GRACEFUL_SHUTDOWN_TIMEOUT (seconds) from the environment on each
|
||||||
|
call so deployments can tune it without code changes. Falls back to the
|
||||||
|
default on an unset or malformed value.
|
||||||
|
"""
|
||||||
|
raw = os.getenv("GRACEFUL_SHUTDOWN_TIMEOUT")
|
||||||
|
if raw is None:
|
||||||
|
return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"GRACEFUL_SHUTDOWN_TIMEOUT=%r is not a number; using default %ss",
|
||||||
|
raw,
|
||||||
|
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||||
|
)
|
||||||
|
return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_shutdown(cls) -> None:
|
||||||
|
"""
|
||||||
|
Mark the worker as draining. Idempotent — repeated calls (e.g. SIGTERM
|
||||||
|
followed by a preStop hit on /health/drain) do not reset the clock.
|
||||||
|
"""
|
||||||
|
if cls._is_shutting_down:
|
||||||
|
return
|
||||||
|
cls._is_shutting_down = True
|
||||||
|
cls._shutdown_started_at = time.monotonic()
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"graceful_shutdown_started in_flight_requests=%s",
|
||||||
|
get_in_flight_requests(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def wait_for_drain(
|
||||||
|
cls,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
exclude_self: bool = False,
|
||||||
|
count_fn: Optional[Callable[[], int]] = None,
|
||||||
|
poll_interval: float = _DRAIN_POLL_INTERVAL,
|
||||||
|
log_interval: float = _DRAIN_LOG_INTERVAL,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Poll the in-flight request counter until it reaches the drain target or
|
||||||
|
``timeout`` seconds elapse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Max seconds to wait. Defaults to ``get_timeout()``.
|
||||||
|
exclude_self: When the caller is itself an in-flight HTTP request
|
||||||
|
(the /health/drain endpoint), set this so the caller's own
|
||||||
|
request is not counted as outstanding work.
|
||||||
|
count_fn: Source of the current in-flight count. Defaults to the
|
||||||
|
live ``InFlightRequestsMiddleware`` counter; injectable for tests.
|
||||||
|
poll_interval: Seconds between counter polls.
|
||||||
|
log_interval: Minimum seconds between ``drain_waiting`` log lines.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of requests that drained while waiting (>= 0).
|
||||||
|
"""
|
||||||
|
# A preStop /health/drain hook and the lifespan SIGTERM handler both
|
||||||
|
# drain; once one has run, the other must not wait again, otherwise the
|
||||||
|
# effective window is 2x the timeout and terminationGracePeriodSeconds
|
||||||
|
# has to be doubled to avoid a mid-drain SIGKILL.
|
||||||
|
if cls._drain_performed:
|
||||||
|
return 0
|
||||||
|
cls._drain_performed = True
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
timeout = cls.get_timeout()
|
||||||
|
if count_fn is None:
|
||||||
|
count_fn = get_in_flight_requests
|
||||||
|
|
||||||
|
# The /health/drain HTTP request flows through InFlightRequestsMiddleware
|
||||||
|
# and so counts itself; treat <=1 as "drained" in that case.
|
||||||
|
target = 1 if exclude_self else 0
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
initial = count_fn()
|
||||||
|
last_log = start
|
||||||
|
|
||||||
|
if timeout <= 0:
|
||||||
|
return max(0, initial - target)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
current = count_fn()
|
||||||
|
if current <= target:
|
||||||
|
drained = max(0, initial - current)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"graceful_shutdown_complete drained_requests=%s elapsed_s=%.2f",
|
||||||
|
drained,
|
||||||
|
time.monotonic() - start,
|
||||||
|
)
|
||||||
|
return drained
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
if elapsed >= timeout:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"graceful_shutdown_timeout in_flight_requests=%s elapsed_s=%.2f "
|
||||||
|
"timeout_s=%s — proceeding with teardown",
|
||||||
|
current,
|
||||||
|
elapsed,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
return max(0, initial - current)
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - last_log >= log_interval:
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"drain_waiting in_flight_requests=%s elapsed_s=%.2f",
|
||||||
|
current,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
last_log = now
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls) -> None:
|
||||||
|
"""Reset state. Intended for use in tests."""
|
||||||
|
cls._is_shutting_down = False
|
||||||
|
cls._shutdown_started_at = None
|
||||||
|
cls._drain_performed = False
|
||||||
@ -14,7 +14,6 @@ import pytest
|
|||||||
import yaml
|
import yaml
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
_PROXY_MODULE_GLOBALS_TO_ISOLATE = (
|
_PROXY_MODULE_GLOBALS_TO_ISOLATE = (
|
||||||
"master_key",
|
"master_key",
|
||||||
"prisma_client",
|
"prisma_client",
|
||||||
@ -49,6 +48,18 @@ def _isolate_proxy_module_globals():
|
|||||||
setattr(proxy_server, name, value)
|
setattr(proxy_server, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_graceful_shutdown_state():
|
||||||
|
"""Graceful shutdown state is process-scoped; keep it from leaking between tests."""
|
||||||
|
from litellm.proxy.shutdown.graceful_shutdown_manager import (
|
||||||
|
GracefulShutdownManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
yield
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
|
||||||
|
|
||||||
def build_cache_config(enable_cache: bool = True) -> Optional[Dict]:
|
def build_cache_config(enable_cache: bool = True) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
Build Redis cache configuration from environment variables.
|
Build Redis cache configuration from environment variables.
|
||||||
|
|||||||
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Behaviour tests for the graceful-shutdown health probes.
|
||||||
|
|
||||||
|
Builds a minimal FastAPI app from the health router plus
|
||||||
|
InFlightRequestsMiddleware so the probe responses can be asserted without
|
||||||
|
standing up the full proxy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from litellm.proxy.health_endpoints._health_endpoints import router
|
||||||
|
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||||
|
InFlightRequestsMiddleware,
|
||||||
|
)
|
||||||
|
from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset():
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
InFlightRequestsMiddleware._in_flight = 0
|
||||||
|
yield
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
InFlightRequestsMiddleware._in_flight = 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
app.add_middleware(InFlightRequestsMiddleware)
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enable_drain(monkeypatch):
|
||||||
|
from litellm.proxy import proxy_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
proxy_server, "general_settings", {"enable_drain_endpoint": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enable_drain_with_token(monkeypatch):
|
||||||
|
from litellm.proxy import proxy_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
proxy_server,
|
||||||
|
"general_settings",
|
||||||
|
{"enable_drain_endpoint": True, "drain_endpoint_token": "secret-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_disabled_by_default_returns_404_with_no_side_effect(client, monkeypatch):
|
||||||
|
from litellm.proxy import proxy_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(proxy_server, "general_settings", {})
|
||||||
|
resp = client.get("/health/drain")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_disabled_ignores_token_header(client, monkeypatch):
|
||||||
|
"""A token alone must not bypass the enable flag; otherwise enabling the
|
||||||
|
token side-channel would silently enable the endpoint."""
|
||||||
|
from litellm.proxy import proxy_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
proxy_server, "general_settings", {"drain_endpoint_token": "secret-123"}
|
||||||
|
)
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"})
|
||||||
|
assert resp.status_code == 404
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_when_enabled_without_token_sets_shutting_down_and_returns_drained(
|
||||||
|
client, enable_drain
|
||||||
|
):
|
||||||
|
resp = client.get("/health/drain")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "drained"
|
||||||
|
assert body["drained_requests"] == 0
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_with_token_configured_rejects_missing_header(
|
||||||
|
client, enable_drain_with_token
|
||||||
|
):
|
||||||
|
resp = client.get("/health/drain")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_with_token_configured_rejects_wrong_header(
|
||||||
|
client, enable_drain_with_token
|
||||||
|
):
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "wrong-value"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_with_token_configured_accepts_correct_header(
|
||||||
|
client, enable_drain_with_token
|
||||||
|
):
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "drained"
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_with_token_from_env_var(client, enable_drain, monkeypatch):
|
||||||
|
monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token")
|
||||||
|
resp = client.get("/health/drain")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_drain_general_settings_token_overrides_env_var(client, monkeypatch):
|
||||||
|
from litellm.proxy import proxy_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
proxy_server,
|
||||||
|
"general_settings",
|
||||||
|
{"enable_drain_endpoint": True, "drain_endpoint_token": "config-token"},
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token")
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
resp = client.get("/health/drain", headers={"X-Drain-Token": "config-token"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_readiness_returns_503_shutting_down_during_drain(client):
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
resp = client.get("/health/readiness")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
assert resp.json() == {"status": "shutting_down"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_readiness_does_not_report_shutting_down_normally(client):
|
||||||
|
resp = client.get("/health/readiness")
|
||||||
|
assert resp.json().get("status") != "shutting_down"
|
||||||
|
|
||||||
|
|
||||||
|
def test_liveliness_returns_503_during_drain(client):
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
resp = client.get("/health/liveliness")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
assert resp.json() == {"status": "shutting_down"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_liveness_alias_returns_503_during_drain(client):
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
resp = client.get("/health/liveness")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
def test_liveliness_returns_alive_when_not_shutting_down(client):
|
||||||
|
resp = client.get("/health/liveliness")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == "I'm alive!"
|
||||||
@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Tests for GracefulShutdownManager.
|
||||||
|
|
||||||
|
These verify the drain logic that lets a pod terminate as soon as its real
|
||||||
|
in-flight work is done (bounded by GRACEFUL_SHUTDOWN_TIMEOUT) rather than
|
||||||
|
sleeping for a fixed worst-case duration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from litellm.proxy.shutdown.graceful_shutdown_manager import (
|
||||||
|
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT,
|
||||||
|
GracefulShutdownManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset():
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
yield
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def _counter_that_drains_after(calls_before_zero: int):
|
||||||
|
"""Return a count_fn that reports N in-flight until it has been polled
|
||||||
|
`calls_before_zero` times, then reports 0."""
|
||||||
|
state = {"polls": 0}
|
||||||
|
|
||||||
|
def count_fn() -> int:
|
||||||
|
state["polls"] += 1
|
||||||
|
return 0 if state["polls"] > calls_before_zero else 3
|
||||||
|
|
||||||
|
return count_fn
|
||||||
|
|
||||||
|
|
||||||
|
# ── shutdown flag ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_shutting_down_by_default():
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_shutdown_sets_flag():
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_shutdown_is_idempotent_and_does_not_reset_clock():
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
first = GracefulShutdownManager._shutdown_started_at
|
||||||
|
time.sleep(0.01)
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
assert GracefulShutdownManager._shutdown_started_at == first
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_clears_flag():
|
||||||
|
GracefulShutdownManager.start_shutdown()
|
||||||
|
GracefulShutdownManager.reset()
|
||||||
|
assert GracefulShutdownManager.is_shutting_down() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── timeout config ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_defaults_when_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False)
|
||||||
|
assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_reads_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "5")
|
||||||
|
assert GracefulShutdownManager.get_timeout() == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_falls_back_on_garbage(monkeypatch):
|
||||||
|
monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "not-a-number")
|
||||||
|
assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
# ── wait_for_drain ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_immediately_when_already_drained():
|
||||||
|
start = time.monotonic()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=10, count_fn=lambda: 0
|
||||||
|
)
|
||||||
|
assert drained == 0
|
||||||
|
assert time.monotonic() - start < 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_waits_until_counter_reaches_zero_then_returns_drained_count():
|
||||||
|
count_fn = _counter_that_drains_after(calls_before_zero=3)
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=10, count_fn=count_fn
|
||||||
|
)
|
||||||
|
assert drained == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_times_out_when_counter_never_drains():
|
||||||
|
start = time.monotonic()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=0.3, count_fn=lambda: 2
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
assert 0.3 <= elapsed < 2.0
|
||||||
|
assert drained == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_timeout_does_not_block():
|
||||||
|
start = time.monotonic()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=0, count_fn=lambda: 5
|
||||||
|
)
|
||||||
|
assert time.monotonic() - start < 0.2
|
||||||
|
assert drained == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exclude_self_treats_one_inflight_as_drained():
|
||||||
|
"""The /health/drain request counts itself, so a steady count of 1 must be
|
||||||
|
treated as fully drained rather than timing out."""
|
||||||
|
start = time.monotonic()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=5, exclude_self=True, count_fn=lambda: 1
|
||||||
|
)
|
||||||
|
assert time.monotonic() - start < 0.5
|
||||||
|
assert drained == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_without_exclude_self_one_inflight_blocks_until_timeout():
|
||||||
|
start = time.monotonic()
|
||||||
|
await GracefulShutdownManager.wait_for_drain(timeout=0.3, count_fn=lambda: 1)
|
||||||
|
assert time.monotonic() - start >= 0.3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_defaults_to_get_timeout_and_live_counter(monkeypatch):
|
||||||
|
"""With no timeout/count_fn passed, it falls back to get_timeout() and the
|
||||||
|
live InFlightRequestsMiddleware counter."""
|
||||||
|
from litellm.proxy.middleware.in_flight_requests_middleware import (
|
||||||
|
InFlightRequestsMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False)
|
||||||
|
InFlightRequestsMiddleware._in_flight = 0
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain()
|
||||||
|
assert drained == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_second_drain_is_a_noop_so_window_is_not_doubled():
|
||||||
|
"""preStop /health/drain and the lifespan SIGTERM handler both drain; the
|
||||||
|
second call must return immediately rather than waiting another full
|
||||||
|
timeout (which would require doubling terminationGracePeriodSeconds)."""
|
||||||
|
await GracefulShutdownManager.wait_for_drain(timeout=0.2, count_fn=lambda: 1)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=5, count_fn=lambda: 1
|
||||||
|
)
|
||||||
|
assert time.monotonic() - start < 0.1
|
||||||
|
assert drained == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emits_periodic_drain_waiting_log_while_waiting():
|
||||||
|
"""With a zero log interval, the periodic drain_waiting branch runs on each
|
||||||
|
poll until the counter finally drains."""
|
||||||
|
count_fn = _counter_that_drains_after(calls_before_zero=2)
|
||||||
|
drained = await GracefulShutdownManager.wait_for_drain(
|
||||||
|
timeout=10, count_fn=count_fn, poll_interval=0, log_interval=0
|
||||||
|
)
|
||||||
|
assert drained == 3
|
||||||
Loading…
Reference in New Issue
Block a user