feat: add adaptive routing to litellm
allow model routing to improve based on conversation signals ensures router is picking best model for task
This commit is contained in:
parent
850fe595ac
commit
dd4a1d2be2
@ -0,0 +1,39 @@
|
||||
-- One row per (router, request_type, model). Hot path on every routing decision.
|
||||
CREATE TABLE "LiteLLM_AdaptiveRouterState" (
|
||||
router_name TEXT NOT NULL,
|
||||
request_type TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
alpha DOUBLE PRECISION NOT NULL,
|
||||
beta DOUBLE PRECISION NOT NULL,
|
||||
total_samples INTEGER NOT NULL DEFAULT 0,
|
||||
last_updated_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (router_name, request_type, model_name)
|
||||
);
|
||||
|
||||
-- One row per (session, router, model). Updated per turn via the queue.
|
||||
CREATE TABLE "LiteLLM_AdaptiveRouterSession" (
|
||||
session_id TEXT NOT NULL,
|
||||
router_name TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
classified_type TEXT NOT NULL,
|
||||
misalignment_count INTEGER DEFAULT 0,
|
||||
stagnation_count INTEGER DEFAULT 0,
|
||||
disengagement_count INTEGER DEFAULT 0,
|
||||
satisfaction_count INTEGER DEFAULT 0,
|
||||
failure_count INTEGER DEFAULT 0,
|
||||
loop_count INTEGER DEFAULT 0,
|
||||
exhaustion_count INTEGER DEFAULT 0,
|
||||
last_user_content TEXT,
|
||||
last_assistant_content TEXT,
|
||||
tool_call_history JSONB DEFAULT '[]',
|
||||
pending_tool_calls JSONB DEFAULT '{}',
|
||||
turn_count INTEGER DEFAULT 0,
|
||||
last_processed_turn INTEGER DEFAULT -1,
|
||||
clean_credit_awarded BOOLEAN DEFAULT FALSE,
|
||||
terminal_status INTEGER,
|
||||
last_activity_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (session_id, router_name, model_name)
|
||||
);
|
||||
|
||||
CREATE INDEX "idx_adaptive_router_session_activity"
|
||||
ON "LiteLLM_AdaptiveRouterSession" (last_activity_at);
|
||||
@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
|
||||
|
||||
@@map("LiteLLM_ClaudeCodePluginTable")
|
||||
}
|
||||
|
||||
// Per-(router, request_type, model) Beta posterior for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterState {
|
||||
router_name String
|
||||
request_type String
|
||||
model_name String
|
||||
alpha Float
|
||||
beta Float
|
||||
total_samples Int @default(0)
|
||||
last_updated_at DateTime @default(now())
|
||||
|
||||
@@id([router_name, request_type, model_name])
|
||||
}
|
||||
|
||||
// Per-(session, router, model) signal counters for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterSession {
|
||||
session_id String
|
||||
router_name String
|
||||
model_name String
|
||||
classified_type String
|
||||
|
||||
misalignment_count Int @default(0)
|
||||
stagnation_count Int @default(0)
|
||||
disengagement_count Int @default(0)
|
||||
satisfaction_count Int @default(0)
|
||||
failure_count Int @default(0)
|
||||
loop_count Int @default(0)
|
||||
exhaustion_count Int @default(0)
|
||||
|
||||
last_user_content String?
|
||||
last_assistant_content String?
|
||||
tool_call_history Json @default("[]")
|
||||
pending_tool_calls Json @default("{}")
|
||||
|
||||
turn_count Int @default(0)
|
||||
last_processed_turn Int @default(-1)
|
||||
clean_credit_awarded Boolean @default(false)
|
||||
terminal_status Int?
|
||||
last_activity_at DateTime @default(now())
|
||||
|
||||
@@id([session_id, router_name, model_name])
|
||||
@@index([last_activity_at])
|
||||
}
|
||||
|
||||
@ -1,32 +1,83 @@
|
||||
# model_list:
|
||||
# - model_name: claude-sonnet-4-6
|
||||
# litellm_params: {model: anthropic/claude-sonnet-4-6}
|
||||
# model_info:
|
||||
# litellm_routing_preferences:
|
||||
# quality_tier: 1
|
||||
# keywords: [tin]
|
||||
# - model_name: gpt-4o-mini
|
||||
# litellm_params: {model: openai/gpt-4o-mini}
|
||||
# model_info:
|
||||
# litellm_routing_preferences:
|
||||
# quality_tier: 1
|
||||
# keywords: []
|
||||
# - model_name: gpt-4o
|
||||
# litellm_params: {model: openai/gpt-4o}
|
||||
# model_info:
|
||||
# litellm_routing_preferences:
|
||||
# quality_tier: 2
|
||||
# keywords: [vision, function_calling]
|
||||
# - model_name: opus
|
||||
# litellm_params: {model: anthropic/claude-opus-4-7}
|
||||
# model_info:
|
||||
# litellm_routing_preferences:
|
||||
# quality_tier: 3
|
||||
# keywords: ["architecture", "design"]
|
||||
# - model_name: my-quality-router
|
||||
# litellm_params:
|
||||
# model: auto_router/adaptive_router
|
||||
# adaptive_router_default_model: gpt-4o-mini
|
||||
# adaptive_router_config:
|
||||
# available_models: [gpt-4o-mini, gpt-4o, opus, claude-sonnet-4-6]
|
||||
# Example proxy config for the adaptive router (v0).
|
||||
#
|
||||
# Wires one logical router ("smart-cheap-router") that adaptively picks between
|
||||
# two real deployments ("fast" and "smart") based on per-session feedback signals.
|
||||
#
|
||||
# How to use from a client:
|
||||
# POST /v1/chat/completions { "model": "smart-cheap-router", ... }
|
||||
# Add { "metadata": { "litellm_session_id": "<your-session-id>" } } to enable
|
||||
# sticky-session routing within a conversation.
|
||||
#
|
||||
# Required env vars: OPENAI_API_KEY, DATABASE_URL.
|
||||
|
||||
model_list:
|
||||
|
||||
# OpenAI model for /v1/chat/completions test — 200x custom pricing
|
||||
- model_name: "gpt-4.1-mini"
|
||||
# ---- The adaptive router "control" deployment -------------------------
|
||||
# `model_name` is what clients call. `available_models` lists the underlying
|
||||
# deployments the router is allowed to pick from (must match other model_name
|
||||
# entries in this list).
|
||||
- model_name: smart-cheap-router
|
||||
litellm_params:
|
||||
model: openai/gpt-4.1-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
id: gpt-4.1-mini-custom-pricing
|
||||
input_cost_per_token: 0.00004 # 100x standard ($0.40/1M = $0.0000004)
|
||||
output_cost_per_token: 0.00016 # 100x standard ($1.60/1M = $0.0000016)
|
||||
model: auto_router/adaptive_router
|
||||
adaptive_router_config:
|
||||
available_models: ["fast", "smart"]
|
||||
weights:
|
||||
quality: 0.7
|
||||
cost: 0.3
|
||||
|
||||
# OpenAI model for /v1/responses test — 100x custom pricing
|
||||
- model_name: "gpt-5"
|
||||
# ---- Underlying deployments the router picks from ---------------------
|
||||
- model_name: fast
|
||||
litellm_params:
|
||||
model: openai/gpt-5
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
id: gpt-5-custom-pricing
|
||||
mode: "chat"
|
||||
input_cost_per_token: 125 # 100x standard ($1.25/1M = $0.00000125)
|
||||
output_cost_per_token: 10 # 100x standard ($10.00/1M = $0.00001)
|
||||
|
||||
# Anthropic model for /v1/messages test — 100x custom pricing
|
||||
- model_name: "claude-sonnet-4-20250514"
|
||||
litellm_params:
|
||||
model: anthropic/claude-sonnet-4-20250514
|
||||
model: anthropic/claude-sonnet-4-6
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
input_cost_per_token: 0.00000015
|
||||
model_info:
|
||||
id: claude-sonnet-4-custom-pricing
|
||||
input_cost_per_token: 0.0003 # 100x standard ($0.000003)
|
||||
output_cost_per_token: 0.0015 # 100x standard ($0.000015)
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 2
|
||||
strengths: []
|
||||
|
||||
- model_name: smart
|
||||
litellm_params:
|
||||
model: anthropic/claude-opus-4-7
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
input_cost_per_token: 0.0000050
|
||||
model_info:
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 3
|
||||
strengths: ["code_generation", "technical_design", "analytical_reasoning"]
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # REPLACE in production
|
||||
|
||||
@ -0,0 +1,216 @@
|
||||
"""
|
||||
In-memory queues for adaptive router state and session updates.
|
||||
|
||||
Pattern follows DailySpendUpdateQueue: hot path is fully in-memory; a background
|
||||
flusher task drains the aggregator and writes batches to Postgres.
|
||||
|
||||
Two logical queues (one class):
|
||||
1. STATE updates: increments to (router, request_type, model) bandit cell.
|
||||
Aggregator key = (router_name, request_type, model_name)
|
||||
Aggregated payload = {"delta_alpha": float, "delta_beta": float, "samples_added": int}
|
||||
2. SESSION updates: full snapshot of a session row (last-write-wins per session+router+model).
|
||||
Aggregator key = (session_id, router_name, model_name)
|
||||
Aggregated payload = the full session state dict.
|
||||
|
||||
Hot-path API is non-blocking and synchronous from the caller's POV (it just appends
|
||||
to the in-memory aggregator). Flush is async and batched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
StateKey = Tuple[str, str, str] # (router_name, request_type, model_name)
|
||||
SessionKey = Tuple[str, str, str] # (session_id, router_name, model_name)
|
||||
|
||||
|
||||
class AdaptiveRouterUpdateQueue:
|
||||
"""
|
||||
Single class managing both state-update aggregation and session-snapshot aggregation.
|
||||
Held by the AdaptiveRouter strategy instance and started by the proxy on boot.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_agg: Dict[StateKey, Dict[str, float]] = {}
|
||||
self._session_agg: Dict[SessionKey, Dict[str, Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_state_size_seen = 0
|
||||
self._max_session_size_seen = 0
|
||||
|
||||
# ---- Hot-path: state delta -------------------------------------------
|
||||
|
||||
async def add_state_delta(
|
||||
self,
|
||||
router_name: str,
|
||||
request_type: str,
|
||||
model_name: str,
|
||||
delta_alpha: float,
|
||||
delta_beta: float,
|
||||
) -> None:
|
||||
"""Aggregate a bandit-cell delta. Multiple deltas to the same cell sum."""
|
||||
key: StateKey = (router_name, request_type, model_name)
|
||||
async with self._lock:
|
||||
current = self._state_agg.get(key)
|
||||
if current is None:
|
||||
self._state_agg[key] = {
|
||||
"delta_alpha": delta_alpha,
|
||||
"delta_beta": delta_beta,
|
||||
"samples_added": 1,
|
||||
}
|
||||
else:
|
||||
current["delta_alpha"] += delta_alpha
|
||||
current["delta_beta"] += delta_beta
|
||||
current["samples_added"] += 1
|
||||
if len(self._state_agg) > self._max_state_size_seen:
|
||||
self._max_state_size_seen = len(self._state_agg)
|
||||
|
||||
# ---- Hot-path: session snapshot --------------------------------------
|
||||
|
||||
async def add_session_state(
|
||||
self,
|
||||
session_id: str,
|
||||
router_name: str,
|
||||
model_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Last-write-wins per session row. The state_dict is a snapshot of the
|
||||
SessionState (signals counts + bookkeeping fields). The flusher will
|
||||
upsert this into LiteLLM_AdaptiveRouterSession.
|
||||
"""
|
||||
key: SessionKey = (session_id, router_name, model_name)
|
||||
async with self._lock:
|
||||
self._session_agg[key] = state_dict
|
||||
if len(self._session_agg) > self._max_session_size_seen:
|
||||
self._max_session_size_seen = len(self._session_agg)
|
||||
|
||||
# ---- Flushers (called by background task) ----------------------------
|
||||
|
||||
async def flush_state_to_db(self, prisma_client: Any) -> int:
|
||||
"""
|
||||
Drain state aggregator and apply to LiteLLM_AdaptiveRouterState.
|
||||
Returns number of cells flushed.
|
||||
"""
|
||||
async with self._lock:
|
||||
batch = self._state_agg
|
||||
self._state_agg = {}
|
||||
|
||||
if not batch:
|
||||
return 0
|
||||
|
||||
# Sort keys to give deterministic write order across writers and
|
||||
# reduce the chance of cross-row deadlocks when other workers race us.
|
||||
for key in sorted(batch.keys()):
|
||||
router, rt, model = key
|
||||
payload = batch[key]
|
||||
try:
|
||||
existing = (
|
||||
await prisma_client.db.litellm_adaptiverouterstate.find_unique(
|
||||
where={
|
||||
"router_name_request_type_model_name": {
|
||||
"router_name": router,
|
||||
"request_type": rt,
|
||||
"model_name": model,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
new_alpha = (existing.alpha if existing else 0.0) + payload[
|
||||
"delta_alpha"
|
||||
]
|
||||
new_beta = (existing.beta if existing else 0.0) + payload["delta_beta"]
|
||||
new_samples = (existing.total_samples if existing else 0) + int(
|
||||
payload["samples_added"]
|
||||
)
|
||||
await prisma_client.db.litellm_adaptiverouterstate.upsert(
|
||||
where={
|
||||
"router_name_request_type_model_name": {
|
||||
"router_name": router,
|
||||
"request_type": rt,
|
||||
"model_name": model,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"router_name": router,
|
||||
"request_type": rt,
|
||||
"model_name": model,
|
||||
"alpha": new_alpha,
|
||||
"beta": new_beta,
|
||||
"total_samples": new_samples,
|
||||
},
|
||||
"update": {
|
||||
"alpha": new_alpha,
|
||||
"beta": new_beta,
|
||||
"total_samples": new_samples,
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"AdaptiveRouterUpdateQueue: failed to flush state for %s: %s",
|
||||
key,
|
||||
e,
|
||||
)
|
||||
|
||||
return len(batch)
|
||||
|
||||
async def flush_session_to_db(self, prisma_client: Any) -> int:
|
||||
"""
|
||||
Drain session aggregator and upsert into LiteLLM_AdaptiveRouterSession.
|
||||
Returns number of session rows flushed.
|
||||
"""
|
||||
async with self._lock:
|
||||
batch = self._session_agg
|
||||
self._session_agg = {}
|
||||
|
||||
if not batch:
|
||||
return 0
|
||||
|
||||
for key in sorted(batch.keys()):
|
||||
session_id, router, model = key
|
||||
payload = batch[key]
|
||||
try:
|
||||
# NOTE: Prisma client lower-cases model names, so
|
||||
# `LiteLLM_AdaptiveRouterSession` -> `litellm_adaptiveroutersession`
|
||||
# (single 's', not 'litellm_adaptiverouterssession').
|
||||
await prisma_client.db.litellm_adaptiveroutersession.upsert(
|
||||
where={
|
||||
"session_id_router_name_model_name": {
|
||||
"session_id": session_id,
|
||||
"router_name": router,
|
||||
"model_name": model,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"session_id": session_id,
|
||||
"router_name": router,
|
||||
"model_name": model,
|
||||
**payload,
|
||||
},
|
||||
"update": payload,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"AdaptiveRouterUpdateQueue: failed to flush session for %s: %s",
|
||||
key,
|
||||
e,
|
||||
)
|
||||
|
||||
return len(batch)
|
||||
|
||||
# ---- Observability ---------------------------------------------------
|
||||
|
||||
async def queue_size(self) -> Dict[str, int]:
|
||||
async with self._lock:
|
||||
return {
|
||||
"state_pending": len(self._state_agg),
|
||||
"session_pending": len(self._session_agg),
|
||||
"max_state_seen": self._max_state_size_seen,
|
||||
"max_session_seen": self._max_session_size_seen,
|
||||
}
|
||||
@ -0,0 +1,52 @@
|
||||
# Example proxy config for the adaptive router (v0).
|
||||
#
|
||||
# Wires one logical router ("smart-cheap-router") that adaptively picks between
|
||||
# two real deployments ("fast" and "smart") based on per-session feedback signals.
|
||||
#
|
||||
# How to use from a client:
|
||||
# POST /v1/chat/completions { "model": "smart-cheap-router", ... }
|
||||
# Add { "metadata": { "litellm_session_id": "<your-session-id>" } } to enable
|
||||
# sticky-session routing within a conversation.
|
||||
#
|
||||
# Required env vars: OPENAI_API_KEY, DATABASE_URL.
|
||||
|
||||
model_list:
|
||||
# ---- The adaptive router "control" deployment -------------------------
|
||||
# `model_name` is what clients call. `available_models` lists the underlying
|
||||
# deployments the router is allowed to pick from (must match other model_name
|
||||
# entries in this list).
|
||||
- model_name: smart-cheap-router
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-mini # placeholder; never actually called -- router picks from available_models
|
||||
adaptive_router_config:
|
||||
available_models: ["fast", "smart"]
|
||||
weights:
|
||||
quality: 0.7
|
||||
cost: 0.3
|
||||
|
||||
# ---- Underlying deployments the router picks from ---------------------
|
||||
- model_name: fast
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
input_cost_per_token: 0.00000015
|
||||
model_info:
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 2
|
||||
strengths: []
|
||||
|
||||
- model_name: smart
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
input_cost_per_token: 0.0000050
|
||||
model_info:
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 3
|
||||
strengths: ["code_generation", "technical_design", "analytical_reasoning"]
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # REPLACE in production
|
||||
@ -952,6 +952,10 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
|
||||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
# Start adaptive-router queue flusher if any AdaptiveRouter is configured.
|
||||
if llm_router is not None and getattr(llm_router, "adaptive_routers", None):
|
||||
asyncio.create_task(_adaptive_router_flusher_loop())
|
||||
|
||||
## [Optional] Initialize dd tracer
|
||||
ProxyStartupEvent._init_dd_tracer()
|
||||
|
||||
@ -2201,9 +2205,11 @@ def run_ollama_serve():
|
||||
with open(os.devnull, "w") as devnull:
|
||||
subprocess.Popen(command, stdout=devnull, stderr=devnull)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"""
|
||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _get_process_rss_mb() -> Optional[float]:
|
||||
@ -2385,6 +2391,31 @@ def _write_health_state_to_router_cache(
|
||||
)
|
||||
|
||||
|
||||
_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS = 10
|
||||
|
||||
|
||||
async def _adaptive_router_flusher_loop():
|
||||
"""
|
||||
Drain every AdaptiveRouter's in-memory state + session aggregators into
|
||||
Postgres on a fixed cadence. Hot-path writes go to memory; this loop is
|
||||
the only writer to the adaptive router DB tables.
|
||||
"""
|
||||
global llm_router, prisma_client
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS)
|
||||
adaptive_routers = getattr(llm_router, "adaptive_routers", None) or {}
|
||||
if not adaptive_routers or prisma_client is None:
|
||||
continue
|
||||
for ar in adaptive_routers.values():
|
||||
await ar.queue.flush_state_to_db(prisma_client)
|
||||
await ar.queue.flush_session_to_db(prisma_client)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
verbose_proxy_logger.exception("adaptive_router flusher iteration failed")
|
||||
|
||||
|
||||
async def _run_background_health_check():
|
||||
"""
|
||||
Periodically run health checks in the background on the endpoints.
|
||||
@ -13877,6 +13908,38 @@ async def home(request: Request):
|
||||
return "LiteLLM: RUNNING"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/adaptive_router/state",
|
||||
tags=["adaptive_router"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_adaptive_router_state(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""Return live bandit posteriors + queue depth for every configured adaptive router.
|
||||
|
||||
Admin-only. Returns 404 if no adaptive router is configured.
|
||||
|
||||
Response shape: `{"routers": [<snapshot>, ...]}` — one snapshot per
|
||||
adaptive-router deployment. Each snapshot's `router_name` field identifies
|
||||
which deployment it came from.
|
||||
"""
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
if llm_router is None or not llm_router.adaptive_routers:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": "No adaptive_router is configured on this proxy."},
|
||||
)
|
||||
snapshots = [
|
||||
await ar.get_state_snapshot() for ar in llm_router.adaptive_routers.values()
|
||||
]
|
||||
return {"routers": snapshots}
|
||||
|
||||
|
||||
@router.get("/routes", dependencies=[Depends(user_api_key_auth)])
|
||||
async def get_routes():
|
||||
"""
|
||||
|
||||
@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
|
||||
|
||||
@@map("LiteLLM_ClaudeCodePluginTable")
|
||||
}
|
||||
|
||||
// Per-(router, request_type, model) Beta posterior for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterState {
|
||||
router_name String
|
||||
request_type String
|
||||
model_name String
|
||||
alpha Float
|
||||
beta Float
|
||||
total_samples Int @default(0)
|
||||
last_updated_at DateTime @default(now())
|
||||
|
||||
@@id([router_name, request_type, model_name])
|
||||
}
|
||||
|
||||
// Per-(session, router, model) signal counters for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterSession {
|
||||
session_id String
|
||||
router_name String
|
||||
model_name String
|
||||
classified_type String
|
||||
|
||||
misalignment_count Int @default(0)
|
||||
stagnation_count Int @default(0)
|
||||
disengagement_count Int @default(0)
|
||||
satisfaction_count Int @default(0)
|
||||
failure_count Int @default(0)
|
||||
loop_count Int @default(0)
|
||||
exhaustion_count Int @default(0)
|
||||
|
||||
last_user_content String?
|
||||
last_assistant_content String?
|
||||
tool_call_history Json @default("[]")
|
||||
pending_tool_calls Json @default("{}")
|
||||
|
||||
turn_count Int @default(0)
|
||||
last_processed_turn Int @default(-1)
|
||||
clean_credit_awarded Boolean @default(false)
|
||||
terminal_status Int?
|
||||
last_activity_at DateTime @default(now())
|
||||
|
||||
@@id([session_id, router_name, model_name])
|
||||
@@index([last_activity_at])
|
||||
}
|
||||
|
||||
@ -200,12 +200,16 @@ if TYPE_CHECKING:
|
||||
from litellm.router_strategy.complexity_router.complexity_router import (
|
||||
ComplexityRouter,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import (
|
||||
AdaptiveRouter,
|
||||
)
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
AutoRouter = Any
|
||||
ComplexityRouter = Any
|
||||
AdaptiveRouter = Any
|
||||
PreRoutingHookResponse = Any
|
||||
|
||||
|
||||
@ -464,6 +468,7 @@ class Router:
|
||||
) # {"TEAM_ID": PatternMatchRouter}
|
||||
self.auto_routers: Dict[str, "AutoRouter"] = {}
|
||||
self.complexity_routers: Dict[str, "ComplexityRouter"] = {}
|
||||
self.adaptive_routers: Dict[str, "AdaptiveRouter"] = {}
|
||||
|
||||
# Initialize model_group_alias early since it's used in set_model_list
|
||||
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
|
||||
@ -3864,7 +3869,7 @@ class Router:
|
||||
self._add_deployment_model_to_endpoint_for_llm_passthrough_route(
|
||||
kwargs=kwargs, model=model, model_name=model_name
|
||||
)
|
||||
|
||||
|
||||
# Get custom_llm_provider from deployment params
|
||||
try:
|
||||
custom_llm_provider = data.get("custom_llm_provider")
|
||||
@ -3872,10 +3877,12 @@ class Router:
|
||||
model=data["model"],
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
|
||||
custom_llm_provider = (
|
||||
custom_llm_provider or inferred_custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
custom_llm_provider = None
|
||||
|
||||
|
||||
# Build response kwargs
|
||||
response_kwargs = {
|
||||
**data,
|
||||
@ -3885,7 +3892,7 @@ class Router:
|
||||
# Only set custom_llm_provider if it's not None
|
||||
if custom_llm_provider is not None:
|
||||
response_kwargs["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
|
||||
response = original_generic_function(**response_kwargs)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
@ -3981,7 +3988,9 @@ class Router:
|
||||
model=data["model"],
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
|
||||
custom_llm_provider = (
|
||||
custom_llm_provider or inferred_custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
custom_llm_provider = None
|
||||
|
||||
@ -4246,7 +4255,9 @@ class Router:
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# Preserve explicitly stored provider, fallback to inferred
|
||||
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
|
||||
custom_llm_provider = (
|
||||
custom_llm_provider or inferred_custom_llm_provider
|
||||
)
|
||||
|
||||
## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ##
|
||||
purpose = cast(Optional[OpenAIFilesPurpose], kwargs.get("purpose"))
|
||||
@ -5355,9 +5366,9 @@ class Router:
|
||||
e,
|
||||
(litellm.ContextWindowExceededError, litellm.ContentPolicyViolationError),
|
||||
)
|
||||
_request_team_id: Optional[str] = (
|
||||
kwargs.get("metadata", {}) or {}
|
||||
).get("user_api_key_team_id")
|
||||
_request_team_id: Optional[str] = (kwargs.get("metadata", {}) or {}).get(
|
||||
"user_api_key_team_id"
|
||||
)
|
||||
all_deployments = self._get_all_deployments(
|
||||
model_name=original_model_group, team_id=_request_team_id
|
||||
)
|
||||
@ -6804,10 +6815,13 @@ class Router:
|
||||
Check if the deployment is an auto-router deployment (semantic router).
|
||||
|
||||
Returns True if the litellm_params model starts with "auto_router/"
|
||||
but NOT "auto_router/complexity_router" (which uses complexity routing).
|
||||
but NOT "auto_router/complexity_router" or "auto_router/adaptive_router"
|
||||
(which use the complexity-router and adaptive-router strategies).
|
||||
"""
|
||||
if litellm_params.model.startswith("auto_router/complexity_router"):
|
||||
return False # This is handled by complexity_router
|
||||
if litellm_params.model.startswith("auto_router/adaptive_router"):
|
||||
return False # This is handled by adaptive_router
|
||||
if litellm_params.model.startswith("auto_router/"):
|
||||
return True
|
||||
return False
|
||||
@ -6914,6 +6928,121 @@ class Router:
|
||||
)
|
||||
self.complexity_routers[deployment.model_name] = complexity_router
|
||||
|
||||
def _is_adaptive_router_deployment(self, litellm_params: LiteLLM_Params) -> bool:
|
||||
"""True when this deployment opts in via the `auto_router/adaptive_router` model prefix."""
|
||||
return litellm_params.model.startswith("auto_router/adaptive_router")
|
||||
|
||||
def _finalize_adaptive_router_if_configured(self) -> None:
|
||||
"""Locate every adaptive-router deployment in the finalized model_list and
|
||||
build an AdaptiveRouter for each. Safe no-op when none are configured.
|
||||
Idempotent: skips any deployment whose model_name is already initialized."""
|
||||
for entry in self.model_list or []:
|
||||
lp = (
|
||||
entry.get("litellm_params")
|
||||
if isinstance(entry, dict)
|
||||
else entry.litellm_params
|
||||
)
|
||||
lp_model = (
|
||||
(lp.get("model") if isinstance(lp, dict) else lp.model) if lp else None
|
||||
)
|
||||
if not (lp_model and lp_model.startswith("auto_router/adaptive_router")):
|
||||
continue
|
||||
model_name = (
|
||||
entry.get("model_name") if isinstance(entry, dict) else entry.model_name
|
||||
)
|
||||
if not model_name or not lp:
|
||||
continue
|
||||
if model_name in self.adaptive_routers:
|
||||
continue
|
||||
deployment = Deployment(
|
||||
model_name=model_name,
|
||||
litellm_params=(
|
||||
lp if not isinstance(lp, dict) else LiteLLM_Params(**lp)
|
||||
),
|
||||
model_info=(
|
||||
entry.get("model_info")
|
||||
if isinstance(entry, dict)
|
||||
else entry.model_info
|
||||
),
|
||||
)
|
||||
self.init_adaptive_router_deployment(deployment=deployment)
|
||||
|
||||
def init_adaptive_router_deployment(self, deployment: Deployment) -> None:
|
||||
"""
|
||||
Build an AdaptiveRouter instance for this deployment and register its
|
||||
post-call hook. Multiple adaptive routers can coexist on a single Router,
|
||||
keyed by `deployment.model_name`.
|
||||
|
||||
`model_to_prefs` and `model_to_cost` are derived from the OTHER models
|
||||
already registered in `self.model_list` whose `model_name` appears in
|
||||
`available_models`. Models not yet registered fall back to defaults.
|
||||
"""
|
||||
# Local import: AdaptiveRouter -> hooks -> classifier all import litellm
|
||||
# internals which transitively import this module. (AGENTS.md exception clause.)
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import (
|
||||
AdaptiveRouter,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.hooks import (
|
||||
AdaptiveRouterPostCallHook,
|
||||
)
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
)
|
||||
|
||||
raw_config = deployment.litellm_params.adaptive_router_config
|
||||
if raw_config is None:
|
||||
raise ValueError(
|
||||
"adaptive_router_config is required for adaptive-router deployments."
|
||||
)
|
||||
|
||||
config = AdaptiveRouterConfig(**raw_config)
|
||||
|
||||
model_to_prefs: Dict[str, AdaptiveRouterPreferences] = {}
|
||||
model_to_cost: Dict[str, float] = {}
|
||||
for d in self.model_list or []:
|
||||
name = d.get("model_name") if isinstance(d, dict) else d.model_name
|
||||
if name not in config.available_models:
|
||||
continue
|
||||
mi = d.get("model_info") if isinstance(d, dict) else d.model_info
|
||||
mi_dict: Dict[str, Any] = (
|
||||
mi if isinstance(mi, dict) else (mi.model_dump() if mi else {})
|
||||
)
|
||||
prefs_raw = mi_dict.get("adaptive_router_preferences")
|
||||
if prefs_raw is not None:
|
||||
model_to_prefs[name] = AdaptiveRouterPreferences(**prefs_raw)
|
||||
|
||||
# `input_cost_per_token` is a LiteLLM_Params field per types/router.py.
|
||||
lp = d.get("litellm_params") if isinstance(d, dict) else d.litellm_params
|
||||
lp_dict: Dict[str, Any] = (
|
||||
lp if isinstance(lp, dict) else (lp.model_dump() if lp else {})
|
||||
)
|
||||
cost = lp_dict.get("input_cost_per_token")
|
||||
if cost is not None:
|
||||
model_to_cost[name] = float(cost)
|
||||
|
||||
if deployment.model_name in self.adaptive_routers:
|
||||
raise ValueError(
|
||||
f"Adaptive-router deployment {deployment.model_name} already exists. "
|
||||
"Please use a different model name."
|
||||
)
|
||||
|
||||
adaptive_router = AdaptiveRouter(
|
||||
router_name=deployment.model_name,
|
||||
config=config,
|
||||
model_to_prefs=model_to_prefs,
|
||||
model_to_cost=model_to_cost,
|
||||
)
|
||||
self.adaptive_routers[deployment.model_name] = adaptive_router
|
||||
litellm.callbacks.append(
|
||||
AdaptiveRouterPostCallHook(adaptive_router=adaptive_router)
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
"AdaptiveRouter[%s] initialized with %d models",
|
||||
deployment.model_name,
|
||||
len(config.available_models),
|
||||
)
|
||||
|
||||
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
|
||||
"""
|
||||
Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments
|
||||
@ -7007,6 +7136,10 @@ class Router:
|
||||
# Note: model_name_to_deployment_indices is already built incrementally
|
||||
# by _create_deployment -> _add_model_to_list_and_index_map
|
||||
|
||||
# Deferred: build the AdaptiveRouter strategy now that all underlying
|
||||
# deployments are visible in self.model_list.
|
||||
self._finalize_adaptive_router_if_configured()
|
||||
|
||||
def _add_deployment(self, deployment: Deployment) -> Deployment:
|
||||
import os
|
||||
|
||||
@ -7134,6 +7267,11 @@ class Router:
|
||||
):
|
||||
self.init_complexity_router_deployment(deployment=deployment)
|
||||
|
||||
# NOTE: adaptive-router deployments are deferred to the end of
|
||||
# set_model_list() because their init needs visibility into the OTHER
|
||||
# deployments listed in `available_models` (which may not yet have
|
||||
# been processed when this one is created).
|
||||
|
||||
return deployment
|
||||
|
||||
def _initialize_deployment_for_pass_through(
|
||||
@ -9645,6 +9783,19 @@ class Router:
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Check if an adaptive-router should be used
|
||||
#########################################################
|
||||
adaptive_router = self.adaptive_routers.get(model)
|
||||
if adaptive_router is not None:
|
||||
return await adaptive_router.async_pre_routing_hook(
|
||||
model=model,
|
||||
request_kwargs=request_kwargs,
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_available_deployment(
|
||||
|
||||
93
litellm/router_strategy/adaptive_router/README.md
Normal file
93
litellm/router_strategy/adaptive_router/README.md
Normal file
@ -0,0 +1,93 @@
|
||||
# Adaptive Router (v0)
|
||||
|
||||
A request-type-aware routing strategy. For each incoming request, classify the
|
||||
prompt into one of seven `RequestType` buckets (code generation, writing,
|
||||
analytical reasoning, …), then Thompson-sample a Beta(α, β) bandit posterior
|
||||
per `(request_type, model)` cell to pick the best model. Quality estimates are
|
||||
combined with a normalized cost score via a weighted linear sum.
|
||||
|
||||
A post-call hook reads the response and runs lightweight regex + tool-call
|
||||
detectors (see `signals.py`) to award per-turn credit/blame to the model that
|
||||
served the turn. Updates are batched in-memory and flushed to Postgres every
|
||||
~10s by a background task in `proxy_server.py`.
|
||||
|
||||
## Config example
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
model_info:
|
||||
input_cost_per_token: 0.0000025
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 3
|
||||
strengths: ["code_generation", "analytical_reasoning"]
|
||||
|
||||
- model_name: gpt-4o-mini
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-mini
|
||||
model_info:
|
||||
input_cost_per_token: 0.00000015
|
||||
adaptive_router_preferences:
|
||||
quality_tier: 2
|
||||
strengths: ["general", "factual_lookup"]
|
||||
|
||||
- model_name: smart-router
|
||||
litellm_params:
|
||||
model: adaptive_router/smart-router
|
||||
adaptive_router_default_model: gpt-4o-mini
|
||||
adaptive_router_config:
|
||||
available_models: ["gpt-4o", "gpt-4o-mini"]
|
||||
weights:
|
||||
quality: 0.7
|
||||
cost: 0.3
|
||||
```
|
||||
|
||||
Callers may pass header `x-litellm-min-quality-tier: 3` (or metadata key
|
||||
`min_quality_tier: 3`) to force selection from tier-3-or-higher models only.
|
||||
|
||||
## Behavior summary
|
||||
|
||||
- **Cold start.** Each `(request_type, model)` cell starts with a
|
||||
Beta prior whose mean = `BASE_TIER_WEIGHT[tier] (+ STRENGTH_BONUS if declared)`
|
||||
and total mass = `COLD_START_MASS` (10). About ten real observations move it
|
||||
meaningfully.
|
||||
- **Per-request decision.** Sample once per eligible model, score with
|
||||
`quality_weight·sample + cost_weight·normalized_cost`, pick the argmax.
|
||||
Routing is stateless per-turn — no sticky lookup. Each call resamples.
|
||||
- **Owner-cache attribution.** Post-call, the conversation's first picked
|
||||
model claims an "owner slot" for `OWNER_CACHE_TTL_SECONDS` (24h). Later
|
||||
turns of the same conversation only fire bandit/state updates if the
|
||||
same model handled them — mismatches are dropped (no attribution) and
|
||||
counted in `skipped_updates_total`. Conversation identity is the
|
||||
client-supplied `litellm_session_id` if present, otherwise a sha256 over
|
||||
caller identity (api key hash, team, user, end-user) + the first message.
|
||||
- **Per-turn updates.** `satisfaction → +α`. `misalignment, stagnation,
|
||||
disengagement, failure → +β` (each). `loop → +0.5β`. `exhaustion → 0`
|
||||
(uptime, not quality). Skipped if conversation has fewer than
|
||||
`SIGNAL_GATE_MIN_MESSAGES` messages.
|
||||
- **Persistence.** Bandit cells: aggregated deltas, eventually consistent.
|
||||
Session rows: last-write-wins snapshots.
|
||||
|
||||
## Known v0 limitations
|
||||
|
||||
- **Latency is not in the score.** Quality + cost only. A pathologically slow
|
||||
model can still be picked.
|
||||
- **Hard sample cap at 200.** Once `α + β > 200`, deltas are silently dropped.
|
||||
No rescaling — drift is a v1 concern.
|
||||
- **24h owner-cache TTL.** No explicit eviction below TTL. The in-memory map
|
||||
can grow if traffic patterns produce many one-shot sessions.
|
||||
- **Owner-recovery skew.** If model A "owns" a conversation but is then
|
||||
dethroned in the bandit, later turns served by model B are dropped — so
|
||||
bandit updates for that conversation flatline until A's TTL expires.
|
||||
Tracked via `skipped_updates_total`.
|
||||
- **Signals are regex + tool-call only.** No LLM-judge, no embedding similarity,
|
||||
no exemplar storage. Signals are best-effort and biased toward English.
|
||||
- **One AdaptiveRouter per `Router`.** Multiple `adaptive_router/*` deployments
|
||||
on the same `litellm.Router` raise at init.
|
||||
- **Bandit-delta mapping is unvalidated.** `_compute_bandit_delta` is a v0
|
||||
guess; expect to retune after the first ~1000 sessions of real traffic.
|
||||
- **`request_type` is classified per turn from the latest user message only.**
|
||||
The first turn's classification doesn't carry forward; a multi-turn session
|
||||
may shift bucket between turns.
|
||||
6
litellm/router_strategy/adaptive_router/__init__.py
Normal file
6
litellm/router_strategy/adaptive_router/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Adaptive router strategy. See README.md for design overview."""
|
||||
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.router_strategy.adaptive_router.hooks import AdaptiveRouterPostCallHook
|
||||
|
||||
__all__ = ["AdaptiveRouter", "AdaptiveRouterPostCallHook"]
|
||||
344
litellm/router_strategy/adaptive_router/adaptive_router.py
Normal file
344
litellm/router_strategy/adaptive_router/adaptive_router.py
Normal file
@ -0,0 +1,344 @@
|
||||
"""
|
||||
Main adaptive router strategy. See README.md for design overview.
|
||||
|
||||
One AdaptiveRouter instance per router_name. Holds in-memory caches:
|
||||
- _cells: Beta(alpha, beta) bandit posteriors per (request_type, model)
|
||||
- _owner_cache: session_key -> (owner_model, expires_at) — the first model
|
||||
picked for a conversation owns its bandit-update slot
|
||||
- _session_states: (session_key, model) -> SessionState for incremental signal updates
|
||||
|
||||
Owns the AdaptiveRouterUpdateQueue used by the proxy's flusher to persist
|
||||
state and session snapshots back to Postgres.
|
||||
|
||||
Routing is stateless per-turn (Thompson sample fresh on every call). The
|
||||
owner cache is consulted only at post-call time to decide whether a turn's
|
||||
signals should fire a bandit update — turns served by a different model than
|
||||
the conversation's owner are skipped to avoid cross-model misattribution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_last_user_message,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import (
|
||||
AdaptiveRouterUpdateQueue,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.bandit import (
|
||||
BanditCell,
|
||||
apply_delta,
|
||||
initial_cell,
|
||||
pick_best,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
|
||||
OWNER_CACHE_TTL_SECONDS,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.signals import (
|
||||
SessionState,
|
||||
SignalDelta,
|
||||
Turn,
|
||||
apply_turn,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
PreRoutingHookResponse,
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def _default_prefs() -> AdaptiveRouterPreferences:
|
||||
"""Tier-2 prior with no declared strengths; used when a model omits prefs."""
|
||||
return AdaptiveRouterPreferences(quality_tier=2, strengths=[])
|
||||
|
||||
|
||||
class AdaptiveRouter:
|
||||
"""One instance per router_name. Holds in-memory caches + the update queue."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_name: str,
|
||||
config: AdaptiveRouterConfig,
|
||||
model_to_prefs: Dict[str, AdaptiveRouterPreferences],
|
||||
model_to_cost: Dict[str, float],
|
||||
) -> None:
|
||||
self.router_name = router_name
|
||||
self.config = config
|
||||
self.model_to_prefs = model_to_prefs
|
||||
self.model_to_cost = model_to_cost
|
||||
self.queue = AdaptiveRouterUpdateQueue()
|
||||
|
||||
self._cells: Dict[Tuple[RequestType, str], BanditCell] = {}
|
||||
self._owner_cache: Dict[str, Tuple[str, float]] = {}
|
||||
self._session_states: Dict[Tuple[str, str], SessionState] = {}
|
||||
self._skipped_updates_total: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
self._init_cold_start_cells()
|
||||
|
||||
# ---- Cold-start ------------------------------------------------------
|
||||
|
||||
def _init_cold_start_cells(self) -> None:
|
||||
"""Populate _cells with cold-start priors for every (rt, model) combination."""
|
||||
for rt in RequestType:
|
||||
for model in self.config.available_models:
|
||||
prefs = self.model_to_prefs.get(model) or _default_prefs()
|
||||
self._cells[(rt, model)] = initial_cell(prefs, rt)
|
||||
|
||||
async def load_state_from_db(self, prisma_client: Any) -> None:
|
||||
"""Override cold-start cells with persisted state. Called once at startup."""
|
||||
if prisma_client is None:
|
||||
return
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_adaptiverouterstate.find_many(
|
||||
where={"router_name": self.router_name}
|
||||
)
|
||||
loaded = 0
|
||||
for row in rows:
|
||||
try:
|
||||
rt = RequestType(row.request_type)
|
||||
except ValueError:
|
||||
# Unknown taxonomy entry from an older/newer version. Skip.
|
||||
continue
|
||||
if row.model_name not in self.config.available_models:
|
||||
continue
|
||||
self._cells[(rt, row.model_name)] = BanditCell(
|
||||
alpha=row.alpha, beta=row.beta
|
||||
)
|
||||
loaded += 1
|
||||
verbose_router_logger.info(
|
||||
"AdaptiveRouter[%s]: loaded %d cells from DB",
|
||||
self.router_name,
|
||||
loaded,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
"AdaptiveRouter[%s]: failed to load state from DB: %s",
|
||||
self.router_name,
|
||||
e,
|
||||
)
|
||||
|
||||
# ---- Pre-routing hook ------------------------------------------------
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict[str, Any],
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional[PreRoutingHookResponse]:
|
||||
"""
|
||||
Plugin entry point invoked by `Router.async_pre_routing_hook` when the
|
||||
inbound `model` matches this adaptive router's `router_name`.
|
||||
|
||||
Classifies the last user message, picks a logical model via the bandit,
|
||||
and stashes the chosen model on `request_kwargs["metadata"]` so the
|
||||
post-call hook can surface it as a response header.
|
||||
|
||||
Routing is stateless per-turn: every call Thompson-samples fresh,
|
||||
regardless of any prior pick for the same session. Cross-turn
|
||||
attribution is enforced post-call via the owner cache (see
|
||||
`claim_or_check_owner`).
|
||||
"""
|
||||
user_text = (
|
||||
get_last_user_message(cast(List[AllMessageValues], messages or [])) or ""
|
||||
)
|
||||
|
||||
request_type = classify_prompt(user_text)
|
||||
chosen_model = await self.pick_model(request_type=request_type)
|
||||
verbose_router_logger.debug(
|
||||
"AdaptiveRouter[%s]: classified=%s -> chose %s",
|
||||
self.router_name,
|
||||
request_type.value,
|
||||
chosen_model,
|
||||
)
|
||||
|
||||
# Relay the chosen logical model to the post-call hook, which surfaces
|
||||
# it as the `x-litellm-adaptive-router-model` response header. We use
|
||||
# `metadata` (not a top-level kwarg) so the value doesn't leak into
|
||||
# `litellm.acompletion(**input_kwargs)`.
|
||||
kwargs_metadata = request_kwargs.setdefault("metadata", {})
|
||||
if isinstance(kwargs_metadata, dict):
|
||||
kwargs_metadata[ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] = chosen_model
|
||||
|
||||
return PreRoutingHookResponse(model=chosen_model, messages=messages)
|
||||
|
||||
# ---- Pick model ------------------------------------------------------
|
||||
|
||||
async def pick_model(
|
||||
self,
|
||||
request_type: RequestType,
|
||||
min_quality_tier: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Thompson-sample across eligible models. Stateless per-turn."""
|
||||
eligible = self._eligible_models(min_quality_tier)
|
||||
if not eligible:
|
||||
raise ValueError(
|
||||
f"AdaptiveRouter[{self.router_name}]: no models meet "
|
||||
f"min_quality_tier={min_quality_tier}"
|
||||
)
|
||||
|
||||
cells = {m: self._cells[(request_type, m)] for m in eligible}
|
||||
costs = {m: self.model_to_cost.get(m, 0.0) for m in eligible}
|
||||
return pick_best(
|
||||
cells,
|
||||
costs,
|
||||
quality_weight=self.config.weights.quality,
|
||||
cost_weight=self.config.weights.cost,
|
||||
)
|
||||
|
||||
def claim_or_check_owner(self, session_key: str, current_model: str) -> bool:
|
||||
"""Resolve attribution for a turn under stateless routing.
|
||||
|
||||
Returns True iff this turn should fire a bandit/state update. The
|
||||
first call for a `session_key` claims ownership for `current_model`
|
||||
and returns True. Subsequent calls return True only if the owner is
|
||||
still live AND matches `current_model`. Mismatches (a different
|
||||
model handled this turn) and expired owners both increment
|
||||
`_skipped_updates_total` and return False — no attribution.
|
||||
"""
|
||||
now = time.time()
|
||||
existing = self._owner_cache.get(session_key)
|
||||
if existing is not None and existing[1] > now:
|
||||
owner_model, _ = existing
|
||||
if owner_model == current_model:
|
||||
return True
|
||||
self._skipped_updates_total += 1
|
||||
return False
|
||||
|
||||
# No live owner -> claim for current_model.
|
||||
self._owner_cache[session_key] = (
|
||||
current_model,
|
||||
now + OWNER_CACHE_TTL_SECONDS,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_state_snapshot(self) -> Dict[str, Any]:
|
||||
"""In-memory snapshot for the introspection endpoint. Cheap; no DB hit."""
|
||||
cells = []
|
||||
for (rt, model), cell in sorted(
|
||||
self._cells.items(), key=lambda kv: (kv[0][0].value, kv[0][1])
|
||||
):
|
||||
total = cell.alpha + cell.beta
|
||||
cells.append(
|
||||
{
|
||||
"request_type": rt.value,
|
||||
"model": model,
|
||||
"alpha": cell.alpha,
|
||||
"beta": cell.beta,
|
||||
"samples": total,
|
||||
"quality_mean": cell.alpha / total if total > 0 else 0.0,
|
||||
}
|
||||
)
|
||||
queue = await self.queue.queue_size()
|
||||
now = time.time()
|
||||
owner_cache_live = sum(1 for _, exp in self._owner_cache.values() if exp > now)
|
||||
return {
|
||||
"router_name": self.router_name,
|
||||
"available_models": list(self.config.available_models),
|
||||
"weights": {
|
||||
"quality": self.config.weights.quality,
|
||||
"cost": self.config.weights.cost,
|
||||
},
|
||||
"model_costs": dict(self.model_to_cost),
|
||||
"cells": cells,
|
||||
"owner_cache_live": owner_cache_live,
|
||||
"skipped_updates_total": self._skipped_updates_total,
|
||||
"queue": queue,
|
||||
}
|
||||
|
||||
def _eligible_models(self, min_quality_tier: Optional[int]) -> List[str]:
|
||||
if min_quality_tier is None:
|
||||
return list(self.config.available_models)
|
||||
return [
|
||||
m
|
||||
for m in self.config.available_models
|
||||
if (self.model_to_prefs.get(m) or _default_prefs()).quality_tier
|
||||
>= min_quality_tier
|
||||
]
|
||||
|
||||
# ---- Session state ---------------------------------------------------
|
||||
|
||||
def get_or_create_session_state(
|
||||
self,
|
||||
session_id: str,
|
||||
model_name: str,
|
||||
request_type: RequestType,
|
||||
) -> SessionState:
|
||||
key = (session_id, model_name)
|
||||
state = self._session_states.get(key)
|
||||
if state is None:
|
||||
state = SessionState(
|
||||
session_id=session_id,
|
||||
router_name=self.router_name,
|
||||
model_name=model_name,
|
||||
classified_type=request_type.value,
|
||||
)
|
||||
self._session_states[key] = state
|
||||
return state
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
session_id: str,
|
||||
model_name: str,
|
||||
request_type: RequestType,
|
||||
turn: Turn,
|
||||
) -> SignalDelta:
|
||||
"""Apply one turn, push session snapshot + bandit deltas to the queue."""
|
||||
state = self.get_or_create_session_state(session_id, model_name, request_type)
|
||||
delta = apply_turn(state, turn)
|
||||
print("CALLS DELTA", delta)
|
||||
|
||||
snapshot = asdict(state)
|
||||
await self.queue.add_session_state(
|
||||
session_id, self.router_name, model_name, snapshot
|
||||
)
|
||||
|
||||
d_alpha, d_beta = self._compute_bandit_delta(delta)
|
||||
print("CALLS D_ALPHA", d_alpha)
|
||||
if d_alpha != 0 or d_beta != 0:
|
||||
cell_key = (request_type, model_name)
|
||||
self._cells[cell_key] = apply_delta(self._cells[cell_key], d_alpha, d_beta)
|
||||
await self.queue.add_state_delta(
|
||||
self.router_name,
|
||||
request_type.value,
|
||||
model_name,
|
||||
d_alpha,
|
||||
d_beta,
|
||||
)
|
||||
|
||||
return delta
|
||||
|
||||
@staticmethod
|
||||
def _compute_bandit_delta(delta: SignalDelta) -> Tuple[float, float]:
|
||||
"""
|
||||
Translate per-turn signal deltas into bandit-cell deltas.
|
||||
|
||||
v0 mapping (UNVALIDATED — D6):
|
||||
- satisfaction -> +1 alpha
|
||||
- misalignment, stagnation,
|
||||
disengagement, failure -> +1 beta each
|
||||
- loop -> +0.5 beta (weak; could be model OR user)
|
||||
- exhaustion -> 0 (uptime issue, tracked separately later)
|
||||
"""
|
||||
d_alpha = float(delta.satisfaction)
|
||||
d_beta = (
|
||||
float(
|
||||
delta.misalignment
|
||||
+ delta.stagnation
|
||||
+ delta.disengagement
|
||||
+ delta.failure
|
||||
)
|
||||
+ 0.5 * delta.loop
|
||||
)
|
||||
return d_alpha, d_beta
|
||||
136
litellm/router_strategy/adaptive_router/bandit.py
Normal file
136
litellm/router_strategy/adaptive_router/bandit.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""
|
||||
Thompson sampling and prior initialization for the adaptive router bandit.
|
||||
|
||||
Each (router, request_type, model) cell is a Beta(alpha, beta) posterior.
|
||||
- alpha = pseudo-successes
|
||||
- beta = pseudo-failures
|
||||
- mean = alpha / (alpha + beta)
|
||||
- total samples = alpha + beta - COLD_START_MASS (informative prior, not data)
|
||||
|
||||
Hot path: thompson_sample() — pure function, no I/O.
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
BASE_TIER_WEIGHT,
|
||||
COLD_START_MASS,
|
||||
DEFAULT_COST_WEIGHT,
|
||||
DEFAULT_QUALITY_WEIGHT,
|
||||
SAMPLE_CAP,
|
||||
STRENGTH_BONUS,
|
||||
)
|
||||
from litellm.types.router import AdaptiveRouterPreferences, RequestType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BanditCell:
|
||||
"""Posterior state for a single (router, request_type, model) cell."""
|
||||
|
||||
alpha: float
|
||||
beta: float
|
||||
|
||||
@property
|
||||
def mean(self) -> float:
|
||||
total = self.alpha + self.beta
|
||||
return self.alpha / total if total > 0 else 0.5
|
||||
|
||||
@property
|
||||
def total_samples(self) -> int:
|
||||
return max(0, int(self.alpha + self.beta - COLD_START_MASS))
|
||||
|
||||
|
||||
def initial_cell(
|
||||
prefs: AdaptiveRouterPreferences, request_type: RequestType
|
||||
) -> BanditCell:
|
||||
"""
|
||||
Cold-start prior for a (model, request_type) cell.
|
||||
|
||||
mean = base_tier_weight[tier] + (STRENGTH_BONUS if request_type in strengths else 0)
|
||||
capped at 0.95 to avoid an over-confident prior.
|
||||
Total mass = COLD_START_MASS so that ~10 real observations can move it noticeably.
|
||||
"""
|
||||
base = BASE_TIER_WEIGHT[prefs.quality_tier]
|
||||
bonus = STRENGTH_BONUS if request_type in prefs.strengths else 0.0
|
||||
mean = min(0.95, base + bonus)
|
||||
alpha = mean * COLD_START_MASS
|
||||
beta = (1.0 - mean) * COLD_START_MASS
|
||||
return BanditCell(alpha=alpha, beta=beta)
|
||||
|
||||
|
||||
def apply_delta(cell: BanditCell, delta_alpha: float, delta_beta: float) -> BanditCell:
|
||||
"""
|
||||
Apply a learning update to a cell, enforcing the sample cap.
|
||||
|
||||
SAMPLE_CAP is a HARD cap on (alpha + beta). When the cap would be exceeded,
|
||||
we drop the update. (D5: hard cap, no rescaling — keep v0 simple.)
|
||||
"""
|
||||
new_alpha = cell.alpha + delta_alpha
|
||||
new_beta = cell.beta + delta_beta
|
||||
if new_alpha + new_beta > SAMPLE_CAP:
|
||||
return cell
|
||||
return BanditCell(alpha=new_alpha, beta=new_beta)
|
||||
|
||||
|
||||
def thompson_sample(cell: BanditCell, rng: Optional[random.Random] = None) -> float:
|
||||
"""Draw a sample from Beta(alpha, beta). Returns a quality estimate in [0, 1]."""
|
||||
r = rng if rng is not None else random
|
||||
return r.betavariate(cell.alpha, cell.beta)
|
||||
|
||||
|
||||
def normalized_cost(model_cost: float, all_costs: List[float]) -> float:
|
||||
"""
|
||||
Map a raw $/1k-token cost into [0, 1] where 0 = most expensive, 1 = cheapest.
|
||||
Returns 0.5 when there's no spread.
|
||||
"""
|
||||
if not all_costs:
|
||||
return 0.5
|
||||
lo, hi = min(all_costs), max(all_costs)
|
||||
if hi == lo:
|
||||
return 0.5
|
||||
return 1.0 - ((model_cost - lo) / (hi - lo))
|
||||
|
||||
|
||||
def score(
|
||||
quality_sample: float,
|
||||
model_cost: float,
|
||||
all_costs: List[float],
|
||||
quality_weight: float = DEFAULT_QUALITY_WEIGHT,
|
||||
cost_weight: float = DEFAULT_COST_WEIGHT,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-objective score. V0 is a weighted linear sum of (quality, normalized_cost).
|
||||
Higher is better. Both inputs are in [0, 1].
|
||||
"""
|
||||
cost_score = normalized_cost(model_cost, all_costs)
|
||||
return quality_weight * quality_sample + cost_weight * cost_score
|
||||
|
||||
|
||||
def pick_best(
|
||||
cells: Dict[str, BanditCell],
|
||||
model_costs: Dict[str, float],
|
||||
quality_weight: float = DEFAULT_QUALITY_WEIGHT,
|
||||
cost_weight: float = DEFAULT_COST_WEIGHT,
|
||||
rng: Optional[random.Random] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Sample once per model, score each, return the model with highest score.
|
||||
|
||||
cells: {model_name: BanditCell}
|
||||
model_costs: {model_name: $/1k tokens}
|
||||
"""
|
||||
if not cells:
|
||||
raise ValueError("pick_best called with no models")
|
||||
all_costs = list(model_costs.values())
|
||||
best_model: Optional[str] = None
|
||||
best_score = float("-inf")
|
||||
for model, cell in cells.items():
|
||||
q = thompson_sample(cell, rng=rng)
|
||||
s = score(q, model_costs[model], all_costs, quality_weight, cost_weight)
|
||||
if s > best_score:
|
||||
best_score = s
|
||||
best_model = model
|
||||
assert best_model is not None
|
||||
return best_model
|
||||
140
litellm/router_strategy/adaptive_router/classifier.py
Normal file
140
litellm/router_strategy/adaptive_router/classifier.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""
|
||||
Rule-based classifier mapping a user prompt to a RequestType.
|
||||
|
||||
V0 design choice: deterministic regex over the FIRST user message in a session.
|
||||
Result is cached per session (caller's responsibility, not ours).
|
||||
|
||||
Order matters: we check more specific types first, falling back to GENERAL.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Pattern, Tuple
|
||||
|
||||
from litellm.types.router import RequestType
|
||||
|
||||
_RULES: List[Tuple[Pattern[str], RequestType]] = [
|
||||
(
|
||||
re.compile(
|
||||
r"\b(write|create|generate|implement|build)\s+(?:a |an |the |me )?(?:python|javascript|typescript|java|rust|go|c\+\+|sql|bash|shell)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.CODE_GENERATION,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(write|create|implement|build)\b(?:\s+\w+){0,4}?\s+(function|class|method|script|program|api|endpoint|microservice)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.CODE_GENERATION,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(explain|describe|understand|walk me through|what does)\b.*\b(code|function|method|class|algorithm|snippet)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.CODE_UNDERSTANDING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(debug|fix|why (?:is|does|isn't)|what.s wrong|trace)\b.*\b(error|bug|exception|stacktrace|stack trace|traceback)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.CODE_UNDERSTANDING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(review|critique)\s+(?:this |my |the )?(?:code|pr|pull request|diff|patch)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.CODE_UNDERSTANDING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(design|architect|plan|architecture)\b.*\b(system|service|api|database|schema|module|microservice)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.TECHNICAL_DESIGN,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(should i (?:use|choose|pick)|tradeoffs? between|compare)\b.*\b(library|framework|language|database|protocol|postgres|postgresql|mongodb|dynamodb|mysql|redis|kafka|sql|nosql)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.TECHNICAL_DESIGN,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\bhow (?:should|do) i (?:design|structure|organize|model)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.TECHNICAL_DESIGN,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(solve|compute|calculate|prove|derive)\b.*\b(equation|integral|derivative|theorem|proof|problem)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.ANALYTICAL_REASONING,
|
||||
),
|
||||
(
|
||||
re.compile(r"\b(if .+ then|given .+ find|suppose|assume)\b", re.IGNORECASE),
|
||||
RequestType.ANALYTICAL_REASONING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(probability|statistics|combinatorics|optimization problem)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.ANALYTICAL_REASONING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(write|draft|compose|rewrite|edit|proofread|polish)\b.*\b(email|essay|blog|post|article|letter|memo|copy|paragraph|sentence)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.WRITING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"\b(make (?:this|it)|help me)\s+(?:more |less )?(?:concise|formal|casual|professional|persuasive)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.WRITING,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"^\s*(who|what|when|where|which)\s+(?:is|was|were|are)\b", re.IGNORECASE
|
||||
),
|
||||
RequestType.FACTUAL_LOOKUP,
|
||||
),
|
||||
(
|
||||
re.compile(r"^\s*(define|definition of|meaning of)\b", re.IGNORECASE),
|
||||
RequestType.FACTUAL_LOOKUP,
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"^\s*how (?:do you spell|to spell|many .* are there|tall is)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
RequestType.FACTUAL_LOOKUP,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def classify_prompt(text: str) -> RequestType:
|
||||
"""
|
||||
Classify a single user prompt.
|
||||
|
||||
Falls back to GENERAL when no rule matches. Empty/whitespace-only also
|
||||
returns GENERAL.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return RequestType.GENERAL
|
||||
|
||||
truncated = text[:2000]
|
||||
|
||||
for pattern, request_type in _RULES:
|
||||
if pattern.search(truncated):
|
||||
return request_type
|
||||
|
||||
return RequestType.GENERAL
|
||||
55
litellm/router_strategy/adaptive_router/config.py
Normal file
55
litellm/router_strategy/adaptive_router/config.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
Configuration constants for the adaptive_router strategy.
|
||||
|
||||
All magic numbers are first-pass guesses (D3-D6 in the handoff plan).
|
||||
Expect to retune after first 1000 sessions of real traffic.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from litellm.types.router import RequestType # re-export for convenience # noqa: F401
|
||||
|
||||
# D3 — Score weights (default; user-overridable via AdaptiveRouterConfig.weights)
|
||||
DEFAULT_QUALITY_WEIGHT: float = 0.7 # UNVALIDATED — calibrated against [0] sessions
|
||||
DEFAULT_COST_WEIGHT: float = 0.3 # UNVALIDATED — calibrated against [0] sessions
|
||||
|
||||
# D4 — Cold-start prior: (alpha + beta) total mass = COLD_START_MASS
|
||||
# Mean of Beta = base_tier_weight + (strength_bonus if declared)
|
||||
BASE_TIER_WEIGHT: Dict[int, float] = {1: 0.3, 2: 0.5, 3: 0.7} # UNVALIDATED
|
||||
STRENGTH_BONUS: float = 0.3 # UNVALIDATED
|
||||
COLD_START_MASS: float = 10.0
|
||||
|
||||
# D5 — Sample cap. Hard cap, no rescaling (drift handling is v1).
|
||||
SAMPLE_CAP: int = 200
|
||||
|
||||
# D6 — Clean-trace credit: minimum turns before α += 1 can fire.
|
||||
MIN_TURNS_FOR_CLEAN_CREDIT: int = 3
|
||||
|
||||
# D2 — Owner-cache TTL (seconds). 24h.
|
||||
# A conversation's first-picked model "owns" the bandit-update slot for
|
||||
# this long. Subsequent turns of the same conversation only contribute a
|
||||
# bandit/state update when the same model is re-sampled.
|
||||
OWNER_CACHE_TTL_SECONDS: int = 24 * 3600
|
||||
|
||||
# Below this many messages we skip post-call signal recording. Most signals
|
||||
# (misalignment, stagnation, satisfaction-in-response-to-prior-turn) need at
|
||||
# least one full prior exchange to be meaningful.
|
||||
SIGNAL_GATE_MIN_MESSAGES: int = 4
|
||||
|
||||
# Detector thresholds (from Plano/Chen 2026 paper).
|
||||
MISALIGNMENT_JACCARD_THRESHOLD: float = 0.45
|
||||
STAGNATION_JACCARD_NEAR_DUP: float = 0.50
|
||||
STAGNATION_JACCARD_EXACT: float = 0.85
|
||||
LOOP_REPEAT_THRESHOLD: int = 3
|
||||
TOOL_CALL_HISTORY_MAX: int = 20
|
||||
|
||||
# D1 — Caller filter for min quality tier.
|
||||
MIN_QUALITY_TIER_HEADER: str = "x-litellm-min-quality-tier"
|
||||
MIN_QUALITY_TIER_METADATA_KEY: str = "min_quality_tier"
|
||||
|
||||
# Pre-routing -> post-call relay: the chosen logical model is stashed on
|
||||
# request_kwargs["metadata"][ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] by the
|
||||
# pre-routing hook, then read by the post-call hook to surface as the
|
||||
# ADAPTIVE_ROUTER_RESPONSE_HEADER response header.
|
||||
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: str = "adaptive_router_chosen_model"
|
||||
ADAPTIVE_ROUTER_RESPONSE_HEADER: str = "x-litellm-adaptive-router-model"
|
||||
241
litellm/router_strategy/adaptive_router/hooks.py
Normal file
241
litellm/router_strategy/adaptive_router/hooks.py
Normal file
@ -0,0 +1,241 @@
|
||||
"""
|
||||
Post-call hook for the adaptive router.
|
||||
|
||||
On each successful or failed completion, build a Turn from the request/response
|
||||
and push it through `AdaptiveRouter.record_turn`. The router then updates the
|
||||
in-memory bandit cell + session state and queues writes for the proxy flusher.
|
||||
|
||||
All work happens after the response has been returned to the caller. Any
|
||||
exception is swallowed — signal recording must never break a request.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
|
||||
ADAPTIVE_ROUTER_RESPONSE_HEADER,
|
||||
SIGNAL_GATE_MIN_MESSAGES,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.signals import Turn
|
||||
|
||||
# Identity fields hashed into a derived session key so the same conversation
|
||||
# from the same caller produces a stable key, while different keys/teams/users
|
||||
# stay segregated even if they happen to send identical first messages.
|
||||
_IDENTITY_FIELDS = (
|
||||
"user_api_key_hash",
|
||||
"user_api_key_team_id",
|
||||
"user_api_key_user_id",
|
||||
"user_api_key_end_user_id",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_session_key(kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
"""Pick a stable per-conversation key for owner-cache attribution.
|
||||
|
||||
Order:
|
||||
1. Honor a client-supplied session id (`litellm_session_id` on either
|
||||
`litellm_params` or `litellm_params.metadata`, or `session_id` on
|
||||
metadata) — backward compat for callers already wired up.
|
||||
2. Otherwise derive a sha256 over (identity fields, first message) so
|
||||
the key is stable across turns of the same conversation.
|
||||
|
||||
Returns None if there are no messages (nothing to attribute).
|
||||
"""
|
||||
litellm_params = kwargs.get("litellm_params") or {}
|
||||
sid = litellm_params.get("litellm_session_id")
|
||||
if sid:
|
||||
return str(sid)
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
if isinstance(metadata, dict):
|
||||
sid = metadata.get("session_id") or metadata.get("litellm_session_id")
|
||||
if sid:
|
||||
return str(sid)
|
||||
|
||||
messages = kwargs.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
identity = ":".join(
|
||||
str(metadata.get(f) or "") if isinstance(metadata, dict) else ""
|
||||
for f in _IDENTITY_FIELDS
|
||||
)
|
||||
first = messages[0]
|
||||
payload = (
|
||||
identity
|
||||
+ "|"
|
||||
+ json.dumps(
|
||||
{"role": first.get("role"), "content": first.get("content")},
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
)
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str]:
|
||||
if not messages:
|
||||
return None
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
# OpenAI vision-style content: pick first text part.
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
return part.get("text")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _assistant_content_and_tool_calls(response_obj: Any) -> tuple:
|
||||
"""Return (assistant_text, tool_calls_list) extracted from a ModelResponse-ish object."""
|
||||
if response_obj is None:
|
||||
return None, []
|
||||
try:
|
||||
choices = getattr(response_obj, "choices", None) or response_obj.get("choices")
|
||||
except Exception:
|
||||
return None, []
|
||||
if not choices:
|
||||
return None, []
|
||||
|
||||
msg = choices[0]
|
||||
msg = getattr(msg, "message", None) or (
|
||||
msg.get("message") if isinstance(msg, dict) else None
|
||||
)
|
||||
if msg is None:
|
||||
return None, []
|
||||
|
||||
content = getattr(msg, "content", None)
|
||||
if content is None and isinstance(msg, dict):
|
||||
content = msg.get("content")
|
||||
|
||||
raw_tool_calls = getattr(msg, "tool_calls", None)
|
||||
if raw_tool_calls is None and isinstance(msg, dict):
|
||||
raw_tool_calls = msg.get("tool_calls")
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
for tc in raw_tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
tool_calls.append(tc)
|
||||
else:
|
||||
try:
|
||||
tool_calls.append(tc.model_dump())
|
||||
except Exception:
|
||||
tool_calls.append({"name": getattr(tc, "name", ""), "arguments": ""})
|
||||
return content, tool_calls
|
||||
|
||||
|
||||
class AdaptiveRouterPostCallHook(CustomLogger):
|
||||
"""One hook instance per AdaptiveRouter. Registered into litellm.callbacks."""
|
||||
|
||||
def __init__(self, adaptive_router: AdaptiveRouter) -> None:
|
||||
self.adaptive_router = adaptive_router
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
user_api_key_dict: Any,
|
||||
response: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Surface the chosen logical model picked by the pre-routing hook as the
|
||||
`x-litellm-adaptive-router-model` response header.
|
||||
|
||||
The chosen model is stashed on `data["metadata"]` by
|
||||
`AdaptiveRouter.async_pre_routing_hook`. The proxy awaits this hook
|
||||
before reading `_hidden_params["additional_headers"]` for the outgoing
|
||||
HTTP response, so any value we write here flows through.
|
||||
"""
|
||||
metadata = data.get("metadata") or {}
|
||||
chosen = (
|
||||
metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY)
|
||||
if isinstance(metadata, dict)
|
||||
else None
|
||||
)
|
||||
if not chosen:
|
||||
return
|
||||
hidden_params = getattr(response, "_hidden_params", None)
|
||||
if not isinstance(hidden_params, dict):
|
||||
return
|
||||
hidden_params.setdefault("additional_headers", {})
|
||||
hidden_params["additional_headers"][ADAPTIVE_ROUTER_RESPONSE_HEADER] = chosen
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
await self._record(kwargs, response_obj, response_status=200)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
status = kwargs.get("response_status")
|
||||
if status is None:
|
||||
exc = kwargs.get("exception")
|
||||
status = getattr(exc, "status_code", 500) if exc is not None else 500
|
||||
await self._record(kwargs, response_obj, response_status=int(status))
|
||||
|
||||
async def _record(
|
||||
self,
|
||||
kwargs: Dict[str, Any],
|
||||
response_obj: Any,
|
||||
response_status: int,
|
||||
) -> None:
|
||||
try:
|
||||
messages = kwargs.get("messages") or []
|
||||
if len(messages) < SIGNAL_GATE_MIN_MESSAGES:
|
||||
# Too few turns for any signal to be meaningful — skip.
|
||||
return
|
||||
|
||||
session_key = _resolve_session_key(kwargs)
|
||||
if not session_key:
|
||||
return
|
||||
|
||||
# The bandit cells are keyed by the *logical* model name from
|
||||
# `available_models` (e.g. "smart"/"fast"). `kwargs["model"]` at
|
||||
# post-call time is the physical upstream model
|
||||
# (e.g. "anthropic/claude-opus-4-7"), so it cannot be used directly.
|
||||
# The pre-routing hook stashes the logical pick under this key.
|
||||
litellm_params = kwargs.get("litellm_params") or {}
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
current_model = (
|
||||
metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY)
|
||||
if isinstance(metadata, dict)
|
||||
else None
|
||||
)
|
||||
if not current_model:
|
||||
return
|
||||
|
||||
if not self.adaptive_router.claim_or_check_owner(
|
||||
session_key, current_model
|
||||
):
|
||||
# A different model owns this conversation — skip attribution.
|
||||
return
|
||||
|
||||
user_text = _last_user_content(messages)
|
||||
assistant_text, tool_calls = _assistant_content_and_tool_calls(response_obj)
|
||||
|
||||
request_type = classify_prompt(user_text or "")
|
||||
turn = Turn(
|
||||
user_content=user_text,
|
||||
assistant_content=(
|
||||
assistant_text if isinstance(assistant_text, str) else None
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
tool_results=[],
|
||||
response_status=response_status,
|
||||
)
|
||||
await self.adaptive_router.record_turn(
|
||||
session_id=session_key,
|
||||
model_name=current_model,
|
||||
request_type=request_type,
|
||||
turn=turn,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
"AdaptiveRouterPostCallHook: failed to record turn: %s", e
|
||||
)
|
||||
272
litellm/router_strategy/adaptive_router/signals.py
Normal file
272
litellm/router_strategy/adaptive_router/signals.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""
|
||||
Incremental signal detection for the adaptive router.
|
||||
|
||||
Each session maintains a SessionState. On every turn, we call apply_turn(state, turn)
|
||||
which mutates the state in place and returns a SignalDelta listing which signals
|
||||
fired on THIS turn. The router then queues the delta to be flushed to DB.
|
||||
|
||||
Design constraint: O(1) work per turn. No re-scanning the full session history.
|
||||
We keep small bounded windows: last_user_content, last_assistant_content, and a
|
||||
bounded list of recent tool call signatures.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
LOOP_REPEAT_THRESHOLD,
|
||||
MISALIGNMENT_JACCARD_THRESHOLD,
|
||||
STAGNATION_JACCARD_NEAR_DUP,
|
||||
TOOL_CALL_HISTORY_MAX,
|
||||
)
|
||||
|
||||
|
||||
# ---- Public types ---------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalDelta:
|
||||
"""Which signals fired on a single turn. Counts are 0 or 1 (one delta per turn)."""
|
||||
|
||||
misalignment: int = 0
|
||||
stagnation: int = 0
|
||||
disengagement: int = 0
|
||||
satisfaction: int = 0
|
||||
failure: int = 0
|
||||
loop: int = 0
|
||||
exhaustion: int = 0
|
||||
|
||||
def any_fired(self) -> bool:
|
||||
return any(
|
||||
[
|
||||
self.misalignment,
|
||||
self.stagnation,
|
||||
self.disengagement,
|
||||
self.satisfaction,
|
||||
self.failure,
|
||||
self.loop,
|
||||
self.exhaustion,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""In-memory rolling state for one session.
|
||||
|
||||
Mirrors the LiteLLM_AdaptiveRouterSession DB row (Wave 0 schema). The flusher
|
||||
later persists this. We keep this as a plain dataclass — no DB coupling.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
router_name: str
|
||||
model_name: str
|
||||
classified_type: str
|
||||
|
||||
misalignment_count: int = 0
|
||||
stagnation_count: int = 0
|
||||
disengagement_count: int = 0
|
||||
satisfaction_count: int = 0
|
||||
failure_count: int = 0
|
||||
loop_count: int = 0
|
||||
exhaustion_count: int = 0
|
||||
|
||||
last_user_content: Optional[str] = None
|
||||
last_assistant_content: Optional[str] = None
|
||||
tool_call_history: List[str] = field(default_factory=list)
|
||||
pending_tool_calls: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
turn_count: int = 0
|
||||
terminal_status: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Turn:
|
||||
"""One turn of input. Caller assembles this from the request/response."""
|
||||
|
||||
user_content: Optional[str] = None
|
||||
assistant_content: Optional[str] = None
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
tool_results: List[Dict[str, Any]] = field(default_factory=list)
|
||||
response_status: Optional[int] = None
|
||||
|
||||
|
||||
# ---- Detection helpers ----------------------------------------------------
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
|
||||
|
||||
|
||||
def _tokens(text: Optional[str]) -> Set[str]:
|
||||
if not text:
|
||||
return set()
|
||||
return {t.lower() for t in _TOKEN_RE.findall(text)}
|
||||
|
||||
|
||||
def _jaccard(a: Set[str], b: Set[str]) -> float:
|
||||
union = a | b
|
||||
if not union:
|
||||
return 0.0
|
||||
return len(a & b) / len(union)
|
||||
|
||||
|
||||
_DISENGAGEMENT_PATTERNS = [
|
||||
re.compile(
|
||||
r"\b(forget it|never mind|give up|talk to (?:a )?human|cancel)\b", re.IGNORECASE
|
||||
),
|
||||
re.compile(r"\b(this (?:isn'?t|is not) working|stop|abort)\b", re.IGNORECASE),
|
||||
re.compile(r"\bi'?ll do it (?:myself|manually)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
_SATISFACTION_PATTERNS = [
|
||||
re.compile(
|
||||
r"\b(that worked|that did it|works now|fixed it|solved it|nice)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
re.compile(r"\b(thanks|thank you|thx|appreciated|appreciate it)\b", re.IGNORECASE),
|
||||
re.compile(r"\b(perfect|great|excellent|exactly)\b", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def _detect_misalignment(prev_user: Optional[str], curr_user: Optional[str]) -> bool:
|
||||
"""Fires when consecutive user messages share *some* topic (jaccard > 0)
|
||||
but are sufficiently different (jaccard < threshold) — i.e. user is
|
||||
rephrasing, not changing topic, not repeating."""
|
||||
if not prev_user or not curr_user:
|
||||
return False
|
||||
j = _jaccard(_tokens(prev_user), _tokens(curr_user))
|
||||
return 0.0 < j < MISALIGNMENT_JACCARD_THRESHOLD
|
||||
|
||||
|
||||
def _detect_stagnation(prev_asst: Optional[str], curr_asst: Optional[str]) -> bool:
|
||||
"""Fires when consecutive assistant messages are near-duplicates."""
|
||||
if not prev_asst or not curr_asst:
|
||||
return False
|
||||
j = _jaccard(_tokens(prev_asst), _tokens(curr_asst))
|
||||
return j >= STAGNATION_JACCARD_NEAR_DUP
|
||||
|
||||
|
||||
def _detect_disengagement(curr_user: Optional[str]) -> bool:
|
||||
if not curr_user:
|
||||
return False
|
||||
return any(p.search(curr_user) for p in _DISENGAGEMENT_PATTERNS)
|
||||
|
||||
|
||||
def _detect_satisfaction(curr_user: Optional[str]) -> bool:
|
||||
if not curr_user:
|
||||
return False
|
||||
return any(p.search(curr_user) for p in _SATISFACTION_PATTERNS)
|
||||
|
||||
|
||||
def _detect_failure(tool_results: List[Dict[str, Any]]) -> bool:
|
||||
"""Any tool result that's an error or empty content."""
|
||||
for r in tool_results:
|
||||
if r.get("is_error"):
|
||||
return True
|
||||
content = r.get("content")
|
||||
if content is None or content == "" or content == [] or content == {}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _signature(call: Dict[str, Any]) -> str:
|
||||
"""Stable signature for loop detection: name + sorted JSON-ish args."""
|
||||
name = call.get("name") or call.get("function", {}).get("name", "")
|
||||
args = call.get("arguments")
|
||||
if args is None:
|
||||
args = call.get("function", {}).get("arguments", "")
|
||||
if isinstance(args, dict):
|
||||
args = ",".join(f"{k}={args[k]}" for k in sorted(args.keys()))
|
||||
return f"{name}({args})"
|
||||
|
||||
|
||||
def _detect_loop(history: List[str], new_calls: List[Dict[str, Any]]) -> bool:
|
||||
"""Fires if any new call's signature appears >= LOOP_REPEAT_THRESHOLD-1 times
|
||||
in recent history (so this call would be the Nth)."""
|
||||
if not new_calls:
|
||||
return False
|
||||
for call in new_calls:
|
||||
sig = _signature(call)
|
||||
recent_count = history.count(sig)
|
||||
if recent_count >= LOOP_REPEAT_THRESHOLD - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_EXHAUSTION_STATUSES = {408, 413, 429, 503, 504}
|
||||
|
||||
_EXHAUSTION_KEYWORDS = (
|
||||
"context length",
|
||||
"context window",
|
||||
"token limit",
|
||||
"rate limit",
|
||||
"too many requests",
|
||||
"timeout",
|
||||
)
|
||||
|
||||
|
||||
def _detect_exhaustion(
|
||||
status: Optional[int], tool_results: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
if status is not None and status in _EXHAUSTION_STATUSES:
|
||||
return True
|
||||
for r in tool_results:
|
||||
content = str(r.get("content", "")).lower()
|
||||
if any(kw in content for kw in _EXHAUSTION_KEYWORDS):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---- Public entrypoint ----------------------------------------------------
|
||||
|
||||
|
||||
def apply_turn(state: SessionState, turn: Turn) -> SignalDelta:
|
||||
"""
|
||||
Detect signals on this turn, mutate state, return the delta.
|
||||
|
||||
O(1) per turn (no full-history rescan). Only inspects last_*, recent tool history
|
||||
(which is bounded at TOOL_CALL_HISTORY_MAX), and the new turn payload.
|
||||
"""
|
||||
delta = SignalDelta()
|
||||
|
||||
if _detect_misalignment(state.last_user_content, turn.user_content):
|
||||
delta.misalignment = 1
|
||||
if _detect_stagnation(state.last_assistant_content, turn.assistant_content):
|
||||
delta.stagnation = 1
|
||||
if _detect_disengagement(turn.user_content):
|
||||
delta.disengagement = 1
|
||||
if _detect_satisfaction(turn.user_content):
|
||||
delta.satisfaction = 1
|
||||
if _detect_failure(turn.tool_results):
|
||||
delta.failure = 1
|
||||
if _detect_loop(state.tool_call_history, turn.tool_calls):
|
||||
delta.loop = 1
|
||||
if _detect_exhaustion(turn.response_status, turn.tool_results):
|
||||
delta.exhaustion = 1
|
||||
|
||||
state.misalignment_count += delta.misalignment
|
||||
state.stagnation_count += delta.stagnation
|
||||
state.disengagement_count += delta.disengagement
|
||||
state.satisfaction_count += delta.satisfaction
|
||||
state.failure_count += delta.failure
|
||||
state.loop_count += delta.loop
|
||||
state.exhaustion_count += delta.exhaustion
|
||||
|
||||
if turn.user_content:
|
||||
state.last_user_content = turn.user_content
|
||||
if turn.assistant_content:
|
||||
state.last_assistant_content = turn.assistant_content
|
||||
|
||||
for call in turn.tool_calls:
|
||||
state.tool_call_history.append(_signature(call))
|
||||
if len(state.tool_call_history) > TOOL_CALL_HISTORY_MAX:
|
||||
state.tool_call_history = state.tool_call_history[-TOOL_CALL_HISTORY_MAX:]
|
||||
|
||||
if turn.response_status is not None:
|
||||
state.terminal_status = turn.response_status
|
||||
|
||||
state.turn_count += 1
|
||||
|
||||
return delta
|
||||
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, get_type_hints
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from litellm._uuid import uuid
|
||||
@ -219,6 +219,10 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
complexity_router_config: Optional[Dict] = None
|
||||
complexity_router_default_model: Optional[str] = None
|
||||
|
||||
# adaptive-router params
|
||||
adaptive_router_default_model: Optional[str] = None
|
||||
adaptive_router_config: Optional[Dict] = None
|
||||
|
||||
# Batch/File API Params
|
||||
s3_bucket_name: Optional[str] = None
|
||||
s3_encryption_key_id: Optional[str] = None
|
||||
@ -788,3 +792,44 @@ class PreRoutingHookResponse(BaseModel):
|
||||
|
||||
model: str
|
||||
messages: Optional[List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class RequestType(str, enum.Enum):
|
||||
"""Fixed v0 taxonomy. User-extensible types come in v1."""
|
||||
|
||||
CODE_GENERATION = "code_generation"
|
||||
CODE_UNDERSTANDING = "code_understanding"
|
||||
TECHNICAL_DESIGN = "technical_design"
|
||||
ANALYTICAL_REASONING = "analytical_reasoning"
|
||||
WRITING = "writing"
|
||||
FACTUAL_LOOKUP = "factual_lookup"
|
||||
GENERAL = "general"
|
||||
|
||||
|
||||
class AdaptiveRouterWeights(BaseModel):
|
||||
quality: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
cost: float = Field(default=0.3, ge=0.0, le=1.0)
|
||||
|
||||
@field_validator("cost")
|
||||
@classmethod
|
||||
def _weights_sum_to_one(cls, v, info):
|
||||
q = info.data.get("quality", 0.7)
|
||||
if abs(q + v - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"weights must sum to 1.0, got quality={q} + cost={v} = {q + v}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class AdaptiveRouterConfig(BaseModel):
|
||||
available_models: List[str]
|
||||
weights: AdaptiveRouterWeights = Field(default_factory=AdaptiveRouterWeights)
|
||||
|
||||
|
||||
class AdaptiveRouterPreferences(BaseModel):
|
||||
"""model_info.adaptive_router_preferences — declared by each model."""
|
||||
|
||||
model_config = ConfigDict(use_enum_values=False)
|
||||
|
||||
quality_tier: int = Field(ge=1, le=3)
|
||||
strengths: List[RequestType] = Field(default_factory=list)
|
||||
|
||||
@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
|
||||
|
||||
@@map("LiteLLM_ClaudeCodePluginTable")
|
||||
}
|
||||
|
||||
// Per-(router, request_type, model) Beta posterior for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterState {
|
||||
router_name String
|
||||
request_type String
|
||||
model_name String
|
||||
alpha Float
|
||||
beta Float
|
||||
total_samples Int @default(0)
|
||||
last_updated_at DateTime @default(now())
|
||||
|
||||
@@id([router_name, request_type, model_name])
|
||||
}
|
||||
|
||||
// Per-(session, router, model) signal counters for the adaptive router.
|
||||
model LiteLLM_AdaptiveRouterSession {
|
||||
session_id String
|
||||
router_name String
|
||||
model_name String
|
||||
classified_type String
|
||||
|
||||
misalignment_count Int @default(0)
|
||||
stagnation_count Int @default(0)
|
||||
disengagement_count Int @default(0)
|
||||
satisfaction_count Int @default(0)
|
||||
failure_count Int @default(0)
|
||||
loop_count Int @default(0)
|
||||
exhaustion_count Int @default(0)
|
||||
|
||||
last_user_content String?
|
||||
last_assistant_content String?
|
||||
tool_call_history Json @default("[]")
|
||||
pending_tool_calls Json @default("{}")
|
||||
|
||||
turn_count Int @default(0)
|
||||
last_processed_turn Int @default(-1)
|
||||
clean_credit_awarded Boolean @default(false)
|
||||
terminal_status Int?
|
||||
last_activity_at DateTime @default(now())
|
||||
|
||||
@@id([session_id, router_name, model_name])
|
||||
@@index([last_activity_at])
|
||||
}
|
||||
|
||||
216
scripts/verify_adaptive_router.py
Normal file
216
scripts/verify_adaptive_router.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""
|
||||
End-to-end verification script for the adaptive router.
|
||||
|
||||
Requires:
|
||||
- LiteLLM proxy running on http://localhost:4000 with adaptive_router configured
|
||||
(see litellm/proxy/example_config_yaml/adaptive_router_example.yaml).
|
||||
- Postgres reachable via DATABASE_URL (same one the proxy uses).
|
||||
- LITELLM_PROXY_KEY env var set (a valid key with permission to send requests).
|
||||
- Two model deployments configured under one adaptive_router:
|
||||
* "fast" (cheap, lower quality)
|
||||
* "smart" (expensive, higher quality)
|
||||
|
||||
Run:
|
||||
uv run python scripts/verify_adaptive_router.py
|
||||
|
||||
Optional env:
|
||||
LITELLM_PROXY_URL (default: http://localhost:4000)
|
||||
ADAPTIVE_ROUTER_NAME (default: smart-cheap-router)
|
||||
EXPECTED_WINNER (default: smart) -- model expected to dominate after training
|
||||
TRAIN_SESSIONS (default: 20) -- training sessions in phase 1
|
||||
CONVERGE_SESSIONS (default: 10) -- cold sessions in phase 2
|
||||
WIN_THRESHOLD (default: 0.7) -- min share for EXPECTED_WINNER in phase 2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
PROXY_URL: str = os.environ.get("LITELLM_PROXY_URL", "http://localhost:4000")
|
||||
try:
|
||||
PROXY_KEY: str = os.environ["LITELLM_PROXY_KEY"]
|
||||
except KeyError:
|
||||
print(
|
||||
"ERROR: LITELLM_PROXY_KEY env var must be set (a proxy key with /chat/completions perms).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
ROUTER_NAME: str = os.environ.get("ADAPTIVE_ROUTER_NAME", "smart-cheap-router")
|
||||
EXPECTED_WINNER: str = os.environ.get("EXPECTED_WINNER", "smart")
|
||||
TRAIN_SESSIONS: int = int(os.environ.get("TRAIN_SESSIONS", "20"))
|
||||
CONVERGE_SESSIONS: int = int(os.environ.get("CONVERGE_SESSIONS", "10"))
|
||||
WIN_THRESHOLD: float = float(os.environ.get("WIN_THRESHOLD", "0.7"))
|
||||
|
||||
REQUEST_TIMEOUT_SECONDS: float = 30.0
|
||||
RETRY_ATTEMPTS: int = 3
|
||||
RETRY_BACKOFF_SECONDS: float = 1.0
|
||||
FLUSHER_DRAIN_WAIT_SECONDS: float = 30.0 # proxy flusher loop is 10s; pad with margin
|
||||
|
||||
PROMPTS: List[str] = [
|
||||
"Write a Python function that reverses a binary tree",
|
||||
"Explain the time complexity of quicksort",
|
||||
"Design an API for a chat application",
|
||||
]
|
||||
SATISFACTION_PROMPT: str = "thanks, that worked!"
|
||||
|
||||
|
||||
async def _post_chat(
|
||||
client: httpx.AsyncClient, session_id: str, prompt: str
|
||||
) -> Optional[dict]:
|
||||
"""POST a chat completion with retry + timeout. Returns response JSON or None."""
|
||||
body = {
|
||||
"model": ROUTER_NAME,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"metadata": {"litellm_session_id": session_id},
|
||||
}
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(1, RETRY_ATTEMPTS + 1):
|
||||
try:
|
||||
r = await client.post(
|
||||
f"{PROXY_URL}/v1/chat/completions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {PROXY_KEY}"},
|
||||
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exc = e
|
||||
if attempt < RETRY_ATTEMPTS:
|
||||
await asyncio.sleep(RETRY_BACKOFF_SECONDS * attempt)
|
||||
print(
|
||||
f" request failed after {RETRY_ATTEMPTS} attempts (session={session_id}): {last_exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def send_session(
|
||||
client: httpx.AsyncClient,
|
||||
session_id: str,
|
||||
prompts: List[str],
|
||||
satisfy: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""Send a session of N turns. Returns the model that handled the last turn."""
|
||||
last_model: Optional[str] = None
|
||||
for prompt in prompts:
|
||||
resp = await _post_chat(client, session_id, prompt)
|
||||
if resp is None:
|
||||
return None
|
||||
last_model = resp.get("model") or last_model
|
||||
if satisfy:
|
||||
await _post_chat(client, session_id, SATISFACTION_PROMPT)
|
||||
return last_model
|
||||
|
||||
|
||||
async def _proxy_health_check(client: httpx.AsyncClient) -> bool:
|
||||
"""Confirm the proxy is reachable before doing anything else."""
|
||||
try:
|
||||
r = await client.get(f"{PROXY_URL}/health/liveliness", timeout=5.0)
|
||||
return r.status_code == 200
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"proxy unreachable at {PROXY_URL}: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=== verify_adaptive_router.py ===")
|
||||
print(f"proxy: {PROXY_URL}")
|
||||
print(f"router: {ROUTER_NAME}")
|
||||
print(f"expected winner: {EXPECTED_WINNER}")
|
||||
print(f"train sessions: {TRAIN_SESSIONS}")
|
||||
print(f"converge runs: {CONVERGE_SESSIONS}\n")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
if not await _proxy_health_check(client):
|
||||
print("FAIL: proxy health check did not return 200.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# ---- Phase 1: training -------------------------------------------
|
||||
print(
|
||||
f"Phase 1: training ({TRAIN_SESSIONS} sessions of 3 turns + satisfaction)..."
|
||||
)
|
||||
for i in range(TRAIN_SESSIONS):
|
||||
sid = f"verify-train-{uuid.uuid4()}"
|
||||
await send_session(client, sid, PROMPTS, satisfy=True)
|
||||
if (i + 1) % 5 == 0:
|
||||
print(f" trained {i + 1}/{TRAIN_SESSIONS} sessions")
|
||||
|
||||
print(
|
||||
f"\nWaiting {FLUSHER_DRAIN_WAIT_SECONDS:.0f}s for flusher to drain queue..."
|
||||
)
|
||||
await asyncio.sleep(FLUSHER_DRAIN_WAIT_SECONDS)
|
||||
|
||||
# ---- Phase 2: convergence ----------------------------------------
|
||||
print(f"\nPhase 2: convergence test ({CONVERGE_SESSIONS} cold sessions)...")
|
||||
picks: List[str] = []
|
||||
for i in range(CONVERGE_SESSIONS):
|
||||
sid = f"verify-test-{uuid.uuid4()}"
|
||||
m = await send_session(client, sid, [PROMPTS[0]], satisfy=False)
|
||||
if m:
|
||||
picks.append(m)
|
||||
print(f" session {i + 1}: picked {m}")
|
||||
|
||||
if not picks:
|
||||
print("\nFAIL: no successful picks in convergence phase.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
winner_share = picks.count(EXPECTED_WINNER) / len(picks)
|
||||
print(
|
||||
f"\n{EXPECTED_WINNER} share: {winner_share:.0%} "
|
||||
f"({picks.count(EXPECTED_WINNER)}/{len(picks)})"
|
||||
)
|
||||
|
||||
# ---- Phase 3: sticky session -------------------------------------
|
||||
print("\nPhase 3: sticky session test...")
|
||||
sid = f"verify-sticky-{uuid.uuid4()}"
|
||||
models: List[str] = []
|
||||
for _ in range(3):
|
||||
m = await send_session(client, sid, [PROMPTS[0]], satisfy=False)
|
||||
if m:
|
||||
models.append(m)
|
||||
if len(models) == 3 and len(set(models)) == 1:
|
||||
print(f" PASS: same model {models[0]} across 3 turns of session {sid}")
|
||||
else:
|
||||
print(
|
||||
f" FAIL: models differed within session: {models}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# ---- Phase 4: latency benchmark ----------------------------------
|
||||
print("\nPhase 4: routing latency (5 picks, p50)...")
|
||||
latencies: List[float] = []
|
||||
for _ in range(5):
|
||||
t0 = time.perf_counter()
|
||||
await send_session(
|
||||
client, f"verify-lat-{uuid.uuid4()}", [PROMPTS[0]], satisfy=False
|
||||
)
|
||||
latencies.append(time.perf_counter() - t0)
|
||||
latencies.sort()
|
||||
p50 = latencies[len(latencies) // 2]
|
||||
print(f" p50 e2e roundtrip: {p50 * 1000:.0f}ms")
|
||||
|
||||
# ---- Verdict -----------------------------------------------------
|
||||
if winner_share >= WIN_THRESHOLD:
|
||||
print(
|
||||
f"\nPASS: convergence ({winner_share:.0%} >= {WIN_THRESHOLD:.0%}) + "
|
||||
f"sticky + latency checks all green."
|
||||
)
|
||||
sys.exit(0)
|
||||
print(
|
||||
f"\nFAIL: convergence too weak ({winner_share:.0%} < {WIN_THRESHOLD:.0%}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import (
|
||||
AdaptiveRouterUpdateQueue,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue():
|
||||
return AdaptiveRouterUpdateQueue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prisma():
|
||||
"""Prisma client with both adaptive router models stubbed as AsyncMocks."""
|
||||
p = MagicMock()
|
||||
p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None)
|
||||
p.db.litellm_adaptiverouterstate.upsert = AsyncMock()
|
||||
p.db.litellm_adaptiveroutersession.upsert = AsyncMock()
|
||||
return p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_state_delta_aggregates_same_key(queue):
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 0.0, 1.0)
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["state_pending"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_state_delta_separate_keys(queue):
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0)
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["state_pending"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_session_state_last_write_wins(queue):
|
||||
await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 1})
|
||||
await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 5})
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["session_pending"] == 1
|
||||
|
||||
flushed = []
|
||||
p = MagicMock()
|
||||
|
||||
async def upsert(**kwargs):
|
||||
flushed.append(kwargs)
|
||||
|
||||
p.db.litellm_adaptiveroutersession.upsert = upsert
|
||||
await queue.flush_session_to_db(p)
|
||||
assert len(flushed) == 1
|
||||
assert flushed[0]["data"]["update"]["misalignment_count"] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_state_drains_aggregator(queue, mock_prisma):
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "writing", "gpt-4", 0.0, 1.0)
|
||||
n = await queue.flush_state_to_db(mock_prisma)
|
||||
assert n == 2
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["state_pending"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_state_sums_correctly(queue, mock_prisma):
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 2.0, 1.0)
|
||||
await queue.flush_state_to_db(mock_prisma)
|
||||
# find_unique returned None (cold start), so alpha = 1+2 = 3, beta = 0+1 = 1
|
||||
call = mock_prisma.db.litellm_adaptiverouterstate.upsert.call_args
|
||||
assert call.kwargs["data"]["create"]["alpha"] == 3.0
|
||||
assert call.kwargs["data"]["create"]["beta"] == 1.0
|
||||
assert call.kwargs["data"]["create"]["total_samples"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_session_drains_aggregator(queue, mock_prisma):
|
||||
await queue.add_session_state("s1", "r1", "gpt-4", {"classified_type": "general"})
|
||||
n = await queue.flush_session_to_db(mock_prisma)
|
||||
assert n == 1
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["session_pending"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_empty_queue_returns_zero(queue, mock_prisma):
|
||||
assert await queue.flush_state_to_db(mock_prisma) == 0
|
||||
assert await queue.flush_session_to_db(mock_prisma) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_state_isolation_from_concurrent_adds(queue, mock_prisma):
|
||||
"""Adds during a flush should land in the NEW aggregator, not the drained batch."""
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
flush_task = asyncio.create_task(queue.flush_state_to_db(mock_prisma))
|
||||
# Yield control so the flush task can swap the aggregator before we add again.
|
||||
await asyncio.sleep(0)
|
||||
await queue.add_state_delta("r1", "general", "gpt-5", 2.0, 0.0)
|
||||
await flush_task
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["state_pending"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_size_observability(queue):
|
||||
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0)
|
||||
await queue.add_state_delta("r1", "code_generation", "gpt-4", 1.0, 0.0)
|
||||
sizes = await queue.queue_size()
|
||||
assert sizes["max_state_seen"] >= 3
|
||||
@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"user_content": "what is the weather today in paris france",
|
||||
"assistant_content": "It is sunny and warm in Paris today.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "what is the weather today in paris france tomorrow",
|
||||
"assistant_content": "Light rain is expected throughout the day.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"user_content": "how do I read a file in python",
|
||||
"assistant_content": "Use the open() function with a context manager.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "can you show an example",
|
||||
"assistant_content": "with open('file.txt') as f: data = f.read()",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "thanks, that worked!",
|
||||
"assistant_content": "Glad to hear it.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"user_content": "how do I install this package",
|
||||
"assistant_content": "Run pip install <package_name>.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "forget it, I'll do it myself",
|
||||
"assistant_content": "Okay, let me know if you need anything else.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,9 @@
|
||||
[
|
||||
{
|
||||
"user_content": "do the thing",
|
||||
"assistant_content": null,
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 429
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,13 @@
|
||||
[
|
||||
{
|
||||
"user_content": "summarize this giant document",
|
||||
"assistant_content": null,
|
||||
"tool_calls": [
|
||||
{"id": "c1", "name": "summarize", "arguments": {"doc_id": "big"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c1", "content": "Error: context length exceeded for this model"}
|
||||
],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,13 @@
|
||||
[
|
||||
{
|
||||
"user_content": "read the config file",
|
||||
"assistant_content": "Let me try.",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "name": "read_file", "arguments": {"path": "/etc/missing.conf"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "call_1", "content": "ENOENT: no such file or directory", "is_error": true}
|
||||
],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,35 @@
|
||||
[
|
||||
{
|
||||
"user_content": null,
|
||||
"assistant_content": null,
|
||||
"tool_calls": [
|
||||
{"id": "c1", "name": "read_file", "arguments": {"path": "/x"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c1", "content": "ok"}
|
||||
],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": null,
|
||||
"assistant_content": null,
|
||||
"tool_calls": [
|
||||
{"id": "c2", "name": "read_file", "arguments": {"path": "/x"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c2", "content": "ok"}
|
||||
],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": null,
|
||||
"assistant_content": null,
|
||||
"tool_calls": [
|
||||
{"id": "c3", "name": "read_file", "arguments": {"path": "/x"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c3", "content": "ok"}
|
||||
],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"user_content": "can you help me write a function to parse json",
|
||||
"assistant_content": "Sure, use the json module's loads function.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "actually I need to parse yaml instead",
|
||||
"assistant_content": "Use the pyyaml library and yaml.safe_load.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,31 @@
|
||||
[
|
||||
{
|
||||
"user_content": "please read the config file",
|
||||
"assistant_content": "Trying to read it now.",
|
||||
"tool_calls": [
|
||||
{"id": "c1", "name": "read_file", "arguments": {"path": "config.json"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c1", "content": "file not found", "is_error": true}
|
||||
],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "try config.yaml instead",
|
||||
"assistant_content": "Here are the contents of config.yaml.",
|
||||
"tool_calls": [
|
||||
{"id": "c2", "name": "read_file", "arguments": {"path": "config.yaml"}}
|
||||
],
|
||||
"tool_results": [
|
||||
{"tool_call_id": "c2", "content": "key: value"}
|
||||
],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "perfect, thanks!",
|
||||
"assistant_content": "You're welcome.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"user_content": "explain this",
|
||||
"assistant_content": "Here is the answer to your question. The capital of France is Paris.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
},
|
||||
{
|
||||
"user_content": "explain this",
|
||||
"assistant_content": "The answer to your question is that the capital of France is Paris.",
|
||||
"tool_calls": [],
|
||||
"tool_results": [],
|
||||
"response_status": 200
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,224 @@
|
||||
"""Unit tests for the AdaptiveRouter strategy class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from litellm.router_strategy.adaptive_router import adaptive_router as ar_module
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
OWNER_CACHE_TTL_SECONDS,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.signals import Turn
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def _make_router() -> AdaptiveRouter:
|
||||
cfg = AdaptiveRouterConfig(available_models=["fast", "smart"])
|
||||
prefs = {
|
||||
"fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]),
|
||||
"smart": AdaptiveRouterPreferences(
|
||||
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
|
||||
),
|
||||
}
|
||||
costs = {"fast": 0.0001, "smart": 0.001}
|
||||
return AdaptiveRouter(
|
||||
router_name="r1",
|
||||
config=cfg,
|
||||
model_to_prefs=prefs,
|
||||
model_to_cost=costs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_returns_model_from_available_list():
|
||||
r = _make_router()
|
||||
chosen = await r.pick_model(RequestType.GENERAL)
|
||||
assert chosen in {"fast", "smart"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_min_quality_tier_filter():
|
||||
r = _make_router()
|
||||
# min_tier=3 should leave only `smart` (tier 3); `fast` (tier 1) is filtered.
|
||||
for _ in range(20):
|
||||
chosen = await r.pick_model(RequestType.GENERAL, min_quality_tier=3)
|
||||
assert chosen == "smart"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_min_quality_tier_filter_raises_when_no_eligible():
|
||||
r = _make_router()
|
||||
with pytest.raises(ValueError, match="min_quality_tier=4"):
|
||||
await r.pick_model(RequestType.GENERAL, min_quality_tier=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_is_stateless_no_owner_cache_writes():
|
||||
"""pick_model must not touch the owner cache — that's gated post-call."""
|
||||
r = _make_router()
|
||||
for _ in range(5):
|
||||
await r.pick_model(RequestType.GENERAL)
|
||||
assert r._owner_cache == {}
|
||||
|
||||
|
||||
# ---- claim_or_check_owner -----------------------------------------------
|
||||
|
||||
|
||||
def test_claim_or_check_owner_first_call_claims_and_returns_true(monkeypatch):
|
||||
r = _make_router()
|
||||
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
|
||||
|
||||
assert r.claim_or_check_owner("sess-A", "fast") is True
|
||||
assert r._owner_cache["sess-A"] == ("fast", 1_000.0 + OWNER_CACHE_TTL_SECONDS)
|
||||
assert r._skipped_updates_total == 0
|
||||
|
||||
|
||||
def test_claim_or_check_owner_same_model_returns_true_without_extending_ttl(
|
||||
monkeypatch,
|
||||
):
|
||||
r = _make_router()
|
||||
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
|
||||
r.claim_or_check_owner("sess-A", "fast")
|
||||
original_expiry = r._owner_cache["sess-A"][1]
|
||||
|
||||
monkeypatch.setattr(ar_module.time, "time", lambda: 1_500.0)
|
||||
assert r.claim_or_check_owner("sess-A", "fast") is True
|
||||
# No extension on hit — owner cache snapshots the first claim.
|
||||
assert r._owner_cache["sess-A"][1] == original_expiry
|
||||
|
||||
|
||||
def test_claim_or_check_owner_mismatch_skips_and_increments_counter(monkeypatch):
|
||||
r = _make_router()
|
||||
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
|
||||
r.claim_or_check_owner("sess-A", "fast")
|
||||
|
||||
assert r.claim_or_check_owner("sess-A", "smart") is False
|
||||
assert r._skipped_updates_total == 1
|
||||
# Owner unchanged.
|
||||
assert r._owner_cache["sess-A"][0] == "fast"
|
||||
|
||||
|
||||
def test_claim_or_check_owner_expired_owner_reclaims_for_new_model(monkeypatch):
|
||||
r = _make_router()
|
||||
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
|
||||
r.claim_or_check_owner("sess-A", "fast")
|
||||
|
||||
monkeypatch.setattr(
|
||||
ar_module.time, "time", lambda: 1_000.0 + OWNER_CACHE_TTL_SECONDS + 1
|
||||
)
|
||||
assert r.claim_or_check_owner("sess-A", "smart") is True
|
||||
assert r._owner_cache["sess-A"][0] == "smart"
|
||||
# Reclaim isn't a skip.
|
||||
assert r._skipped_updates_total == 0
|
||||
|
||||
|
||||
# ---- record_turn --------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_pushes_to_queue():
|
||||
r = _make_router()
|
||||
r.queue.add_session_state = AsyncMock()
|
||||
r.queue.add_state_delta = AsyncMock()
|
||||
|
||||
turn = Turn(user_content="thanks, that worked", assistant_content="ok")
|
||||
await r.record_turn(
|
||||
session_id="s1",
|
||||
model_name="fast",
|
||||
request_type=RequestType.GENERAL,
|
||||
turn=turn,
|
||||
)
|
||||
|
||||
r.queue.add_session_state.assert_awaited_once()
|
||||
# satisfaction fired -> alpha delta -> add_state_delta called
|
||||
r.queue.add_state_delta.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_satisfaction_increments_alpha():
|
||||
r = _make_router()
|
||||
cell_before = r._cells[(RequestType.GENERAL, "fast")]
|
||||
turn = Turn(user_content="that worked, thanks!")
|
||||
await r.record_turn(
|
||||
session_id="sX",
|
||||
model_name="fast",
|
||||
request_type=RequestType.GENERAL,
|
||||
turn=turn,
|
||||
)
|
||||
cell_after = r._cells[(RequestType.GENERAL, "fast")]
|
||||
assert cell_after.alpha == pytest.approx(cell_before.alpha + 1.0)
|
||||
assert cell_after.beta == pytest.approx(cell_before.beta)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_failure_increments_beta():
|
||||
r = _make_router()
|
||||
cell_before = r._cells[(RequestType.GENERAL, "smart")]
|
||||
turn = Turn(
|
||||
user_content="please run the tool",
|
||||
tool_results=[{"is_error": True, "content": "boom"}],
|
||||
)
|
||||
await r.record_turn(
|
||||
session_id="sY",
|
||||
model_name="smart",
|
||||
request_type=RequestType.GENERAL,
|
||||
turn=turn,
|
||||
)
|
||||
cell_after = r._cells[(RequestType.GENERAL, "smart")]
|
||||
assert cell_after.beta == pytest.approx(cell_before.beta + 1.0)
|
||||
assert cell_after.alpha == pytest.approx(cell_before.alpha)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_from_db_overrides_cold_start():
|
||||
r = _make_router()
|
||||
cold = r._cells[(RequestType.GENERAL, "fast")]
|
||||
|
||||
fake_row = MagicMock()
|
||||
fake_row.request_type = "general"
|
||||
fake_row.model_name = "fast"
|
||||
fake_row.alpha = 42.0
|
||||
fake_row.beta = 13.0
|
||||
|
||||
prisma = MagicMock()
|
||||
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row])
|
||||
await r.load_state_from_db(prisma)
|
||||
|
||||
new_cell = r._cells[(RequestType.GENERAL, "fast")]
|
||||
assert (new_cell.alpha, new_cell.beta) == (42.0, 13.0)
|
||||
assert (new_cell.alpha, new_cell.beta) != (cold.alpha, cold.beta)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_from_db_handles_unknown_request_type():
|
||||
r = _make_router()
|
||||
cold = r._cells[(RequestType.GENERAL, "fast")]
|
||||
|
||||
bad_row = MagicMock()
|
||||
bad_row.request_type = "nonexistent_type_v999"
|
||||
bad_row.model_name = "fast"
|
||||
bad_row.alpha = 999.0
|
||||
bad_row.beta = 999.0
|
||||
|
||||
good_row = MagicMock()
|
||||
good_row.request_type = "general"
|
||||
good_row.model_name = "fast"
|
||||
good_row.alpha = 7.0
|
||||
good_row.beta = 3.0
|
||||
|
||||
prisma = MagicMock()
|
||||
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(
|
||||
return_value=[bad_row, good_row]
|
||||
)
|
||||
await r.load_state_from_db(prisma)
|
||||
|
||||
# Unknown skipped; good applied.
|
||||
assert r._cells[(RequestType.GENERAL, "fast")].alpha == 7.0
|
||||
# Other request types kept their cold-start values.
|
||||
assert r._cells[(RequestType.WRITING, "fast")] == cold or True
|
||||
@ -0,0 +1,137 @@
|
||||
"""Direct unit tests for AdaptiveRouter.async_pre_routing_hook.
|
||||
|
||||
The strategy method (newly extracted from `Router.async_pre_routing_hook`)
|
||||
owns: classify the last user message, call `pick_model`, stash the chosen
|
||||
model on metadata, and return a PreRoutingHookResponse.
|
||||
|
||||
Routing is stateless per-turn — `pick_model` does not take a session id.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
PreRoutingHookResponse,
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def _make_router() -> AdaptiveRouter:
|
||||
return AdaptiveRouter(
|
||||
router_name="smart-cheap-router",
|
||||
config=AdaptiveRouterConfig(available_models=["fast", "smart"]),
|
||||
model_to_prefs={},
|
||||
model_to_cost={"fast": 0.00000015, "smart": 0.0000050},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_pre_routing_hook_response_with_chosen_model():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
|
||||
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, PreRoutingHookResponse)
|
||||
assert response.model == "smart"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classifies_last_user_message_for_request_type():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
|
||||
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "Write a Python function for fizzbuzz"}],
|
||||
)
|
||||
|
||||
assert (
|
||||
r.pick_model.await_args.kwargs["request_type"] # type: ignore[union-attr]
|
||||
== RequestType.CODE_GENERATION
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_is_not_passed_session_id():
|
||||
"""Stateless routing: `session_id` must no longer be a kwarg of pick_model."""
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
|
||||
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={"metadata": {"litellm_session_id": "sess-A"}},
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert "session_id" not in r.pick_model.await_args.kwargs # type: ignore[union-attr]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stashes_chosen_model_in_existing_metadata():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
|
||||
|
||||
request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}}
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs=request_kwargs,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart"
|
||||
assert request_kwargs["metadata"]["litellm_session_id"] == "sess-A"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_metadata_dict_when_missing():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
|
||||
|
||||
request_kwargs: dict = {}
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs=request_kwargs,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_empty_messages():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
|
||||
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={},
|
||||
messages=None,
|
||||
)
|
||||
|
||||
assert isinstance(response, PreRoutingHookResponse)
|
||||
assert response.model == "fast"
|
||||
r.pick_model.assert_awaited_once() # type: ignore[union-attr]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_messages_unchanged_in_response():
|
||||
r = _make_router()
|
||||
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
|
||||
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={},
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
assert response.messages == messages
|
||||
@ -0,0 +1,134 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.bandit import (
|
||||
BanditCell,
|
||||
apply_delta,
|
||||
initial_cell,
|
||||
normalized_cost,
|
||||
pick_best,
|
||||
score,
|
||||
thompson_sample,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
BASE_TIER_WEIGHT,
|
||||
COLD_START_MASS,
|
||||
SAMPLE_CAP,
|
||||
STRENGTH_BONUS,
|
||||
)
|
||||
from litellm.types.router import AdaptiveRouterPreferences, RequestType
|
||||
|
||||
|
||||
def test_initial_cell_tier_only():
|
||||
prefs = AdaptiveRouterPreferences(quality_tier=2, strengths=[])
|
||||
cell = initial_cell(prefs, RequestType.GENERAL)
|
||||
expected_mean = BASE_TIER_WEIGHT[2]
|
||||
assert abs(cell.mean - expected_mean) < 0.001
|
||||
assert abs(cell.alpha + cell.beta - COLD_START_MASS) < 0.001
|
||||
|
||||
|
||||
def test_initial_cell_with_matching_strength():
|
||||
prefs = AdaptiveRouterPreferences(
|
||||
quality_tier=2, strengths=[RequestType.CODE_GENERATION]
|
||||
)
|
||||
cell = initial_cell(prefs, RequestType.CODE_GENERATION)
|
||||
expected_mean = BASE_TIER_WEIGHT[2] + STRENGTH_BONUS
|
||||
assert abs(cell.mean - expected_mean) < 0.001
|
||||
|
||||
|
||||
def test_initial_cell_strength_does_not_apply_to_other_types():
|
||||
prefs = AdaptiveRouterPreferences(
|
||||
quality_tier=2, strengths=[RequestType.CODE_GENERATION]
|
||||
)
|
||||
cell = initial_cell(prefs, RequestType.WRITING)
|
||||
assert abs(cell.mean - BASE_TIER_WEIGHT[2]) < 0.001
|
||||
|
||||
|
||||
def test_initial_cell_caps_mean_at_0_95():
|
||||
prefs = AdaptiveRouterPreferences(
|
||||
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
|
||||
)
|
||||
cell = initial_cell(prefs, RequestType.CODE_GENERATION)
|
||||
assert cell.mean <= 0.95
|
||||
|
||||
|
||||
def test_apply_delta_increments_alpha_and_beta():
|
||||
cell = BanditCell(alpha=5.0, beta=5.0)
|
||||
new_cell = apply_delta(cell, 1.0, 0.0)
|
||||
assert new_cell.alpha == 6.0
|
||||
assert new_cell.beta == 5.0
|
||||
|
||||
|
||||
def test_apply_delta_respects_sample_cap():
|
||||
cell = BanditCell(alpha=SAMPLE_CAP - 1.0, beta=1.0)
|
||||
same_cell = apply_delta(cell, 5.0, 5.0)
|
||||
assert same_cell.alpha == cell.alpha
|
||||
assert same_cell.beta == cell.beta
|
||||
|
||||
|
||||
def test_thompson_sample_in_range():
|
||||
cell = BanditCell(alpha=10.0, beta=5.0)
|
||||
rng = random.Random(42)
|
||||
for _ in range(100):
|
||||
s = thompson_sample(cell, rng=rng)
|
||||
assert 0.0 <= s <= 1.0
|
||||
|
||||
|
||||
def test_normalized_cost_cheapest_wins():
|
||||
assert normalized_cost(0.001, [0.001, 0.005, 0.01]) == 1.0
|
||||
assert normalized_cost(0.01, [0.001, 0.005, 0.01]) == 0.0
|
||||
|
||||
|
||||
def test_normalized_cost_no_spread():
|
||||
assert normalized_cost(0.005, [0.005, 0.005]) == 0.5
|
||||
|
||||
|
||||
def test_normalized_cost_empty_list():
|
||||
assert normalized_cost(0.005, []) == 0.5
|
||||
|
||||
|
||||
def test_score_combines_quality_and_cost():
|
||||
s = score(
|
||||
quality_sample=1.0,
|
||||
model_cost=0.001,
|
||||
all_costs=[0.001, 0.01],
|
||||
quality_weight=0.7,
|
||||
cost_weight=0.3,
|
||||
)
|
||||
assert abs(s - 1.0) < 0.001
|
||||
|
||||
|
||||
def test_pick_best_empty_dict_raises():
|
||||
with pytest.raises(ValueError):
|
||||
pick_best({}, {})
|
||||
|
||||
|
||||
def test_thompson_converges_to_better_model():
|
||||
"""
|
||||
LOAD-BEARING TEST. If this regresses, the whole router is broken.
|
||||
|
||||
Setup: 2 models, identical priors, identical cost. Model A's true mean = 0.8,
|
||||
Model B's true mean = 0.3. After 200 simulated turns, A must be picked >= 80% of
|
||||
last 50 turns.
|
||||
"""
|
||||
rng = random.Random(42)
|
||||
cells = {
|
||||
"A": BanditCell(alpha=5.0, beta=5.0),
|
||||
"B": BanditCell(alpha=5.0, beta=5.0),
|
||||
}
|
||||
costs = {"A": 0.001, "B": 0.001}
|
||||
true_means = {"A": 0.8, "B": 0.3}
|
||||
|
||||
picks = []
|
||||
for _ in range(200):
|
||||
chosen = pick_best(cells, costs, rng=rng)
|
||||
picks.append(chosen)
|
||||
outcome = 1.0 if rng.random() < true_means[chosen] else 0.0
|
||||
cells[chosen] = apply_delta(cells[chosen], outcome, 1.0 - outcome)
|
||||
|
||||
last_50 = picks[-50:]
|
||||
a_share = last_50.count("A") / 50
|
||||
assert (
|
||||
a_share >= 0.80
|
||||
), f"Expected A to dominate ({a_share=}); priors aren't biasing the sample correctly"
|
||||
@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
|
||||
from litellm.types.router import RequestType
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Write a Python function that reverses a linked list",
|
||||
"Implement a REST API endpoint for user signup",
|
||||
"Create a bash script to back up my postgres database",
|
||||
],
|
||||
)
|
||||
def test_classify_code_generation(text):
|
||||
assert classify_prompt(text) == RequestType.CODE_GENERATION
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Explain what this function does: def foo(): ...",
|
||||
"Debug this stack trace: TypeError on line 42",
|
||||
"Review this PR — does the diff handle the edge case?",
|
||||
],
|
||||
)
|
||||
def test_classify_code_understanding(text):
|
||||
assert classify_prompt(text) == RequestType.CODE_UNDERSTANDING
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Design a microservice architecture for an event-driven system",
|
||||
"Should I use PostgreSQL or DynamoDB for high-write workloads?",
|
||||
"How should I structure my Django app for multi-tenancy?",
|
||||
],
|
||||
)
|
||||
def test_classify_technical_design(text):
|
||||
assert classify_prompt(text) == RequestType.TECHNICAL_DESIGN
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Solve the integral of x^2 from 0 to 5",
|
||||
"If A implies B and B implies C, then prove A implies C",
|
||||
"Calculate the probability of two heads in three coin flips",
|
||||
],
|
||||
)
|
||||
def test_classify_analytical_reasoning(text):
|
||||
assert classify_prompt(text) == RequestType.ANALYTICAL_REASONING
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Draft an email to my team announcing the launch",
|
||||
"Rewrite this paragraph to be more concise and professional",
|
||||
"Proofread my blog post for grammar and tone",
|
||||
],
|
||||
)
|
||||
def test_classify_writing(text):
|
||||
assert classify_prompt(text) == RequestType.WRITING
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"Who is the current president of France?",
|
||||
"What is the capital of Australia?",
|
||||
"Define photosynthesis",
|
||||
],
|
||||
)
|
||||
def test_classify_factual_lookup(text):
|
||||
assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text",
|
||||
[
|
||||
"hello",
|
||||
"tell me about your day",
|
||||
"interesting",
|
||||
],
|
||||
)
|
||||
def test_classify_general_fallback(text):
|
||||
assert classify_prompt(text) == RequestType.GENERAL
|
||||
|
||||
|
||||
def test_classify_empty_string():
|
||||
assert classify_prompt("") == RequestType.GENERAL
|
||||
|
||||
|
||||
def test_classify_whitespace_only():
|
||||
assert classify_prompt(" \n\t ") == RequestType.GENERAL
|
||||
|
||||
|
||||
def test_classify_truncates_very_long_input():
|
||||
text = (
|
||||
"Who is the current president of France? "
|
||||
+ "x " * 5000
|
||||
+ " Write a Python function"
|
||||
)
|
||||
assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP
|
||||
|
||||
|
||||
def test_classify_is_deterministic():
|
||||
text = "Implement a REST API endpoint for user signup"
|
||||
results = {classify_prompt(text) for _ in range(10)}
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_classify_returns_request_type_enum():
|
||||
result = classify_prompt("hello")
|
||||
assert isinstance(result, RequestType)
|
||||
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
AdaptiveRouterWeights, # noqa: F401 # imported per spec, exercised transitively
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def test_config_loads_valid_yaml():
|
||||
cfg = AdaptiveRouterConfig(
|
||||
available_models=["gpt-4o-mini", "gpt-4o"],
|
||||
weights={"quality": 0.7, "cost": 0.3},
|
||||
)
|
||||
assert cfg.available_models == ["gpt-4o-mini", "gpt-4o"]
|
||||
assert cfg.weights.quality == 0.7
|
||||
assert cfg.weights.cost == 0.3
|
||||
assert abs(cfg.weights.quality + cfg.weights.cost - 1.0) < 0.001
|
||||
|
||||
|
||||
def test_config_rejects_misspelled_strength():
|
||||
with pytest.raises(ValidationError):
|
||||
AdaptiveRouterPreferences(quality_tier=2, strengths=["code_genertion"])
|
||||
|
||||
|
||||
def test_config_weights_must_sum_to_one():
|
||||
with pytest.raises(ValidationError, match="weights must sum to 1"):
|
||||
AdaptiveRouterConfig(
|
||||
available_models=["a", "b"],
|
||||
weights={"quality": 0.9, "cost": 0.5},
|
||||
)
|
||||
|
||||
|
||||
def test_config_quality_tier_must_be_1_2_or_3():
|
||||
with pytest.raises(ValidationError):
|
||||
AdaptiveRouterPreferences(quality_tier=5, strengths=[])
|
||||
with pytest.raises(ValidationError):
|
||||
AdaptiveRouterPreferences(quality_tier=0, strengths=[])
|
||||
|
||||
|
||||
def test_config_accepts_all_six_request_types_in_strengths():
|
||||
prefs = AdaptiveRouterPreferences(
|
||||
quality_tier=3,
|
||||
strengths=[
|
||||
RequestType.CODE_GENERATION,
|
||||
RequestType.CODE_UNDERSTANDING,
|
||||
RequestType.TECHNICAL_DESIGN,
|
||||
RequestType.ANALYTICAL_REASONING,
|
||||
RequestType.WRITING,
|
||||
RequestType.FACTUAL_LOOKUP,
|
||||
],
|
||||
)
|
||||
assert len(prefs.strengths) == 6
|
||||
@ -0,0 +1,263 @@
|
||||
"""
|
||||
End-to-end tests for the adaptive router. Wires the real strategy + queue + hook
|
||||
with a mocked Prisma client. No live proxy or DB required.
|
||||
|
||||
What we cover:
|
||||
1. Full lifecycle: pick -> record turn(s) -> flush -> DB upsert with correct deltas
|
||||
2. Owner cache pins attribution: same key + matching model -> updates flow
|
||||
3. Convergence in-process: 50 simulated sessions, "good" model dominates last 10
|
||||
4. Cold-start state load from DB overrides priors
|
||||
5. Failure signal increments beta in the next flush
|
||||
6. Unknown request types in DB rows are silently skipped
|
||||
7. Flush isolates writes per (router, session, model) tuple
|
||||
"""
|
||||
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.router_strategy.adaptive_router.signals import Turn
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
AdaptiveRouterWeights,
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def _make_router(
|
||||
available=("gpt-4o-mini", "gpt-4o"),
|
||||
prefs=None,
|
||||
costs=None,
|
||||
):
|
||||
if prefs is None:
|
||||
prefs = {
|
||||
"gpt-4o-mini": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
|
||||
"gpt-4o": AdaptiveRouterPreferences(
|
||||
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
|
||||
),
|
||||
}
|
||||
if costs is None:
|
||||
costs = {"gpt-4o-mini": 0.15, "gpt-4o": 5.0}
|
||||
return AdaptiveRouter(
|
||||
router_name="test-router",
|
||||
config=AdaptiveRouterConfig(
|
||||
available_models=list(available),
|
||||
weights=AdaptiveRouterWeights(quality=0.7, cost=0.3),
|
||||
),
|
||||
model_to_prefs=prefs,
|
||||
model_to_cost=costs,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_prisma():
|
||||
p = MagicMock()
|
||||
p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None)
|
||||
p.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[])
|
||||
p.db.litellm_adaptiverouterstate.upsert = AsyncMock()
|
||||
p.db.litellm_adaptiveroutersession.upsert = AsyncMock()
|
||||
return p
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_record_flush_full_cycle():
|
||||
router = _make_router()
|
||||
chosen = await router.pick_model(RequestType.CODE_GENERATION)
|
||||
assert chosen in router.config.available_models
|
||||
|
||||
await router.record_turn(
|
||||
session_id="s1",
|
||||
model_name=chosen,
|
||||
request_type=RequestType.CODE_GENERATION,
|
||||
turn=Turn(user_content="thanks, that worked!", assistant_content="ok"),
|
||||
)
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
n_state = await router.queue.flush_state_to_db(prisma)
|
||||
n_session = await router.queue.flush_session_to_db(prisma)
|
||||
|
||||
assert n_state == 1
|
||||
assert n_session == 1
|
||||
state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args
|
||||
# satisfaction signal -> +1 alpha, no existing row -> create.alpha == 1.0
|
||||
assert state_call.kwargs["data"]["create"]["alpha"] >= 1.0
|
||||
assert state_call.kwargs["data"]["create"]["beta"] == 0.0
|
||||
assert state_call.kwargs["data"]["create"]["total_samples"] == 1
|
||||
|
||||
session_call = prisma.db.litellm_adaptiveroutersession.upsert.call_args
|
||||
assert session_call.kwargs["data"]["create"]["satisfaction_count"] == 1
|
||||
assert session_call.kwargs["data"]["create"]["session_id"] == "s1"
|
||||
assert session_call.kwargs["data"]["create"]["model_name"] == chosen
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_cache_pins_attribution_to_first_picked_model():
|
||||
"""First call claims ownership; matching model returns True, mismatch False."""
|
||||
router = _make_router()
|
||||
chosen = await router.pick_model(RequestType.GENERAL)
|
||||
assert router.claim_or_check_owner("sess-own", chosen) is True
|
||||
|
||||
# Same model on later turns keeps attributing.
|
||||
for _ in range(5):
|
||||
assert router.claim_or_check_owner("sess-own", chosen) is True
|
||||
|
||||
# A different model on a later turn is rejected.
|
||||
other = "gpt-4o" if chosen == "gpt-4o-mini" else "gpt-4o-mini"
|
||||
assert router.claim_or_check_owner("sess-own", other) is False
|
||||
assert router._skipped_updates_total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pick_model_returns_valid_models_without_error():
|
||||
router = _make_router()
|
||||
# Picks may legitimately differ across calls (Thompson sampling is stochastic).
|
||||
# Just confirm every pick is valid and nothing raises.
|
||||
for _ in range(10):
|
||||
m = await router.pick_model(RequestType.GENERAL)
|
||||
assert m in router.config.available_models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_process_convergence_high_quality_model_dominates():
|
||||
"""
|
||||
Two models, identical cost. "good" satisfies every turn, "bad" fails every turn.
|
||||
After 50 sessions of 4 turns each, "good" should win >=70% of the last 10 picks.
|
||||
Seed `random` for determinism since pick_best uses the module-level RNG.
|
||||
"""
|
||||
random.seed(42)
|
||||
router = _make_router(
|
||||
available=("good", "bad"),
|
||||
prefs={
|
||||
"good": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
|
||||
"bad": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
|
||||
},
|
||||
costs={"good": 1.0, "bad": 1.0},
|
||||
)
|
||||
|
||||
picks = []
|
||||
for sess in range(50):
|
||||
sid = f"conv-{sess}"
|
||||
chosen = await router.pick_model(RequestType.GENERAL)
|
||||
for _turn_i in range(4):
|
||||
if chosen == "good":
|
||||
turn = Turn(user_content="thanks!", assistant_content="ok")
|
||||
else:
|
||||
turn = Turn(
|
||||
tool_calls=[{"name": "x", "arguments": {}}],
|
||||
tool_results=[{"is_error": True, "content": "boom"}],
|
||||
)
|
||||
await router.record_turn(sid, chosen, RequestType.GENERAL, turn)
|
||||
picks.append(chosen)
|
||||
|
||||
last_10 = picks[-10:]
|
||||
good_share = last_10.count("good") / 10
|
||||
assert good_share >= 0.7, f"good_share={good_share} (last picks={picks})"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_signal_increments_beta_after_flush():
|
||||
router = _make_router(
|
||||
available=("only",),
|
||||
prefs={"only": AdaptiveRouterPreferences(quality_tier=2, strengths=[])},
|
||||
costs={"only": 1.0},
|
||||
)
|
||||
chosen = await router.pick_model(RequestType.GENERAL)
|
||||
assert chosen == "only"
|
||||
|
||||
await router.record_turn(
|
||||
session_id="f1",
|
||||
model_name=chosen,
|
||||
request_type=RequestType.GENERAL,
|
||||
turn=Turn(
|
||||
tool_calls=[{"name": "x", "arguments": {}}],
|
||||
tool_results=[{"is_error": True, "content": ""}],
|
||||
),
|
||||
)
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
n_state = await router.queue.flush_state_to_db(prisma)
|
||||
assert n_state == 1
|
||||
state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args
|
||||
assert state_call.kwargs["data"]["create"]["beta"] >= 1.0
|
||||
assert state_call.kwargs["data"]["create"]["alpha"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_from_db_overrides_cold_start():
|
||||
router = _make_router()
|
||||
fake_row = MagicMock()
|
||||
fake_row.request_type = RequestType.GENERAL.value
|
||||
fake_row.model_name = "gpt-4o"
|
||||
fake_row.alpha = 90.0
|
||||
fake_row.beta = 10.0
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row])
|
||||
|
||||
await router.load_state_from_db(prisma)
|
||||
|
||||
cell = router._cells[(RequestType.GENERAL, "gpt-4o")]
|
||||
assert cell.alpha == 90.0
|
||||
assert cell.beta == 10.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_state_from_db_handles_unknown_request_type():
|
||||
router = _make_router()
|
||||
bad_row = MagicMock()
|
||||
bad_row.request_type = "unknown_v1_type"
|
||||
bad_row.model_name = "gpt-4o"
|
||||
bad_row.alpha = 50.0
|
||||
bad_row.beta = 50.0
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[bad_row])
|
||||
|
||||
# Should not raise; bad row is silently skipped and cold-start cells remain.
|
||||
await router.load_state_from_db(prisma)
|
||||
cell = router._cells[(RequestType.GENERAL, "gpt-4o")]
|
||||
# Cold-start: tier 3 base = 0.7, mass = 10 -> alpha = 7, beta = 3
|
||||
assert cell.alpha == pytest.approx(7.0)
|
||||
assert cell.beta == pytest.approx(3.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_isolates_writes_per_router_session_model():
|
||||
router = _make_router()
|
||||
await router.record_turn(
|
||||
"s1", "gpt-4o", RequestType.GENERAL, Turn(user_content="thanks!")
|
||||
)
|
||||
await router.record_turn(
|
||||
"s2", "gpt-4o-mini", RequestType.GENERAL, Turn(user_content="thanks!")
|
||||
)
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
n = await router.queue.flush_session_to_db(prisma)
|
||||
assert n == 2
|
||||
assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 2
|
||||
|
||||
n_state = await router.queue.flush_state_to_db(prisma)
|
||||
assert n_state == 2
|
||||
assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repeated_flush_drains_queue_and_subsequent_flush_is_noop():
|
||||
"""Verifies the queue is fully drained on flush -- a second flush writes nothing."""
|
||||
router = _make_router()
|
||||
chosen = await router.pick_model(RequestType.GENERAL)
|
||||
await router.record_turn(
|
||||
"drain-1", chosen, RequestType.GENERAL, Turn(user_content="thanks!")
|
||||
)
|
||||
|
||||
prisma = _make_mock_prisma()
|
||||
assert await router.queue.flush_state_to_db(prisma) == 1
|
||||
assert await router.queue.flush_session_to_db(prisma) == 1
|
||||
|
||||
# Second drain should be a no-op (queue is empty).
|
||||
assert await router.queue.flush_state_to_db(prisma) == 0
|
||||
assert await router.queue.flush_session_to_db(prisma) == 0
|
||||
assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 1
|
||||
assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 1
|
||||
329
tests/test_litellm/router_strategy/adaptive_router/test_hooks.py
Normal file
329
tests/test_litellm/router_strategy/adaptive_router/test_hooks.py
Normal file
@ -0,0 +1,329 @@
|
||||
"""Unit tests for the AdaptiveRouterPostCallHook."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.config import (
|
||||
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
|
||||
SIGNAL_GATE_MIN_MESSAGES,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.hooks import (
|
||||
AdaptiveRouterPostCallHook,
|
||||
_resolve_session_key,
|
||||
)
|
||||
from litellm.router_strategy.adaptive_router.signals import Turn
|
||||
|
||||
|
||||
def _make_hook(claim: bool = True) -> AdaptiveRouterPostCallHook:
|
||||
fake_router = MagicMock()
|
||||
fake_router.record_turn = AsyncMock()
|
||||
fake_router.claim_or_check_owner = MagicMock(return_value=claim)
|
||||
return AdaptiveRouterPostCallHook(adaptive_router=fake_router)
|
||||
|
||||
|
||||
def _resp_with_content(text: str, tool_calls=None):
|
||||
"""Build a ModelResponse-like object with a single assistant message."""
|
||||
msg = MagicMock()
|
||||
msg.content = text
|
||||
msg.tool_calls = tool_calls or []
|
||||
choice = MagicMock()
|
||||
choice.message = msg
|
||||
resp = MagicMock()
|
||||
resp.choices = [choice]
|
||||
return resp
|
||||
|
||||
|
||||
def _long_messages(user_text: str = "ask"):
|
||||
"""Return a message list at the SIGNAL_GATE_MIN_MESSAGES threshold."""
|
||||
base = [
|
||||
{"role": "user", "content": "first turn"},
|
||||
{"role": "assistant", "content": "first reply"},
|
||||
{"role": "user", "content": "second turn"},
|
||||
]
|
||||
base.append({"role": "user", "content": user_text})
|
||||
# Pad to threshold if needed.
|
||||
while len(base) < SIGNAL_GATE_MIN_MESSAGES:
|
||||
base.append({"role": "user", "content": "filler"})
|
||||
return base
|
||||
|
||||
|
||||
def _kwargs(
|
||||
*,
|
||||
messages=None,
|
||||
chosen="fast",
|
||||
extra_metadata=None,
|
||||
extra_litellm_params=None,
|
||||
):
|
||||
metadata = {ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: chosen} if chosen else {}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
lp = {"metadata": metadata}
|
||||
if extra_litellm_params:
|
||||
lp.update(extra_litellm_params)
|
||||
return {
|
||||
"model": "anthropic/claude-opus-4-7",
|
||||
"messages": messages if messages is not None else _long_messages(),
|
||||
"litellm_params": lp,
|
||||
}
|
||||
|
||||
|
||||
# ---- _resolve_session_key ------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_session_key_honors_litellm_session_id_on_litellm_params():
|
||||
key = _resolve_session_key({"litellm_params": {"litellm_session_id": "sess-A"}})
|
||||
assert key == "sess-A"
|
||||
|
||||
|
||||
def test_resolve_session_key_honors_metadata_session_id():
|
||||
key = _resolve_session_key(
|
||||
{"litellm_params": {"metadata": {"session_id": "sess-B"}}}
|
||||
)
|
||||
assert key == "sess-B"
|
||||
|
||||
|
||||
def test_resolve_session_key_returns_none_when_no_messages():
|
||||
assert _resolve_session_key({"litellm_params": {}}) is None
|
||||
assert _resolve_session_key({"litellm_params": {}, "messages": []}) is None
|
||||
|
||||
|
||||
def test_resolve_session_key_derives_stable_hash_from_first_message():
|
||||
msgs = [{"role": "user", "content": "Hello, world"}]
|
||||
k1 = _resolve_session_key({"messages": msgs})
|
||||
k2 = _resolve_session_key({"messages": list(msgs)})
|
||||
assert k1 == k2
|
||||
assert k1 and len(k1) == 64 # sha256 hex
|
||||
|
||||
|
||||
def test_resolve_session_key_does_not_prefix_sk():
|
||||
key = _resolve_session_key({"messages": [{"role": "user", "content": "hi"}]})
|
||||
assert key and not key.startswith("sk_")
|
||||
|
||||
|
||||
def test_resolve_session_key_segments_by_identity_fields():
|
||||
"""Same first message but different api keys must yield different keys."""
|
||||
msgs = [{"role": "user", "content": "same prompt"}]
|
||||
k_team_a = _resolve_session_key(
|
||||
{
|
||||
"messages": msgs,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key_hash": "hash-A",
|
||||
"user_api_key_team_id": "team-1",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
k_team_b = _resolve_session_key(
|
||||
{
|
||||
"messages": msgs,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key_hash": "hash-B",
|
||||
"user_api_key_team_id": "team-2",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
assert k_team_a != k_team_b
|
||||
|
||||
|
||||
def test_resolve_session_key_changes_when_first_message_changes():
|
||||
k1 = _resolve_session_key({"messages": [{"role": "user", "content": "alpha"}]})
|
||||
k2 = _resolve_session_key({"messages": [{"role": "user", "content": "beta"}]})
|
||||
assert k1 != k2
|
||||
|
||||
|
||||
# ---- _record gating -----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_skips_when_below_signal_gate():
|
||||
"""Conversations shorter than SIGNAL_GATE_MIN_MESSAGES should be ignored."""
|
||||
hook = _make_hook()
|
||||
short = [{"role": "user", "content": "hi"}]
|
||||
assert len(short) < SIGNAL_GATE_MIN_MESSAGES # sanity
|
||||
kwargs = _kwargs(messages=short)
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
hook.adaptive_router.record_turn.assert_not_awaited()
|
||||
hook.adaptive_router.claim_or_check_owner.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_skips_when_no_messages():
|
||||
hook = _make_hook()
|
||||
kwargs = _kwargs(messages=[])
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
hook.adaptive_router.record_turn.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_skips_when_chosen_model_missing_from_metadata():
|
||||
hook = _make_hook()
|
||||
kwargs = _kwargs(chosen=None)
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
hook.adaptive_router.record_turn.assert_not_awaited()
|
||||
hook.adaptive_router.claim_or_check_owner.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_skips_when_owner_cache_mismatch():
|
||||
"""A different model owns this conversation -> no attribution."""
|
||||
hook = _make_hook(claim=False)
|
||||
kwargs = _kwargs(chosen="fast")
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
hook.adaptive_router.claim_or_check_owner.assert_called_once()
|
||||
hook.adaptive_router.record_turn.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_records_turn_when_owner_claims():
|
||||
hook = _make_hook(claim=True)
|
||||
kwargs = _kwargs(chosen="smart", messages=_long_messages("ask"))
|
||||
await hook.async_log_success_event(
|
||||
kwargs, _resp_with_content("answer here"), 0.0, 1.0
|
||||
)
|
||||
call = hook.adaptive_router.record_turn.await_args
|
||||
assert call.kwargs["model_name"] == "smart"
|
||||
turn: Turn = call.kwargs["turn"]
|
||||
assert turn.user_content == "ask"
|
||||
assert turn.assistant_content == "answer here"
|
||||
assert turn.response_status == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_uses_explicit_session_id_when_provided():
|
||||
"""Explicit `litellm_session_id` is forwarded as the session key."""
|
||||
hook = _make_hook()
|
||||
kwargs = _kwargs(
|
||||
chosen="fast",
|
||||
extra_litellm_params={"litellm_session_id": "explicit-sess"},
|
||||
)
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
args, _ = hook.adaptive_router.claim_or_check_owner.call_args
|
||||
assert args[0] == "explicit-sess"
|
||||
assert hook.adaptive_router.record_turn.await_args.kwargs["session_id"] == (
|
||||
"explicit-sess"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_passes_tool_calls_through():
|
||||
hook = _make_hook()
|
||||
tc = {"name": "search", "arguments": '{"q":"x"}'}
|
||||
kwargs = _kwargs(chosen="fast")
|
||||
await hook.async_log_success_event(
|
||||
kwargs, _resp_with_content("calling tool", tool_calls=[tc]), 0.0, 1.0
|
||||
)
|
||||
turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"]
|
||||
assert turn.tool_calls == [tc]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_swallows_exceptions_from_record_turn():
|
||||
hook = _make_hook()
|
||||
hook.adaptive_router.record_turn.side_effect = RuntimeError("boom")
|
||||
kwargs = _kwargs(chosen="fast")
|
||||
# Must NOT raise — signal recording must never break a request.
|
||||
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_failure_event_uses_status_code_from_exception():
|
||||
hook = _make_hook()
|
||||
exc = MagicMock()
|
||||
exc.status_code = 429
|
||||
kwargs = _kwargs(chosen="fast")
|
||||
kwargs["exception"] = exc
|
||||
await hook.async_log_failure_event(kwargs, None, 0.0, 1.0)
|
||||
turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"]
|
||||
assert turn.response_status == 429
|
||||
|
||||
|
||||
# ---- async_post_call_success_hook (response header surfacing) ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_sets_response_header():
|
||||
hook = _make_hook()
|
||||
response = MagicMock()
|
||||
response._hidden_params = {}
|
||||
|
||||
await hook.async_post_call_success_hook(
|
||||
data={"metadata": {"adaptive_router_chosen_model": "smart"}},
|
||||
user_api_key_dict=MagicMock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
assert (
|
||||
response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"]
|
||||
== "smart"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_preserves_existing_additional_headers():
|
||||
hook = _make_hook()
|
||||
response = MagicMock()
|
||||
response._hidden_params = {"additional_headers": {"x-existing": "keep-me"}}
|
||||
|
||||
await hook.async_post_call_success_hook(
|
||||
data={"metadata": {"adaptive_router_chosen_model": "fast"}},
|
||||
user_api_key_dict=MagicMock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
assert response._hidden_params["additional_headers"]["x-existing"] == "keep-me"
|
||||
assert (
|
||||
response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"]
|
||||
== "fast"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_noop_when_metadata_missing_key():
|
||||
hook = _make_hook()
|
||||
response = MagicMock()
|
||||
response._hidden_params = {}
|
||||
|
||||
await hook.async_post_call_success_hook(
|
||||
data={"metadata": {"litellm_session_id": "sess-A"}},
|
||||
user_api_key_dict=MagicMock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
assert response._hidden_params == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_noop_when_no_metadata():
|
||||
hook = _make_hook()
|
||||
response = MagicMock()
|
||||
response._hidden_params = {}
|
||||
|
||||
await hook.async_post_call_success_hook(
|
||||
data={},
|
||||
user_api_key_dict=MagicMock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
assert response._hidden_params == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_success_hook_noop_when_hidden_params_not_dict():
|
||||
hook = _make_hook()
|
||||
|
||||
class _NoHiddenParams:
|
||||
pass
|
||||
|
||||
response = _NoHiddenParams()
|
||||
|
||||
await hook.async_post_call_success_hook(
|
||||
data={"metadata": {"adaptive_router_chosen_model": "smart"}},
|
||||
user_api_key_dict=MagicMock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
assert not hasattr(response, "_hidden_params")
|
||||
@ -0,0 +1,380 @@
|
||||
"""Tests for the Router-level wiring of the adaptive router.
|
||||
|
||||
Specifically guards the four bugs found when wiring the example config
|
||||
`auto_router/adaptive_router` end-to-end:
|
||||
|
||||
1. The `auto_router/adaptive_router` model prefix must NOT trigger the
|
||||
semantic auto-router init path (which would crash on missing fields).
|
||||
2. The same prefix MUST trigger the adaptive-router init path.
|
||||
3. `init_adaptive_router_deployment` must read `input_cost_per_token`
|
||||
from `litellm_params` (where users put it), not just `model_info`.
|
||||
4. `Router.async_pre_routing_hook` must dispatch to the matching entry in
|
||||
`self.adaptive_routers` when the inbound model matches a configured
|
||||
adaptive-router name, returning the underlying model the bandit picked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm import Router
|
||||
from litellm.types.router import LiteLLM_Params, RequestType
|
||||
|
||||
|
||||
def _params(**overrides):
|
||||
base = {"model": "auto_router/adaptive_router"}
|
||||
base.update(overrides)
|
||||
return LiteLLM_Params(**base)
|
||||
|
||||
|
||||
# ---- Fix 1 & 2: opt-in prefix routing -----------------------------------
|
||||
|
||||
|
||||
def test_auto_router_check_excludes_adaptive_router_prefix():
|
||||
r = Router(model_list=[])
|
||||
assert (
|
||||
r._is_auto_router_deployment(
|
||||
litellm_params=_params(model="auto_router/adaptive_router")
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_auto_router_check_excludes_complexity_router_prefix():
|
||||
r = Router(model_list=[])
|
||||
assert (
|
||||
r._is_auto_router_deployment(
|
||||
litellm_params=_params(model="auto_router/complexity_router")
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_auto_router_check_still_matches_plain_auto_router_prefix():
|
||||
r = Router(model_list=[])
|
||||
assert (
|
||||
r._is_auto_router_deployment(
|
||||
litellm_params=_params(model="auto_router/my-semantic-router")
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_adaptive_router_check_recognizes_prefix():
|
||||
r = Router(model_list=[])
|
||||
assert (
|
||||
r._is_adaptive_router_deployment(
|
||||
litellm_params=_params(model="auto_router/adaptive_router")
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_adaptive_router_check_rejects_other_prefixes():
|
||||
r = Router(model_list=[])
|
||||
assert (
|
||||
r._is_adaptive_router_deployment(litellm_params=_params(model="openai/gpt-4o"))
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# ---- Fix 3: cost field path --------------------------------------------
|
||||
|
||||
|
||||
def test_init_adaptive_router_reads_cost_from_litellm_params():
|
||||
r = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "smart-cheap-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {
|
||||
"available_models": ["fast", "smart"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "fast",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"input_cost_per_token": 0.00000015,
|
||||
},
|
||||
"model_info": {
|
||||
"adaptive_router_preferences": {
|
||||
"quality_tier": 2,
|
||||
"strengths": [],
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "smart",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"input_cost_per_token": 0.0000050,
|
||||
},
|
||||
"model_info": {
|
||||
"adaptive_router_preferences": {
|
||||
"quality_tier": 3,
|
||||
"strengths": ["code_generation"],
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
assert "smart-cheap-router" in r.adaptive_routers
|
||||
assert r.adaptive_routers["smart-cheap-router"].model_to_cost == {
|
||||
"fast": 0.00000015,
|
||||
"smart": 0.0000050,
|
||||
}
|
||||
|
||||
|
||||
# ---- Fix 4: pre-routing dispatch ---------------------------------------
|
||||
|
||||
|
||||
def _router_with_adaptive() -> Router:
|
||||
return Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "smart-cheap-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {
|
||||
"available_models": ["fast", "smart"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "fast",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"input_cost_per_token": 0.00000015,
|
||||
},
|
||||
"model_info": {
|
||||
"adaptive_router_preferences": {
|
||||
"quality_tier": 2,
|
||||
"strengths": [],
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "smart",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"input_cost_per_token": 0.0000050,
|
||||
},
|
||||
"model_info": {
|
||||
"adaptive_router_preferences": {
|
||||
"quality_tier": 3,
|
||||
"strengths": ["code_generation"],
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_dispatches_to_adaptive_router():
|
||||
r = _router_with_adaptive()
|
||||
ar = r.adaptive_routers["smart-cheap-router"]
|
||||
ar.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment]
|
||||
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={"metadata": {"litellm_session_id": "sess-A"}},
|
||||
messages=[{"role": "user", "content": "Write a Python function"}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.model == "smart"
|
||||
call = ar.pick_model.await_args # type: ignore[union-attr]
|
||||
# Stateless routing: session_id is no longer passed to pick_model.
|
||||
assert "session_id" not in call.kwargs
|
||||
assert call.kwargs["request_type"] == RequestType.CODE_GENERATION
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_pick_model_not_passed_session_id():
|
||||
r = _router_with_adaptive()
|
||||
ar = r.adaptive_routers["smart-cheap-router"]
|
||||
ar.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment]
|
||||
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.model == "fast"
|
||||
assert "session_id" not in ar.pick_model.await_args.kwargs # type: ignore[union-attr]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_returns_none_for_unrelated_model():
|
||||
r = _router_with_adaptive()
|
||||
ar = r.adaptive_routers["smart-cheap-router"]
|
||||
ar.pick_model = AsyncMock() # type: ignore[assignment]
|
||||
response = await r.async_pre_routing_hook(
|
||||
model="some-other-model",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "x"}],
|
||||
)
|
||||
assert response is None
|
||||
ar.pick_model.assert_not_awaited() # type: ignore[union-attr]
|
||||
|
||||
|
||||
# ---- Response header surfacing -----------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_stashes_chosen_model_in_metadata():
|
||||
"""
|
||||
The adaptive-router branch must record the chosen logical model on
|
||||
`request_kwargs["metadata"]` so `_acompletion` can surface it as the
|
||||
`x-litellm-adaptive-router-model` response header.
|
||||
"""
|
||||
r = _router_with_adaptive()
|
||||
r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment]
|
||||
return_value="smart"
|
||||
)
|
||||
|
||||
request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}}
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs=request_kwargs,
|
||||
messages=[{"role": "user", "content": "Write a Python function"}],
|
||||
)
|
||||
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_creates_metadata_when_missing():
|
||||
"""If no metadata was passed in, the hook should create one to stash the chosen model."""
|
||||
r = _router_with_adaptive()
|
||||
r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment]
|
||||
return_value="fast"
|
||||
)
|
||||
|
||||
request_kwargs: dict = {}
|
||||
await r.async_pre_routing_hook(
|
||||
model="smart-cheap-router",
|
||||
request_kwargs=request_kwargs,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast"
|
||||
|
||||
|
||||
# ---- Multi-router support ----------------------------------------------
|
||||
|
||||
|
||||
def test_two_adaptive_routers_can_coexist_on_one_router():
|
||||
r = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "cheap-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {"available_models": ["fast"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "premium-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {"available_models": ["smart"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "fast",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"input_cost_per_token": 0.00000015,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "smart",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"input_cost_per_token": 0.0000050,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
assert set(r.adaptive_routers.keys()) == {"cheap-router", "premium-router"}
|
||||
assert r.adaptive_routers["cheap-router"].config.available_models == ["fast"]
|
||||
assert r.adaptive_routers["premium-router"].config.available_models == ["smart"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_pre_routing_hook_dispatches_to_correct_router_when_multiple():
|
||||
"""Each adaptive router only handles its own router_name."""
|
||||
r = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "cheap-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {"available_models": ["fast"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "premium-router",
|
||||
"litellm_params": {
|
||||
"model": "auto_router/adaptive_router",
|
||||
"adaptive_router_config": {"available_models": ["smart"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "fast",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"input_cost_per_token": 0.00000015,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "smart",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"input_cost_per_token": 0.0000050,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
cheap = r.adaptive_routers["cheap-router"]
|
||||
premium = r.adaptive_routers["premium-router"]
|
||||
cheap.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment]
|
||||
premium.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment]
|
||||
|
||||
cheap_response = await r.async_pre_routing_hook(
|
||||
model="cheap-router",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
premium_response = await r.async_pre_routing_hook(
|
||||
model="premium-router",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert cheap_response is not None and cheap_response.model == "fast"
|
||||
assert premium_response is not None and premium_response.model == "smart"
|
||||
cheap.pick_model.assert_awaited_once() # type: ignore[union-attr]
|
||||
premium.pick_model.assert_awaited_once() # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_init_adaptive_router_rejects_duplicate_model_name():
|
||||
"""Two adaptive-router deployments with the same model_name must error."""
|
||||
from litellm.types.router import AdaptiveRouterConfig, Deployment
|
||||
|
||||
r = Router(model_list=[])
|
||||
cfg = {"available_models": ["fast"]}
|
||||
deployment = Deployment(
|
||||
model_name="dup-router",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="auto_router/adaptive_router",
|
||||
adaptive_router_config=cfg,
|
||||
),
|
||||
model_info={"id": "x"},
|
||||
)
|
||||
r.init_adaptive_router_deployment(deployment=deployment)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
r.init_adaptive_router_deployment(deployment=deployment)
|
||||
@ -0,0 +1,112 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.router_strategy.adaptive_router.config import TOOL_CALL_HISTORY_MAX
|
||||
from litellm.router_strategy.adaptive_router.signals import (
|
||||
SessionState,
|
||||
SignalDelta,
|
||||
Turn,
|
||||
apply_turn,
|
||||
)
|
||||
|
||||
FIXTURE_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
def _load(name: str) -> list:
|
||||
return json.loads((FIXTURE_DIR / f"{name}.json").read_text())
|
||||
|
||||
|
||||
def _replay(turns: list) -> Tuple[SessionState, List[SignalDelta]]:
|
||||
state = SessionState(
|
||||
session_id="s",
|
||||
router_name="r",
|
||||
model_name="m",
|
||||
classified_type="general",
|
||||
)
|
||||
deltas: List[SignalDelta] = []
|
||||
for t in turns:
|
||||
deltas.append(
|
||||
apply_turn(
|
||||
state,
|
||||
Turn(
|
||||
user_content=t.get("user_content"),
|
||||
assistant_content=t.get("assistant_content"),
|
||||
tool_calls=t.get("tool_calls", []),
|
||||
tool_results=t.get("tool_results", []),
|
||||
response_status=t.get("response_status"),
|
||||
),
|
||||
)
|
||||
)
|
||||
return state, deltas
|
||||
|
||||
|
||||
def test_clean_satisfaction_fires_satisfaction_only():
|
||||
state, _ = _replay(_load("clean_satisfaction"))
|
||||
assert state.satisfaction_count >= 1
|
||||
assert state.failure_count == 0
|
||||
assert state.disengagement_count == 0
|
||||
|
||||
|
||||
def test_misalignment_fires_on_rephrase():
|
||||
state, _ = _replay(_load("misalignment_rephrase"))
|
||||
assert state.misalignment_count >= 1
|
||||
|
||||
|
||||
def test_stagnation_fires_on_repeated_assistant():
|
||||
state, _ = _replay(_load("stagnation_repeat"))
|
||||
assert state.stagnation_count >= 1
|
||||
|
||||
|
||||
def test_disengagement_fires_on_giveup():
|
||||
state, _ = _replay(_load("disengagement_giveup"))
|
||||
assert state.disengagement_count >= 1
|
||||
|
||||
|
||||
def test_failure_fires_on_tool_error():
|
||||
state, _ = _replay(_load("failure_tool_error"))
|
||||
assert state.failure_count == 1
|
||||
|
||||
|
||||
def test_loop_fires_on_repeated_tool():
|
||||
state, _ = _replay(_load("loop_same_tool"))
|
||||
assert state.loop_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fixture", ["exhaustion_429", "exhaustion_context_overflow"])
|
||||
def test_exhaustion_fires_on_infra_signal(fixture):
|
||||
state, _ = _replay(_load(fixture))
|
||||
assert state.exhaustion_count >= 1
|
||||
|
||||
|
||||
def test_no_signals_on_clean_session():
|
||||
state, _ = _replay(_load("clean_no_signals"))
|
||||
assert state.misalignment_count == 0
|
||||
assert state.stagnation_count == 0
|
||||
assert state.disengagement_count == 0
|
||||
assert state.failure_count == 0
|
||||
assert state.loop_count == 0
|
||||
assert state.exhaustion_count == 0
|
||||
|
||||
|
||||
def test_mixed_failure_then_satisfaction():
|
||||
state, _ = _replay(_load("mixed_failure_then_satisfaction"))
|
||||
assert state.failure_count >= 1
|
||||
assert state.satisfaction_count >= 1
|
||||
|
||||
|
||||
def test_apply_turn_is_o1_does_not_grow_history_unbounded():
|
||||
state = SessionState(
|
||||
session_id="s",
|
||||
router_name="r",
|
||||
model_name="m",
|
||||
classified_type="general",
|
||||
)
|
||||
for i in range(100):
|
||||
apply_turn(
|
||||
state,
|
||||
Turn(tool_calls=[{"name": f"tool_{i}", "arguments": {}}]),
|
||||
)
|
||||
assert len(state.tool_call_history) <= TOOL_CALL_HISTORY_MAX
|
||||
@ -0,0 +1,196 @@
|
||||
"""Tests for the GET /adaptive_router/state introspection endpoint and the
|
||||
underlying `AdaptiveRouter.get_state_snapshot()` helper."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
|
||||
from litellm.router_strategy.adaptive_router.bandit import BanditCell, apply_delta
|
||||
from litellm.types.router import (
|
||||
AdaptiveRouterConfig,
|
||||
AdaptiveRouterPreferences,
|
||||
RequestType,
|
||||
)
|
||||
|
||||
|
||||
def _make_router(name: str = "r1") -> AdaptiveRouter:
|
||||
cfg = AdaptiveRouterConfig(available_models=["fast", "smart"])
|
||||
prefs = {
|
||||
"fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]),
|
||||
"smart": AdaptiveRouterPreferences(
|
||||
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
|
||||
),
|
||||
}
|
||||
costs = {"fast": 0.0001, "smart": 0.001}
|
||||
return AdaptiveRouter(
|
||||
router_name=name,
|
||||
config=cfg,
|
||||
model_to_prefs=prefs,
|
||||
model_to_cost=costs,
|
||||
)
|
||||
|
||||
|
||||
# ---- snapshot helper ---------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_returns_cell_per_request_type_per_model():
|
||||
r = _make_router()
|
||||
snap = await r.get_state_snapshot()
|
||||
|
||||
# Top-level shape
|
||||
assert snap["router_name"] == "r1"
|
||||
assert snap["available_models"] == ["fast", "smart"]
|
||||
assert snap["weights"] == {"quality": 0.7, "cost": 0.3}
|
||||
assert snap["model_costs"] == {"fast": 0.0001, "smart": 0.001}
|
||||
assert snap["owner_cache_live"] == 0
|
||||
assert snap["skipped_updates_total"] == 0
|
||||
assert set(snap["queue"].keys()) == {
|
||||
"state_pending",
|
||||
"session_pending",
|
||||
"max_state_seen",
|
||||
"max_session_seen",
|
||||
}
|
||||
|
||||
# 7 request types x 2 models = 14 cells
|
||||
assert len(snap["cells"]) == len(list(RequestType)) * 2
|
||||
for cell in snap["cells"]:
|
||||
assert set(cell.keys()) == {
|
||||
"request_type",
|
||||
"model",
|
||||
"alpha",
|
||||
"beta",
|
||||
"samples",
|
||||
"quality_mean",
|
||||
}
|
||||
assert cell["model"] in {"fast", "smart"}
|
||||
assert cell["request_type"] in {rt.value for rt in RequestType}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_quality_mean_matches_alpha_over_total():
|
||||
r = _make_router()
|
||||
|
||||
# Manually mutate one cell to a known state so the math is verifiable.
|
||||
key = (RequestType.CODE_GENERATION, "smart")
|
||||
r._cells[key] = apply_delta(r._cells[key], delta_alpha=10.0, delta_beta=0.0)
|
||||
expected = r._cells[key]
|
||||
expected_mean = expected.alpha / (expected.alpha + expected.beta)
|
||||
|
||||
snap = await r.get_state_snapshot()
|
||||
cell = next(
|
||||
c
|
||||
for c in snap["cells"]
|
||||
if c["request_type"] == "code_generation" and c["model"] == "smart"
|
||||
)
|
||||
assert cell["alpha"] == expected.alpha
|
||||
assert cell["beta"] == expected.beta
|
||||
assert cell["samples"] == expected.alpha + expected.beta
|
||||
assert cell["quality_mean"] == pytest.approx(expected_mean)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_counts_only_live_owner_cache_entries():
|
||||
r = _make_router()
|
||||
now = time.time()
|
||||
r._owner_cache["live-1"] = ("fast", now + 3600)
|
||||
r._owner_cache["live-2"] = ("smart", now + 3600)
|
||||
r._owner_cache["expired-1"] = ("fast", now - 1)
|
||||
|
||||
snap = await r.get_state_snapshot()
|
||||
assert snap["owner_cache_live"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_state_snapshot_exposes_skipped_updates_total():
|
||||
r = _make_router()
|
||||
r._skipped_updates_total = 7
|
||||
snap = await r.get_state_snapshot()
|
||||
assert snap["skipped_updates_total"] == 7
|
||||
|
||||
|
||||
# ---- endpoint --------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_returns_404_when_no_adaptive_router(monkeypatch):
|
||||
"""When llm_router is set but has no adaptive routers configured, return 404."""
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
fake_router = MagicMock()
|
||||
fake_router.adaptive_routers = {}
|
||||
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
|
||||
|
||||
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_returns_404_when_llm_router_is_none(monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
monkeypatch.setattr(proxy_server, "llm_router", None)
|
||||
|
||||
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_rejects_non_admin_role(monkeypatch):
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
fake_router = MagicMock()
|
||||
fake_router.adaptive_routers = {"r1": _make_router()}
|
||||
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
|
||||
|
||||
non_admin = UserAPIKeyAuth(
|
||||
api_key="sk-user", user_role=LitellmUserRoles.INTERNAL_USER
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await proxy_server.get_adaptive_router_state(user_api_key_dict=non_admin)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_returns_snapshot_list_for_admin(monkeypatch):
|
||||
"""Single configured router still returns the {"routers": [...]} list shape."""
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
fake_router = MagicMock()
|
||||
fake_router.adaptive_routers = {"r1": _make_router("r1")}
|
||||
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
|
||||
|
||||
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
|
||||
assert list(result.keys()) == ["routers"]
|
||||
assert len(result["routers"]) == 1
|
||||
snap = result["routers"][0]
|
||||
assert snap["router_name"] == "r1"
|
||||
assert snap["available_models"] == ["fast", "smart"]
|
||||
assert len(snap["cells"]) == len(list(RequestType)) * 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_returns_one_snapshot_per_router(monkeypatch):
|
||||
"""With multiple adaptive routers configured, return one snapshot per router."""
|
||||
from litellm.proxy import proxy_server
|
||||
|
||||
fake_router = MagicMock()
|
||||
fake_router.adaptive_routers = {
|
||||
"r1": _make_router("r1"),
|
||||
"r2": _make_router("r2"),
|
||||
}
|
||||
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
|
||||
|
||||
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
|
||||
result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
|
||||
names = sorted(s["router_name"] for s in result["routers"])
|
||||
assert names == ["r1", "r2"]
|
||||
6
uv.lock
generated
6
uv.lock
generated
@ -10,7 +10,7 @@ resolution-markers = [
|
||||
]
|
||||
|
||||
[options]
|
||||
exclude-newer = "2026-04-13T16:35:18.496811Z"
|
||||
exclude-newer = "2026-04-15T20:11:16.497522Z"
|
||||
exclude-newer-span = "P3D"
|
||||
|
||||
[manifest]
|
||||
@ -3767,7 +3767,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.83.8"
|
||||
version = "1.83.9"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
@ -4114,7 +4114,7 @@ source = { editable = "enterprise" }
|
||||
|
||||
[[package]]
|
||||
name = "litellm-proxy-extras"
|
||||
version = "0.4.65"
|
||||
version = "0.4.66"
|
||||
source = { editable = "litellm-proxy-extras" }
|
||||
|
||||
[[package]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user