457 lines
17 KiB
Python
457 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
LiteLLM Health Check Client
|
|
|
|
A sentinel health check tool that tests all configured models on a LiteLLM proxy.
|
|
This script:
|
|
- Can read models from YAML config file or fetch from proxy API
|
|
- Sends a simple test request to each model concurrently
|
|
- Reports health status for each model
|
|
- Supports both chat/completion and embedding models
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import httpx
|
|
import yaml
|
|
|
|
# Default prompt for health checks - exactly 100k characters
|
|
# Generate a repeating pattern to reach exactly 100,000 characters
|
|
_base_text = "This is a health check test prompt for LiteLLM proxy. "
|
|
_repeat_count = (100000 // len(_base_text)) + 1
|
|
_DEFAULT_COMPLETION_PROMPT = (_base_text * _repeat_count)[:100000]
|
|
|
|
# Default embedding text - also exactly 100k characters
|
|
_embedding_base_text = "This is a test for vectorization. "
|
|
_embedding_repeat_count = (100000 // len(_embedding_base_text)) + 1
|
|
_DEFAULT_EMBEDDING_TEXT = (_embedding_base_text * _embedding_repeat_count)[:100000]
|
|
|
|
|
|
class LiteLLMHealthCheckClient:
|
|
"""Client for health checking LiteLLM proxy models."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_key: str,
|
|
timeout: int = 120, # Match Go implementation's 120s timeout
|
|
completion_prompt: str = _DEFAULT_COMPLETION_PROMPT, # Default ~100k chars
|
|
embedding_text: str = _DEFAULT_EMBEDDING_TEXT, # Default ~100k chars
|
|
custom_auth_header: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize the health check client.
|
|
|
|
Args:
|
|
base_url: Base URL of the LiteLLM proxy (e.g., https://litellm.example.com)
|
|
api_key: API key for authentication
|
|
timeout: Request timeout in seconds (default: 120, matching Go implementation)
|
|
completion_prompt: Test prompt for chat/completion models
|
|
embedding_text: Test text for embedding models
|
|
custom_auth_header: Optional custom header name for authentication (e.g., "x-requester-service").
|
|
If provided, uses this header instead of standard "Authorization" header.
|
|
"""
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.timeout = timeout
|
|
self.completion_prompt = completion_prompt
|
|
self.embedding_text = embedding_text
|
|
|
|
# Debug: Print prompt/text lengths
|
|
print(
|
|
f"DEBUG: Completion prompt length: {len(self.completion_prompt)} characters",
|
|
file=sys.stderr,
|
|
)
|
|
print(
|
|
f"DEBUG: Embedding text length: {len(self.embedding_text)} characters",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# Support custom auth header for proxies with custom authentication
|
|
# Handle both None and empty string
|
|
if custom_auth_header and custom_auth_header.strip():
|
|
custom_auth_header = custom_auth_header.strip()
|
|
self.headers = {
|
|
custom_auth_header: f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
print(f"Using custom auth header: {custom_auth_header}", file=sys.stderr)
|
|
else:
|
|
self.headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
print("Using standard Authorization header", file=sys.stderr)
|
|
|
|
def load_models_from_yaml(self, yaml_path: str) -> List[Dict]:
|
|
"""
|
|
Load models from a YAML config file (similar to Go implementation).
|
|
|
|
Args:
|
|
yaml_path: Path to the YAML config file
|
|
|
|
Returns:
|
|
List of model dictionaries with 'id' and 'mode' keys
|
|
"""
|
|
try:
|
|
with open(yaml_path, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
model_list = config.get("model_list", [])
|
|
models = []
|
|
|
|
for entry in model_list:
|
|
model_name = entry.get("model_name", "")
|
|
litellm_params = entry.get("litellm_params", {})
|
|
model_info = litellm_params.get("model_info", {})
|
|
mode = model_info.get("mode", "")
|
|
|
|
# Use model_name as the ID (this is what gets sent to the API)
|
|
models.append(
|
|
{
|
|
"id": model_name,
|
|
"mode": mode.lower() if mode else "",
|
|
"provider": model_info.get("provider", ""),
|
|
}
|
|
)
|
|
|
|
return models
|
|
except Exception as e:
|
|
print(
|
|
f"Error loading models from YAML file {yaml_path}: {e}", file=sys.stderr
|
|
)
|
|
return []
|
|
|
|
async def fetch_models(self, client: httpx.AsyncClient) -> List[Dict]:
|
|
"""
|
|
Fetch all available models from the proxy API.
|
|
|
|
Returns:
|
|
List of model dictionaries with 'id' and 'mode' keys
|
|
"""
|
|
try:
|
|
# Try /v1/models first (OpenAI-compatible endpoint)
|
|
response = await client.get(
|
|
f"{self.base_url}/v1/models",
|
|
headers=self.headers,
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
models_data = data.get("data", [])
|
|
models = []
|
|
for m in models_data:
|
|
models.append({"id": m["id"], "mode": "", "provider": ""})
|
|
return models
|
|
except Exception as e:
|
|
print(f"Error fetching models from /v1/models: {e}", file=sys.stderr)
|
|
# Fallback to /model/info endpoint which has more details
|
|
try:
|
|
response = await client.get(
|
|
f"{self.base_url}/model/info",
|
|
headers=self.headers,
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
if isinstance(data, dict) and "data" in data:
|
|
models_data = data["data"]
|
|
elif isinstance(data, list):
|
|
models_data = data
|
|
else:
|
|
models_data = []
|
|
|
|
models = []
|
|
for m in models_data:
|
|
model_info = m.get("model_info", {})
|
|
mode = model_info.get("mode", "")
|
|
models.append(
|
|
{
|
|
"id": m.get("model_name", m.get("id", "unknown")),
|
|
"mode": mode.lower() if mode else "",
|
|
"provider": model_info.get("provider", ""),
|
|
}
|
|
)
|
|
return models
|
|
except Exception as e2:
|
|
print(f"Error fetching models from /model/info: {e2}", file=sys.stderr)
|
|
return []
|
|
|
|
async def check_model_health(
|
|
self, client: httpx.AsyncClient, model: Dict
|
|
) -> Tuple[str, Dict]:
|
|
"""
|
|
Check health of a single model by sending a test request.
|
|
|
|
Args:
|
|
client: HTTP client
|
|
model: Model dictionary with 'id' and 'mode' keys
|
|
|
|
Returns:
|
|
Tuple of (model_id, result_dict)
|
|
"""
|
|
model_id = model["id"]
|
|
mode = model.get("mode", "")
|
|
|
|
start_time = time.time()
|
|
result = {
|
|
"model": model_id,
|
|
"healthy": False,
|
|
"error": None,
|
|
"response_time_ms": None,
|
|
"mode": mode,
|
|
}
|
|
|
|
try:
|
|
# Determine if this is an embedding model
|
|
# Check mode first (from config), then fall back to name-based detection
|
|
is_embedding = mode == "embedding" or any(
|
|
keyword in model_id.lower()
|
|
for keyword in ["embedding", "embed", "text-embedding"]
|
|
)
|
|
|
|
if is_embedding:
|
|
# Test embedding endpoint (matching Go implementation)
|
|
embedding_text_length = len(self.embedding_text)
|
|
print(
|
|
f"DEBUG: Sending embedding text of length {embedding_text_length} chars to model {model_id}",
|
|
file=sys.stderr,
|
|
)
|
|
embedding_response = await client.post(
|
|
f"{self.base_url}/v1/embeddings",
|
|
headers=self.headers,
|
|
json={
|
|
"model": model_id,
|
|
"input": self.embedding_text,
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
embedding_response.raise_for_status()
|
|
embedding_data = embedding_response.json()
|
|
dimensions = 0
|
|
if "data" in embedding_data and len(embedding_data["data"]) > 0:
|
|
dimensions = len(embedding_data["data"][0].get("embedding", []))
|
|
|
|
result["healthy"] = True
|
|
result["mode"] = "embedding"
|
|
result["dimensions"] = dimensions
|
|
else:
|
|
# Test chat completion endpoint (matching Go implementation)
|
|
prompt_length = len(self.completion_prompt)
|
|
print(
|
|
f"DEBUG: Sending prompt of length {prompt_length} chars to model {model_id}",
|
|
file=sys.stderr,
|
|
)
|
|
completion_response = await client.post(
|
|
f"{self.base_url}/v1/chat/completions",
|
|
headers=self.headers,
|
|
json={
|
|
"model": model_id,
|
|
"messages": [
|
|
{"role": "user", "content": self.completion_prompt}
|
|
],
|
|
"max_tokens": 10, # Minimal tokens for health check
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
completion_response.raise_for_status()
|
|
completion_data = completion_response.json()
|
|
response_text = ""
|
|
if "choices" in completion_data and len(completion_data["choices"]) > 0:
|
|
response_text = (
|
|
completion_data["choices"][0]
|
|
.get("message", {})
|
|
.get("content", "")
|
|
)
|
|
|
|
result["healthy"] = True
|
|
result["mode"] = "chat"
|
|
result["response_text"] = response_text[:100] # Truncate for display
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
result["response_time_ms"] = round(elapsed_ms, 2)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
result["error"] = f"HTTP {e.response.status_code}: {e.response.text[:200]}"
|
|
except httpx.TimeoutException:
|
|
result["error"] = f"Request timeout after {self.timeout}s"
|
|
except Exception as e:
|
|
result["error"] = str(e)[:200]
|
|
|
|
return model_id, result
|
|
|
|
async def run_health_checks(
|
|
self,
|
|
models: Optional[List[Dict]] = None,
|
|
models_only: Optional[List[str]] = None,
|
|
) -> Dict[str, Dict]:
|
|
"""
|
|
Run health checks on all models concurrently.
|
|
|
|
Args:
|
|
models: Optional list of models to check. If None, fetches from proxy.
|
|
models_only: Optional list of model IDs to check. If set, only these
|
|
models are health-checked (must exist in the models list).
|
|
|
|
Returns:
|
|
Dictionary mapping model_id to health check result
|
|
"""
|
|
async with httpx.AsyncClient() as client:
|
|
if models is None:
|
|
models = await self.fetch_models(client)
|
|
|
|
if not models:
|
|
print("No models found to health check", file=sys.stderr)
|
|
return {}
|
|
|
|
if models_only:
|
|
allowlist = {m.strip() for m in models_only if m and m.strip()}
|
|
models = [m for m in models if m.get("id") in allowlist]
|
|
print(
|
|
f"Filtering to only check {len(models)} models: {', '.join(sorted(allowlist))}",
|
|
file=sys.stderr,
|
|
)
|
|
if not models:
|
|
print(
|
|
"No models matched LITELLM_MODELS_ONLY filter",
|
|
file=sys.stderr,
|
|
)
|
|
return {}
|
|
|
|
print(f"Running health checks on {len(models)} models...", file=sys.stderr)
|
|
|
|
# Run all health checks concurrently
|
|
tasks = [self.check_model_health(client, model) for model in models]
|
|
results_list = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Convert to dictionary format
|
|
results = {}
|
|
for result in results_list:
|
|
if isinstance(result, Exception):
|
|
print(f"Exception in health check task: {result}", file=sys.stderr)
|
|
continue
|
|
# Type narrowing: after checking it's not an Exception, it's a Tuple
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
model_id, result_dict = result
|
|
results[model_id] = result_dict
|
|
|
|
return results
|
|
|
|
def print_results(self, results: Dict[str, Dict], json_output: bool = False):
|
|
"""
|
|
Print health check results.
|
|
|
|
Args:
|
|
results: Dictionary of health check results
|
|
json_output: If True, output as JSON
|
|
"""
|
|
if json_output:
|
|
print(json.dumps(results, indent=2))
|
|
return
|
|
|
|
healthy_count = sum(1 for r in results.values() if r.get("healthy"))
|
|
unhealthy_count = len(results) - healthy_count
|
|
|
|
# Print detailed results for each model (matching Go output format)
|
|
print(f"\n{'='*60}", file=sys.stderr)
|
|
print(f"Starting health check queries\n", file=sys.stderr)
|
|
|
|
for model_id, result in results.items():
|
|
if result.get("healthy"):
|
|
if result.get("mode") == "embedding":
|
|
dimensions = result.get("dimensions", 0)
|
|
print(
|
|
f"---- {model_id} ----\n✅ Success. "
|
|
f"Generated embedding vector with {dimensions} dimensions.\n\n",
|
|
file=sys.stderr,
|
|
)
|
|
else:
|
|
response_text = result.get("response_text", "")
|
|
print(
|
|
f"---- {model_id} ----\n✅ Success. "
|
|
f"Response:\n{response_text}\n\n",
|
|
file=sys.stderr,
|
|
)
|
|
else:
|
|
error = result.get("error", "Unknown error")
|
|
print(f"---- {model_id} ----\n❌ ERROR: {error}\n\n", file=sys.stderr)
|
|
|
|
print(f"{'='*60}", file=sys.stderr)
|
|
print(f"Health Check Summary", file=sys.stderr)
|
|
print(f"{'='*60}", file=sys.stderr)
|
|
print(f"Total models: {len(results)}", file=sys.stderr)
|
|
print(f"Healthy: {healthy_count}", file=sys.stderr)
|
|
print(f"Unhealthy: {unhealthy_count}", file=sys.stderr)
|
|
print(f"{'='*60}\n", file=sys.stderr)
|
|
|
|
# Exit with non-zero code if any models are unhealthy
|
|
if unhealthy_count > 0:
|
|
sys.exit(1)
|
|
else:
|
|
sys.exit(0)
|
|
|
|
|
|
async def main():
|
|
"""Main entry point."""
|
|
base_url = os.environ.get("LITELLM_BASE_URL", "http://localhost:4000")
|
|
api_key = os.environ.get("LITELLM_API_KEY", "sk-1234")
|
|
yaml_path = os.environ.get("LITELLM_MODELS_YAML")
|
|
custom_auth_header = os.environ.get(
|
|
"LITELLM_CUSTOM_AUTH_HEADER"
|
|
) # e.g., "x-requester-service"
|
|
|
|
# Debug: Print custom auth header value if set
|
|
if custom_auth_header:
|
|
print(f"Custom auth header from env: '{custom_auth_header}'", file=sys.stderr)
|
|
|
|
if not base_url:
|
|
print("Error: LITELLM_BASE_URL environment variable not set", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if not api_key:
|
|
print("Error: LITELLM_API_KEY environment variable not set", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
timeout = int(os.environ.get("LITELLM_TIMEOUT", "120")) # Match Go's 120s default
|
|
completion_prompt = os.environ.get(
|
|
"LITELLM_COMPLETION_PROMPT", _DEFAULT_COMPLETION_PROMPT
|
|
)
|
|
embedding_text = os.environ.get("LITELLM_EMBEDDING_TEXT", _DEFAULT_EMBEDDING_TEXT)
|
|
json_output = os.environ.get("LITELLM_JSON_OUTPUT", "").lower() == "true"
|
|
# Optional: only health-check these model IDs (comma-separated). E.g.:
|
|
# LITELLM_MODELS_ONLY=claude-3.7-sonnet,claude-3.5-sonnet,claude-4.5-haiku
|
|
models_only_raw = os.environ.get("LITELLM_MODELS_ONLY", "")
|
|
models_only = [m.strip() for m in models_only_raw.split(",") if m.strip()] or None
|
|
|
|
client = LiteLLMHealthCheckClient(
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
timeout=timeout,
|
|
completion_prompt=completion_prompt,
|
|
embedding_text=embedding_text,
|
|
custom_auth_header=custom_auth_header,
|
|
)
|
|
|
|
# Load models from YAML if provided, otherwise fetch from API
|
|
models = None
|
|
if yaml_path:
|
|
models = client.load_models_from_yaml(yaml_path)
|
|
if models:
|
|
print(
|
|
f"Successfully loaded {len(models)} models from {yaml_path}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
results = await client.run_health_checks(models=models, models_only=models_only)
|
|
client.print_results(results, json_output=json_output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|