Litellm krrish staging 04 20 2026 (#26138)
* feat(router): add auto_router/quality_router for quality-tier routing (#25987) * feat(router): add auto_router/quality_router for quality-tier routing Adds a new auto-router type that routes a request to a model at a target quality tier. The quality tier is inferred by re-using the existing ComplexityRouter's classification, then mapped through an admin-configured complexity_to_quality table. Each candidate model declares its own quality_tier in model_info.litellm_routing_preferences. Resolution strategy: exact tier match, else round up to the next higher tier, else fall back to default_model. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * feat(quality_router): add capability-based filtering Each deployment can declare a `capabilities: List[str]` field in `model_info.litellm_routing_preferences` (e.g. ["vision", "function_calling"]). Requests can pass `litellm_capabilities` in `request_kwargs` to require specific capabilities — the router will only route to deployments whose declared capabilities are a superset. Resolution still walks tier (exact → round up), but at each tier filters by capability before picking. Falls back to default_model only when it also satisfies the required capabilities; otherwise raises rather than silently routing to a model that lacks a required capability. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * feat(quality_router): expose routing decision in response headers For transparency, expose the QualityRouter's routing decision in the proxy response headers: x-litellm-quality-router-model → picked model_name (e.g. "haiku-vision") x-litellm-quality-router-tier → resolved quality tier (e.g. "1") x-litellm-quality-router-complexity → ComplexityTier name (e.g. "SIMPLE") Mechanism: the pre-routing hook stashes the decision in request_kwargs["metadata"]["quality_router_decision"]. After the call returns, Router.set_response_headers lifts the decision into response._hidden_params["additional_headers"] alongside the existing x-litellm-model-group / x-litellm-model-id headers. Existing metadata keys (trace_id, user_id, etc.) are preserved. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * feat(quality_router): replace capabilities with keyword override Drops the capability-based filtering in favor of a keyword-based override for v0: - RoutingPreferences.keywords: List[str] (replaces capabilities) — each deployment can declare substring keywords. - If any declared keyword (case-insensitive) appears in the user message, the router short-circuits the complexity-classification flow and routes to the matching deployment. - Tiebreaker for overlapping keyword matches: quality_tier DESC, then cheapest model_info.input_cost_per_token ASC. Unpriced models lose ties to priced ones. Decision metadata + headers now expose the override: x-litellm-quality-router-via → "keyword" | "quality_tier" x-litellm-quality-router-keyword → matched keyword (only on keyword route) x-litellm-quality-router-complexity → complexity tier (only on tier route) Removes: - request_kwargs["litellm_capabilities"] reading - _model_capabilities, _model_supports_capabilities, _first_capable_model_at_tier, capability filter in _resolve_model_for_quality_tier Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * feat(quality_router): add explicit `order` to RoutingPreferences Adds an explicit priority field to RoutingPreferences for resolving collisions deterministically: RoutingPreferences.order: Optional[int] # lower wins; unset = +inf Used as the PRIMARY tiebreaker in two places: 1. Keyword overlap: when multiple deployments declare the same matching keyword, sort by (order ASC, quality_tier DESC, input_cost_per_token ASC, model_name ASC). Explicit always beats implicit. 2. Tier resolution: when multiple deployments share a quality tier, `_resolve_model_for_quality_tier` picks the one with the lowest order. The tier list is now sorted at index-build time. This lets admins make routing decisions explicit when the natural quality-and-price ordering would pick the wrong model. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * feat(quality_router): reorder tiebreak to (quality, order, price) Changes the tiebreak ordering so quality_tier always wins first, then explicit `order` is used to break ties within the same tier, then price breaks the rest: 1. quality_tier DESC ← best model wins first 2. order ASC ← explicit priority within a tier 3. input_cost_per_token ASC 4. model_name ASC Previously `order` was the primary key — that meant a tier-2 model with `order=1` would beat a tier-3 model with no `order`, which is the wrong default. Now `order` only resolves collisions among same-tier candidates. Tier resolution (within a single tier) keeps the same key minus quality: (order ASC, cost ASC, name). Test renames + flips: - test_explicit_order_overrides_quality_tier → test_quality_wins_over_explicit_order - new: test_order_breaks_tie_within_same_quality_tier Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * fix(quality_router): resolve Greptile review feedback Addresses four P1 findings from PR review plus test coverage: 1. set_model_list missing quality_routers reset - Hot-reloading the Router would leave stale QualityRouter instances pointing at the old model_list. `set_model_list` now clears `self.quality_routers` alongside the other indices. 2. Round-down fallback before default_model - `_resolve_model_for_quality_tier` now rounds DOWN to the closest lower tier after round-up fails, before falling back to `default_model`. Degrades gracefully rather than jumping straight off-tier. 3. RoutingPreferences validation bypass - `_build_tier_index` now instantiates `RoutingPreferences(**prefs)` so invalid shapes (e.g. non-int quality_tier) raise a clear ValueError instead of silently succeeding. 4. Config-ordering dependency - `_tier_to_models` is now built lazily on first access. Previously, eager construction in `__init__` meant a QualityRouter deployment had to appear AFTER all its referenced models in config.yaml, because `Router._create_deployment` populates `model_list` incrementally. Any `available_models` defined after the router entry would silently be reported as missing. Also adds 6 new tests covering each fix: - test_invalid_quality_tier_type_raises_clear_error - test_router_can_be_instantiated_before_its_targets_exist - test_set_model_list_clears_quality_routers_registry - test_rounds_down_when_no_higher_tier_exists - test_rounds_down_prefers_closest_lower_tier - test_prefers_round_up_over_round_down Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com> * style: apply black 24.10.0 formatting to pre-existing offenders Unblocks the LiteLLM Linting check for this PR — these 12 files are already failing `black --check` on main (the lint workflow only runs on PRs, so main drifts). No behavior changes; formatting-only. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Update litellm/router.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: Claude Opus 4 (1M context) <noreply@anthropic.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Support /v1/responses in complexity router (#26137) * feat(proxy): add --reload flag for uvicorn hot reload (dev only) Opt-in CLI flag, off by default, no env var. Only affects the uvicorn run path; gunicorn/hypercorn paths and prod (which doesn't pass the flag) are unaffected. * Feature/add audio support for scaleway (#26110) * feat(scaleway): add SCALEWAY to LlmProviders enum * feat(scaleway): add audio transcription config and dispatch wiring Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test(scaleway): add behavior tests for audio transcription config Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * chore(scaleway): advertise audio_transcriptions in endpoint-support JSON * docs(scaleway): document audio transcription support * fix(scaleway): address PR review — plain-text response_format + missing-key fail-fast Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test(scaleway): cover new response paths, drop gettysburg.wav coupling Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> * Prompt Compression - add it to the proxy (#25729) * refactor: new agentic loop event hook simplifies how to create logic for tool based multi llm calls * fix: compress - make it work on anthropic input as well * fix(compress.py): working prompt compression for claude code ensures claude code messages can run through proxy easily * docs: add agentic loop hook guide * docs: add agentic_loop_hook to sidebar * fix: fix multiple arguments error * fix: fix tool call loop for compression on streaming /v1/messages * fix: fix linting errors * fix: fix ci/cd errors * feat(litellm_pre_call_utils.py): use claude code session for litellm session id allows claude code logs to be stitched together, making it easy to know they were all part of the same conversation * fix: suppress incorrect mypy warning rE: module * revert: drop PR's changes to litellm/proxy/_experimental/out/ Restores the 34 HTML files under _experimental/out/ to their pre-PR paths (X/index.html -> X.html). All renames are R100 (content unchanged); no other files are touched. * fix: address greptile review comments on PR #25729 - Skip ``kwargs["tools"] = []`` injection when compression is a no-op — Anthropic Messages rejects empty tool arrays on requests that did not originally declare tools. - Move agentic-loop safety guards (fingerprint cycle / max depth) out of the per-callback try/except so they propagate instead of being swallowed by the generic exception handler. Extracted _check_agentic_loop_safety. - Gate generic ``x-<vendor>-session-id`` capture behind the LITELLM_CAPTURE_VENDOR_SESSION_HEADERS env var (off by default) to preserve backwards compatibility; explicit x-litellm-* headers are unaffected. - Fix monkeypatch target in pre-call-hook test to patch the actual module-level binding (litellm.integrations.compression_interception.handler.compress). - Add regression tests for empty-tools skip and opt-in session capture. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * revert: drop LITELLM_CAPTURE_VENDOR_SESSION_HEADERS flag Generic x-<vendor>-session-id header capture is a new feature and only runs *after* the explicit x-litellm-trace-id / x-litellm-session-id checks, so it does not change behavior for any existing caller that was already using the LiteLLM headers — no backwards-incompatibility to gate. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(compress): replace input_type with CallTypes call_type Drop the bespoke ``CompressionInputType`` literal and use the existing ``litellm.types.utils.CallTypes`` enum instead. ``litellm.compress()`` now takes ``call_type: Union[CallTypes, str]`` (default ``CallTypes.completion``) — no new concept to learn, and the enum is already the way the rest of the codebase talks about request shapes. Supported values: ``completion`` / ``acompletion`` (OpenAI chat-completions shape) and ``anthropic_messages`` (Anthropic structured content blocks). Updated: compress(), the compression_interception handler, tests, docs, and the two eval scripts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * Support /v1/responses in complexity router Adds cross-format support to the complexity router via the guardrail translation handler dispatch. Adds get_structured_messages to base translation plus OpenAI chat, Responses, and Anthropic handlers. Auto-router helper _extract_text_from_messages handles tool-call and multimodal messages. Widens async_pre_routing_hook messages type to Dict[str, Any]. Fixes https://github.com/BerriAI/litellm/issues/25134 * chore: apply black formatting * fix: fallback to trying each handler when route inference fails --------- Co-authored-by: Ryan Crabbe <ryan@berri.ai> Co-authored-by: nhyy244 <106547304+nhyy244@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> * test: cover _is_quality_router_deployment and init_quality_router_deployment * fix: reset auto_routers on set_model_list to prevent hot-reload ValueError * style: apply black formatting to websearch_interception and agentic_streaming_iterator --------- Co-authored-by: yuneng-jiang <yuneng@berri.ai> Co-authored-by: Claude Opus 4 (1M context) <noreply@anthropic.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Ryan Crabbe <ryan@berri.ai> Co-authored-by: nhyy244 <106547304+nhyy244@users.noreply.github.com>
This commit is contained in:
parent
7979044c76
commit
e7bc316db0
@ -240,7 +240,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional[PreRoutingHookResponse]:
|
||||
|
||||
@ -34,6 +34,7 @@ from litellm.types.llms.anthropic import (
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
@ -67,6 +68,32 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
super().__init__()
|
||||
self.adapter = LiteLLMAnthropicMessagesAdapter()
|
||||
|
||||
def _translate_to_openai(self, data: dict) -> ChatCompletionRequest:
|
||||
"""Translate Anthropic request to OpenAI chat completion format."""
|
||||
(
|
||||
chat_completion_compatible_request,
|
||||
_tool_name_mapping,
|
||||
) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
|
||||
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
|
||||
)
|
||||
return chat_completion_compatible_request
|
||||
|
||||
def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]:
|
||||
"""
|
||||
Convert Anthropic messages request data to OpenAI-spec structured messages.
|
||||
|
||||
Uses the Anthropic-to-OpenAI adapter to translate message format.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
if messages is None:
|
||||
return None
|
||||
chat_completion_compatible_request = self._translate_to_openai(data)
|
||||
result = cast(
|
||||
List[AllMessageValues],
|
||||
chat_completion_compatible_request.get("messages", []),
|
||||
)
|
||||
return result if result else None
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
@ -82,13 +109,7 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
|
||||
skip_system = effective_skip_system_message_for_guardrail(guardrail_to_apply)
|
||||
|
||||
(
|
||||
chat_completion_compatible_request,
|
||||
_tool_name_mapping,
|
||||
) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
|
||||
# Use a shallow copy to avoid mutating request data (pop on litellm_metadata).
|
||||
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
|
||||
)
|
||||
chat_completion_compatible_request = self._translate_to_openai(data)
|
||||
|
||||
structured_messages = cast(
|
||||
List[AllMessageValues],
|
||||
@ -103,8 +124,6 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
chat_completion_compatible_request.get("tools", [])
|
||||
)
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
|
||||
# Step 1: Extract all text content and images
|
||||
for msg_idx, message in enumerate(messages):
|
||||
|
||||
@ -5,6 +5,7 @@ if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class BaseTranslation(ABC):
|
||||
@ -101,6 +102,16 @@ class BaseTranslation(ABC):
|
||||
"""
|
||||
return responses_so_far
|
||||
|
||||
def get_structured_messages(self, data: dict) -> Optional[List["AllMessageValues"]]:
|
||||
"""
|
||||
Convert request data to OpenAI-spec structured messages.
|
||||
|
||||
Override in subclasses for format-specific conversion.
|
||||
|
||||
Returns None if no convertible content is found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""
|
||||
Extract tool names from the request body for allowlist/policy checks.
|
||||
|
||||
@ -48,6 +48,17 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]:
|
||||
"""
|
||||
Convert chat completions request data to OpenAI-spec structured messages.
|
||||
|
||||
Messages are already in OpenAI format, so this is a simple extraction.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
if messages is None:
|
||||
return None
|
||||
return cast(List[AllMessageValues], messages)
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
@ -68,9 +79,6 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
tool_calls_to_check: List[ChatCompletionToolParam] = []
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
tool_call_task_mappings: List[Tuple[int, int]] = []
|
||||
# text_task_mappings: Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
# tool_call_task_mappings: Track (message_index, tool_call_index) for each tool call
|
||||
|
||||
# Step 1: Extract all text content, images, and tool calls
|
||||
for msg_idx, message in enumerate(messages):
|
||||
@ -92,12 +100,12 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check # type: ignore
|
||||
if messages:
|
||||
msg_list = cast(List[AllMessageValues], messages)
|
||||
structured_messages = self.get_structured_messages(data)
|
||||
if structured_messages:
|
||||
inputs["structured_messages"] = (
|
||||
openai_messages_without_system(msg_list)
|
||||
openai_messages_without_system(structured_messages)
|
||||
if skip_system
|
||||
else msg_list
|
||||
else structured_messages
|
||||
)
|
||||
# Pass tools (function definitions) to the guardrail
|
||||
tools = data.get("tools")
|
||||
|
||||
@ -43,6 +43,7 @@ from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
@ -70,6 +71,24 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]:
|
||||
"""
|
||||
Convert Responses API request data to OpenAI-spec structured messages.
|
||||
|
||||
Transforms `input` (string or ResponseInputParam) and optional
|
||||
`instructions` into chat completion messages.
|
||||
"""
|
||||
input_data = data.get("input")
|
||||
if input_data is None:
|
||||
return None
|
||||
messages = (
|
||||
LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=input_data,
|
||||
responses_api_request=data,
|
||||
)
|
||||
)
|
||||
return cast(List[AllMessageValues], messages) if messages else None
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
@ -86,12 +105,7 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
if input_data is None:
|
||||
return data
|
||||
|
||||
structured_messages = (
|
||||
LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=input_data,
|
||||
responses_api_request=data,
|
||||
)
|
||||
)
|
||||
structured_messages = self.get_structured_messages(data)
|
||||
|
||||
# Handle simple string input
|
||||
if isinstance(input_data, str):
|
||||
|
||||
@ -30,13 +30,13 @@ model_list:
|
||||
id: claude-sonnet-4-custom-pricing
|
||||
input_cost_per_token: 0.0003 # 100x standard ($0.000003)
|
||||
output_cost_per_token: 0.0015 # 100x standard ($0.000015)
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["compression_interception"]
|
||||
compression_interception_params:
|
||||
enabled: true
|
||||
compression_trigger: 100000
|
||||
# # optional:
|
||||
# # embedding_model: "text-embedding-3-small"
|
||||
# # embedding_model_params:
|
||||
# # dimensions: 512
|
||||
- model_name: my-auto
|
||||
litellm_params:
|
||||
model: auto_router/complexity_router
|
||||
complexity_router_config:
|
||||
tiers:
|
||||
SIMPLE: "gpt-4.1-mini"
|
||||
COMPLEX: claude-sonnet-4-6
|
||||
tier_boundaries:
|
||||
simple_medium: 0.30
|
||||
complexity_router_default_model: small-model
|
||||
|
||||
@ -200,12 +200,16 @@ if TYPE_CHECKING:
|
||||
from litellm.router_strategy.complexity_router.complexity_router import (
|
||||
ComplexityRouter,
|
||||
)
|
||||
from litellm.router_strategy.quality_router.quality_router import (
|
||||
QualityRouter,
|
||||
)
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
AutoRouter = Any
|
||||
ComplexityRouter = Any
|
||||
QualityRouter = Any
|
||||
PreRoutingHookResponse = Any
|
||||
|
||||
|
||||
@ -464,6 +468,7 @@ class Router:
|
||||
) # {"TEAM_ID": PatternMatchRouter}
|
||||
self.auto_routers: Dict[str, "AutoRouter"] = {}
|
||||
self.complexity_routers: Dict[str, "ComplexityRouter"] = {}
|
||||
self.quality_routers: Dict[str, "QualityRouter"] = {}
|
||||
|
||||
# Initialize model_group_alias early since it's used in set_model_list
|
||||
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
|
||||
@ -5884,7 +5889,7 @@ class Router:
|
||||
response = await response
|
||||
## PROCESS RESPONSE HEADERS
|
||||
response = await self.set_response_headers(
|
||||
response=response, model_group=model_group
|
||||
response=response, model_group=model_group, request_kwargs=kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
@ -6814,6 +6819,8 @@ class Router:
|
||||
"""
|
||||
if litellm_params.model.startswith("auto_router/complexity_router"):
|
||||
return False # This is handled by complexity_router
|
||||
if litellm_params.model.startswith("auto_router/quality_router"):
|
||||
return False # This is handled by quality_router
|
||||
if litellm_params.model.startswith("auto_router/"):
|
||||
return True
|
||||
return False
|
||||
@ -6920,6 +6927,58 @@ class Router:
|
||||
)
|
||||
self.complexity_routers[deployment.model_name] = complexity_router
|
||||
|
||||
def _is_quality_router_deployment(self, litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Check if the deployment is a quality-router deployment.
|
||||
|
||||
Returns True if the litellm_params model starts with "auto_router/quality_router".
|
||||
"""
|
||||
if litellm_params.model.startswith("auto_router/quality_router"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_quality_router_deployment(self, deployment: Deployment):
|
||||
"""
|
||||
Initialize the quality-router deployment.
|
||||
|
||||
Resolves the default model from either `quality_router_default_model` or
|
||||
`quality_router_config["default_model"]`, then instantiates the
|
||||
QualityRouter and stores it in `self.quality_routers`.
|
||||
"""
|
||||
# Import here to mirror the AutoRouter / ComplexityRouter init pattern
|
||||
# and avoid circular imports.
|
||||
from litellm.router_strategy.quality_router.quality_router import (
|
||||
QualityRouter,
|
||||
)
|
||||
|
||||
quality_router_config: Optional[dict] = (
|
||||
deployment.litellm_params.quality_router_config
|
||||
)
|
||||
|
||||
default_model: Optional[str] = (
|
||||
deployment.litellm_params.quality_router_default_model
|
||||
)
|
||||
if default_model is None and quality_router_config:
|
||||
default_model = quality_router_config.get("default_model")
|
||||
|
||||
if default_model is None:
|
||||
raise ValueError(
|
||||
"quality_router_default_model is required for quality-router deployments, "
|
||||
"or set default_model in quality_router_config. Please configure it in the litellm_params"
|
||||
)
|
||||
|
||||
quality_router: QualityRouter = QualityRouter(
|
||||
model_name=deployment.model_name,
|
||||
default_model=default_model,
|
||||
litellm_router_instance=self,
|
||||
quality_router_config=quality_router_config,
|
||||
)
|
||||
if deployment.model_name in self.quality_routers:
|
||||
raise ValueError(
|
||||
f"Quality-router deployment {deployment.model_name} already exists. Please use a different model name."
|
||||
)
|
||||
self.quality_routers[deployment.model_name] = quality_router
|
||||
|
||||
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
|
||||
"""
|
||||
Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments
|
||||
@ -6966,6 +7025,11 @@ class Router:
|
||||
self.model_id_to_deployment_index_map = {} # Reset the index
|
||||
self.model_name_to_deployment_indices = {} # Reset the model_name index
|
||||
self.team_model_to_deployment_indices = {} # Reset the team_model index
|
||||
# Reset per-strategy router registries so hot-reload doesn't leave
|
||||
# stale routers pointing at the old model_list.
|
||||
self.quality_routers = {}
|
||||
self.complexity_routers = {}
|
||||
self.auto_routers = {}
|
||||
self._invalidate_model_group_info_cache()
|
||||
self._invalidate_access_groups_cache()
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
@ -7140,6 +7204,12 @@ class Router:
|
||||
):
|
||||
self.init_complexity_router_deployment(deployment=deployment)
|
||||
|
||||
#########################################################
|
||||
# Check if this is a quality-router deployment
|
||||
#########################################################
|
||||
if self._is_quality_router_deployment(litellm_params=deployment.litellm_params):
|
||||
self.init_quality_router_deployment(deployment=deployment)
|
||||
|
||||
return deployment
|
||||
|
||||
def _initialize_deployment_for_pass_through(
|
||||
@ -8143,7 +8213,10 @@ class Router:
|
||||
return returned_dict
|
||||
|
||||
async def set_response_headers(
|
||||
self, response: Any, model_group: Optional[str] = None
|
||||
self,
|
||||
response: Any,
|
||||
model_group: Optional[str] = None,
|
||||
request_kwargs: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Add the most accurate rate limit headers for a given model response.
|
||||
@ -8164,6 +8237,45 @@ class Router:
|
||||
|
||||
additional_headers = response._hidden_params["additional_headers"] # type: ignore
|
||||
|
||||
# Lift QualityRouter routing decision into response headers for
|
||||
# transparency. The decision is stashed in request_kwargs.metadata
|
||||
# by QualityRouter.async_pre_routing_hook.
|
||||
metadata = (
|
||||
(request_kwargs.get("metadata") or {})
|
||||
if isinstance(request_kwargs, dict)
|
||||
else {}
|
||||
)
|
||||
decision = (
|
||||
metadata.get("quality_router_decision")
|
||||
if isinstance(metadata, dict)
|
||||
else None
|
||||
)
|
||||
if isinstance(decision, dict):
|
||||
# Only emit headers for fields that have a meaningful value.
|
||||
# `complexity_tier` and `matched_keyword` are mutually exclusive
|
||||
# (the keyword path short-circuits classification), so each
|
||||
# request emits one or the other but not both.
|
||||
if decision.get("routed_model") is not None:
|
||||
additional_headers["x-litellm-quality-router-model"] = str(
|
||||
decision["routed_model"]
|
||||
)
|
||||
if decision.get("quality_tier") is not None:
|
||||
additional_headers["x-litellm-quality-router-tier"] = str(
|
||||
decision["quality_tier"]
|
||||
)
|
||||
if decision.get("routed_via") is not None:
|
||||
additional_headers["x-litellm-quality-router-via"] = str(
|
||||
decision["routed_via"]
|
||||
)
|
||||
if decision.get("matched_keyword") is not None:
|
||||
additional_headers["x-litellm-quality-router-keyword"] = str(
|
||||
decision["matched_keyword"]
|
||||
)
|
||||
if decision.get("complexity_tier") is not None:
|
||||
additional_headers["x-litellm-quality-router-complexity"] = str(
|
||||
decision["complexity_tier"]
|
||||
)
|
||||
|
||||
if (
|
||||
"x-ratelimit-remaining-tokens" not in additional_headers
|
||||
and "x-ratelimit-remaining-requests" not in additional_headers
|
||||
@ -8708,8 +8820,6 @@ class Router:
|
||||
and self.routing_strategy == "latency-based-routing"
|
||||
):
|
||||
_settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
|
||||
elif var == "routing_strategy_args":
|
||||
_settings_to_return[var] = None
|
||||
return _settings_to_return
|
||||
|
||||
def update_settings(self, **kwargs):
|
||||
@ -9620,7 +9730,7 @@ class Router:
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional[PreRoutingHookResponse]:
|
||||
@ -9653,6 +9763,18 @@ class Router:
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Check if any quality-router should be used
|
||||
#########################################################
|
||||
if model in self.quality_routers:
|
||||
return await self.quality_routers[model].async_pre_routing_hook(
|
||||
model=model,
|
||||
request_kwargs=request_kwargs,
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_available_deployment(
|
||||
|
||||
@ -82,11 +82,34 @@ class AutoRouter(CustomLogger):
|
||||
)
|
||||
return auto_router_routes
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Extract text content from the last user message for routing.
|
||||
|
||||
Handles tool-call conversations (where the last message may be an
|
||||
assistant or tool message with non-string content) and multimodal
|
||||
messages (where content is a list of content blocks).
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content")
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, list):
|
||||
return " ".join(
|
||||
block.get("text", "")
|
||||
for block in content
|
||||
if isinstance(block, dict) and block.get("type") == "text"
|
||||
)
|
||||
return str(content)
|
||||
return ""
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
@ -120,8 +143,7 @@ class AutoRouter(CustomLogger):
|
||||
auto_sync=self.auto_sync_value,
|
||||
)
|
||||
|
||||
user_message: Dict[str, str] = messages[-1]
|
||||
message_content: str = user_message.get("content", "")
|
||||
message_content = self._extract_text_from_messages(messages)
|
||||
route_choice: Optional[Union[RouteChoice, List[RouteChoice]]] = self.routelayer(
|
||||
text=message_content
|
||||
)
|
||||
|
||||
@ -332,45 +332,68 @@ class ComplexityRouter(CustomLogger):
|
||||
f"No model configured for tier {tier_key} and no default_model set"
|
||||
)
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
def _resolve_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Pre-routing hook called before the routing decision.
|
||||
Resolve messages from the request, converting from other formats if needed.
|
||||
|
||||
Classifies the request by complexity and returns the appropriate model.
|
||||
|
||||
Args:
|
||||
model: The original model name requested.
|
||||
request_kwargs: The request kwargs.
|
||||
messages: The messages in the request.
|
||||
input: Optional input for embeddings.
|
||||
specific_deployment: Whether a specific deployment was requested.
|
||||
|
||||
Returns:
|
||||
PreRoutingHookResponse with the routed model, or None if no routing needed.
|
||||
Uses the guardrail translation handler dispatch to convert Responses API
|
||||
``input`` (or other non-chat-completions formats) into OpenAI-spec messages.
|
||||
"""
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
if messages:
|
||||
return messages
|
||||
|
||||
if messages is None or len(messages) == 0:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No messages provided, skipping routing"
|
||||
)
|
||||
return None
|
||||
from litellm.litellm_core_utils.api_route_to_call_types import (
|
||||
get_call_types_for_route,
|
||||
)
|
||||
from litellm.llms import load_guardrail_translation_mappings
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# Extract the last user message and the last system prompt
|
||||
mappings = load_guardrail_translation_mappings()
|
||||
call_type: Optional[CallTypes] = None
|
||||
|
||||
# 1. Try route-based inference from proxy metadata
|
||||
route = request_kwargs.get("litellm_metadata", {}).get(
|
||||
"user_api_key_request_route"
|
||||
)
|
||||
if route:
|
||||
call_types_list = get_call_types_for_route(route)
|
||||
if call_types_list:
|
||||
for ct in call_types_list:
|
||||
if ct in mappings:
|
||||
call_type = ct
|
||||
break
|
||||
|
||||
# 2. Fallback: try each mapped handler until one produces messages
|
||||
handlers_to_try: List[Any] = []
|
||||
if call_type is not None and call_type in mappings:
|
||||
handlers_to_try.append(mappings[call_type]())
|
||||
else:
|
||||
handlers_to_try.extend(handler_cls() for handler_cls in mappings.values())
|
||||
|
||||
for handler in handlers_to_try:
|
||||
structured = handler.get_structured_messages(request_kwargs)
|
||||
if structured:
|
||||
return [
|
||||
msg if isinstance(msg, dict) else msg.model_dump() # type: ignore
|
||||
for msg in structured
|
||||
]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_user_message_and_system_prompt(
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract the last user message text and last system prompt from messages."""
|
||||
user_message: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
# content may be a list of content parts (e.g. [{"type": "text", "text": "..."}])
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
part.get("text", "")
|
||||
@ -383,6 +406,52 @@ class ComplexityRouter(CustomLogger):
|
||||
user_message = content
|
||||
elif role == "system" and system_prompt is None:
|
||||
system_prompt = content
|
||||
if user_message is not None and system_prompt is not None:
|
||||
break
|
||||
|
||||
return user_message, system_prompt
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""
|
||||
Pre-routing hook called before the routing decision.
|
||||
|
||||
Classifies the request by complexity and returns the appropriate model.
|
||||
Supports chat completions (messages), Responses API (input), and other
|
||||
formats via the guardrail translation handler dispatch.
|
||||
|
||||
Args:
|
||||
model: The original model name requested.
|
||||
request_kwargs: The request kwargs.
|
||||
messages: The messages in the request.
|
||||
input: Optional input for Responses API or embeddings.
|
||||
specific_deployment: Whether a specific deployment was requested.
|
||||
|
||||
Returns:
|
||||
PreRoutingHookResponse with the routed model, or None if no routing needed.
|
||||
"""
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
resolved_messages = self._resolve_messages(messages, request_kwargs)
|
||||
|
||||
if not resolved_messages:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No messages could be resolved, skipping routing"
|
||||
)
|
||||
return None
|
||||
|
||||
# Determine whether the original request used messages directly
|
||||
has_original_messages = messages is not None and len(messages) > 0
|
||||
|
||||
user_message, system_prompt = self._extract_user_message_and_system_prompt(
|
||||
resolved_messages
|
||||
)
|
||||
|
||||
if user_message is None:
|
||||
verbose_router_logger.debug(
|
||||
@ -391,13 +460,10 @@ class ComplexityRouter(CustomLogger):
|
||||
return PreRoutingHookResponse(
|
||||
model=self.config.default_model
|
||||
or self.get_model_for_tier(ComplexityTier.MEDIUM),
|
||||
messages=messages,
|
||||
messages=messages if has_original_messages else None,
|
||||
)
|
||||
|
||||
# Classify the request
|
||||
tier, score, signals = self.classify(user_message, system_prompt)
|
||||
|
||||
# Get the model for this tier
|
||||
routed_model = self.get_model_for_tier(tier)
|
||||
|
||||
verbose_router_logger.info(
|
||||
@ -407,5 +473,5 @@ class ComplexityRouter(CustomLogger):
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=routed_model,
|
||||
messages=messages,
|
||||
messages=messages if has_original_messages else None,
|
||||
)
|
||||
|
||||
21
litellm/router_strategy/quality_router/__init__.py
Normal file
21
litellm/router_strategy/quality_router/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Quality-tier auto-router.
|
||||
|
||||
Re-uses the ComplexityRouter's classification to decide a request's complexity,
|
||||
then maps that complexity to an admin-configured quality tier and resolves the
|
||||
target model from each candidate's `model_info.litellm_routing_preferences`.
|
||||
"""
|
||||
|
||||
from .config import (
|
||||
DEFAULT_COMPLEXITY_TO_QUALITY,
|
||||
QualityRouterConfig,
|
||||
RoutingPreferences,
|
||||
)
|
||||
from .quality_router import QualityRouter
|
||||
|
||||
__all__ = [
|
||||
"QualityRouter",
|
||||
"QualityRouterConfig",
|
||||
"RoutingPreferences",
|
||||
"DEFAULT_COMPLEXITY_TO_QUALITY",
|
||||
]
|
||||
74
litellm/router_strategy/quality_router/config.py
Normal file
74
litellm/router_strategy/quality_router/config.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Configuration models for the QualityRouter.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# Default mapping from ComplexityTier name (string) to quality tier (int).
|
||||
# Higher tier = higher capability requirement.
|
||||
DEFAULT_COMPLEXITY_TO_QUALITY: Dict[str, int] = {
|
||||
"SIMPLE": 1,
|
||||
"MEDIUM": 2,
|
||||
"COMPLEX": 3,
|
||||
"REASONING": 4,
|
||||
}
|
||||
|
||||
|
||||
class QualityRouterConfig(BaseModel):
|
||||
"""Configuration for the QualityRouter."""
|
||||
|
||||
available_models: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of candidate model names this router may route to. Each model "
|
||||
"must declare its quality_tier in model_info.litellm_routing_preferences."
|
||||
),
|
||||
)
|
||||
|
||||
default_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fallback model when no quality tier resolves.",
|
||||
)
|
||||
|
||||
complexity_to_quality: Dict[str, int] = Field(
|
||||
default_factory=lambda: DEFAULT_COMPLEXITY_TO_QUALITY.copy(),
|
||||
description="Mapping from ComplexityTier name to quality tier (int).",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class RoutingPreferences(BaseModel):
|
||||
"""Per-deployment routing preferences declared on model_info."""
|
||||
|
||||
quality_tier: int = Field(
|
||||
...,
|
||||
description="The quality tier this deployment satisfies.",
|
||||
)
|
||||
|
||||
keywords: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Substring keywords (case-insensitive) that, when present in the "
|
||||
"user message, route the request to this deployment. See `order` "
|
||||
"for explicit collision handling, otherwise ties fall through to "
|
||||
"(highest quality_tier, then cheapest model_info.input_cost_per_token)."
|
||||
),
|
||||
)
|
||||
|
||||
order: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Explicit priority used to break ties between deployments at the "
|
||||
"same quality tier. Lower values win. Applies both to keyword "
|
||||
"collisions and to picking between multiple deployments at the "
|
||||
"same quality_tier. Tiebreak order is "
|
||||
"(quality_tier DESC, order ASC, input_cost_per_token ASC, "
|
||||
"model_name ASC) — quality always wins first, then explicit "
|
||||
"order, then price."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
446
litellm/router_strategy/quality_router/quality_router.py
Normal file
446
litellm/router_strategy/quality_router/quality_router.py
Normal file
@ -0,0 +1,446 @@
|
||||
"""
|
||||
Quality-tier Auto Router.
|
||||
|
||||
Routes a request to a model at a target quality tier. The quality tier is
|
||||
inferred by re-using the existing ComplexityRouter's classification, then
|
||||
mapped through an admin-configured `complexity_to_quality` table. Each
|
||||
candidate model declares its own `quality_tier` in
|
||||
`model_info.litellm_routing_preferences`.
|
||||
|
||||
Optional keyword override: deployments may also declare `keywords` in
|
||||
`litellm_routing_preferences`. If any declared keyword appears in the user
|
||||
message (case-insensitive substring match), the router short-circuits the
|
||||
complexity-classification flow and routes to the matching deployment. When
|
||||
multiple deployments match, ties are broken by (highest quality_tier first,
|
||||
then cheapest `model_info.input_cost_per_token`).
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.router_strategy.complexity_router.complexity_router import (
|
||||
ComplexityRouter,
|
||||
)
|
||||
|
||||
from .config import QualityRouterConfig, RoutingPreferences
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
else:
|
||||
Router = Any
|
||||
PreRoutingHookResponse = Any
|
||||
|
||||
|
||||
class QualityRouter(CustomLogger):
|
||||
"""
|
||||
Routes requests to a model at a target quality tier, with an optional
|
||||
keyword override.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
litellm_router_instance: "Router",
|
||||
default_model: Optional[str] = None,
|
||||
quality_router_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
if quality_router_config:
|
||||
self.config = QualityRouterConfig(**quality_router_config)
|
||||
else:
|
||||
self.config = QualityRouterConfig()
|
||||
|
||||
# Explicit default_model arg overrides anything in the config dict.
|
||||
if default_model:
|
||||
self.config.default_model = default_model
|
||||
|
||||
# Internal scorer — re-use the existing rule-based classifier.
|
||||
self._scorer = ComplexityRouter(
|
||||
model_name=f"{model_name}::scorer",
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
)
|
||||
|
||||
# Per-model indices populated alongside the tier index. `_model_keywords`
|
||||
# stores keywords lowercased so we can substring-match against the
|
||||
# lowercased user message in O(total-keyword-count). `_model_quality`,
|
||||
# `_model_cost`, and `_model_order` drive tiebreaking — `_model_order`
|
||||
# is the explicit priority (lower wins, unset = +inf).
|
||||
self._model_keywords: Dict[str, List[str]] = {}
|
||||
self._model_quality: Dict[str, int] = {}
|
||||
self._model_cost: Dict[str, Optional[float]] = {}
|
||||
self._model_order: Dict[str, Optional[int]] = {}
|
||||
|
||||
# Tier → models index. Built lazily on first access so the QualityRouter
|
||||
# deployment does NOT need to appear after all its referenced models in
|
||||
# the config — when `_build_tier_index` runs eagerly in `__init__`, the
|
||||
# router instance's `model_list` is still being assembled incrementally
|
||||
# by `_create_deployment`, and any `available_models` defined AFTER the
|
||||
# router entry in config.yaml would silently be reported as missing.
|
||||
self._tier_to_models_cache: Optional[Dict[int, List[str]]] = None
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"QualityRouter initialized for {model_name} with "
|
||||
f"available_models={self.config.available_models}, "
|
||||
f"default_model={self.config.default_model}"
|
||||
)
|
||||
|
||||
@property
|
||||
def _tier_to_models(self) -> Dict[int, List[str]]:
|
||||
"""Lazy tier→models index; built on first access."""
|
||||
if self._tier_to_models_cache is None:
|
||||
self._tier_to_models_cache = self._build_tier_index()
|
||||
return self._tier_to_models_cache
|
||||
|
||||
def _get_routing_preferences(self, deployment: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract litellm_routing_preferences from a deployment, handling both
|
||||
dict-shaped and Pydantic-object-shaped deployments.
|
||||
"""
|
||||
# Dict-shaped deployment.
|
||||
if isinstance(deployment, dict):
|
||||
model_info = deployment.get("model_info") or {}
|
||||
if isinstance(model_info, dict):
|
||||
return model_info.get("litellm_routing_preferences")
|
||||
# Pydantic ModelInfo nested in a dict.
|
||||
return getattr(model_info, "litellm_routing_preferences", None)
|
||||
|
||||
# Pydantic-object deployment.
|
||||
model_info = getattr(deployment, "model_info", None)
|
||||
if model_info is None:
|
||||
return None
|
||||
if isinstance(model_info, dict):
|
||||
return model_info.get("litellm_routing_preferences")
|
||||
return getattr(model_info, "litellm_routing_preferences", None)
|
||||
|
||||
def _get_deployment_input_cost(self, deployment: Any) -> Optional[float]:
|
||||
"""
|
||||
Extract `input_cost_per_token` from a deployment's model_info.
|
||||
|
||||
Returns None when not declared — None is treated as "infinite cost"
|
||||
for the cheapest-tiebreak ordering, so unpriced models lose ties to
|
||||
priced ones. (Admins who want a model to win on price must declare it.)
|
||||
"""
|
||||
if isinstance(deployment, dict):
|
||||
model_info = deployment.get("model_info") or {}
|
||||
else:
|
||||
model_info = getattr(deployment, "model_info", None) or {}
|
||||
|
||||
if isinstance(model_info, dict):
|
||||
cost = model_info.get("input_cost_per_token")
|
||||
else:
|
||||
cost = getattr(model_info, "input_cost_per_token", None)
|
||||
|
||||
if cost is None:
|
||||
return None
|
||||
try:
|
||||
return float(cost)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _get_deployment_model_name(self, deployment: Any) -> Optional[str]:
|
||||
"""Extract `model_name` from a dict- or object-shaped deployment."""
|
||||
if isinstance(deployment, dict):
|
||||
return deployment.get("model_name")
|
||||
return getattr(deployment, "model_name", None)
|
||||
|
||||
def _build_tier_index(self) -> Dict[int, List[str]]:
|
||||
"""
|
||||
Build {quality_tier: [model_name, ...]} for every model in
|
||||
`available_models`, plus side indices `_model_keywords`,
|
||||
`_model_quality`, and `_model_cost`. Raises if any listed model is
|
||||
missing `litellm_routing_preferences`.
|
||||
"""
|
||||
model_list = getattr(self.litellm_router_instance, "model_list", None) or []
|
||||
available = set(self.config.available_models)
|
||||
|
||||
# Track which available models we've matched so we can error on missing.
|
||||
seen: Dict[str, bool] = {name: False for name in available}
|
||||
tier_to_models: Dict[int, List[str]] = {}
|
||||
|
||||
for deployment in model_list:
|
||||
name = self._get_deployment_model_name(deployment)
|
||||
if name is None or name not in available:
|
||||
continue
|
||||
|
||||
raw_prefs = self._get_routing_preferences(deployment)
|
||||
if raw_prefs is None:
|
||||
raise ValueError(
|
||||
f"QualityRouter: model '{name}' is listed in available_models "
|
||||
f"but has no model_info.litellm_routing_preferences"
|
||||
)
|
||||
|
||||
# Validate via the Pydantic model so we get a clear error for
|
||||
# missing quality_tier, wrong types, etc. This also means
|
||||
# `RoutingPreferences` is the single source of truth for the
|
||||
# accepted shape — readers relied on raw dicts before.
|
||||
try:
|
||||
if isinstance(raw_prefs, RoutingPreferences):
|
||||
prefs = raw_prefs
|
||||
elif isinstance(raw_prefs, dict):
|
||||
prefs = RoutingPreferences(**raw_prefs)
|
||||
else:
|
||||
# A Pydantic object of some other shape — coerce via its dict.
|
||||
prefs = RoutingPreferences(
|
||||
**(
|
||||
raw_prefs.model_dump()
|
||||
if hasattr(raw_prefs, "model_dump")
|
||||
else dict(raw_prefs)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"QualityRouter: model '{name}' has invalid "
|
||||
f"litellm_routing_preferences: {e}"
|
||||
) from e
|
||||
|
||||
tier_int = int(prefs.quality_tier)
|
||||
tier_to_models.setdefault(tier_int, []).append(name)
|
||||
self._model_keywords[name] = [str(k).lower() for k in prefs.keywords if k]
|
||||
self._model_quality[name] = tier_int
|
||||
self._model_cost[name] = self._get_deployment_input_cost(deployment)
|
||||
self._model_order[name] = prefs.order
|
||||
seen[name] = True
|
||||
|
||||
missing = [name for name, found in seen.items() if not found]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"QualityRouter: the following available_models are not present in "
|
||||
f"the router's model_list (or are missing routing preferences): {missing}"
|
||||
)
|
||||
|
||||
# Sort each tier's model list so `_resolve_model_for_quality_tier`
|
||||
# (which picks index [0]) honors (order ASC, cost ASC, name ASC).
|
||||
# Quality is moot within a single tier; keep parity with the keyword
|
||||
# tiebreak by ordering on (order, cost, name) here.
|
||||
for models in tier_to_models.values():
|
||||
models.sort(key=lambda n: (self._order_key(n), self._cost_key(n), n))
|
||||
|
||||
return tier_to_models
|
||||
|
||||
def _order_key(self, model_name: str) -> float:
|
||||
"""`order` lookup as a float — unset becomes +inf so explicit wins."""
|
||||
order = self._model_order.get(model_name)
|
||||
return float(order) if order is not None else math.inf
|
||||
|
||||
def _cost_key(self, model_name: str) -> float:
|
||||
"""`input_cost_per_token` as a float — unset becomes +inf."""
|
||||
cost = self._model_cost.get(model_name)
|
||||
return float(cost) if cost is not None else math.inf
|
||||
|
||||
def _keyword_override(self, user_message: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Find a deployment whose declared keywords appear in `user_message`.
|
||||
|
||||
Returns (model_name, matched_keyword) or None when no keyword matches.
|
||||
When multiple deployments match, sorts by:
|
||||
1. quality_tier DESC (best quality always wins first)
|
||||
2. `order` ASC (explicit priority — unset = +inf so explicit wins
|
||||
within the same tier)
|
||||
3. input_cost_per_token ASC (unpriced = +inf so priced wins)
|
||||
4. model_name ASC (deterministic stability)
|
||||
"""
|
||||
# Touch the lazy index so `_model_keywords` / `_model_quality` /
|
||||
# `_model_cost` / `_model_order` are populated.
|
||||
_ = self._tier_to_models
|
||||
|
||||
text = user_message.lower()
|
||||
|
||||
matches: List[Tuple[str, str]] = [] # (model_name, matched_keyword)
|
||||
for model_name, keywords in self._model_keywords.items():
|
||||
for kw in keywords:
|
||||
if kw and kw in text:
|
||||
matches.append((model_name, kw))
|
||||
break # one match per model is enough
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
def sort_key(match: Tuple[str, str]) -> Tuple[int, float, float, str]:
|
||||
name = match[0]
|
||||
quality = self._model_quality.get(name, 0)
|
||||
order_val = self._order_key(name)
|
||||
cost = self._model_cost.get(name)
|
||||
cost_val = cost if cost is not None else math.inf
|
||||
# Negate quality so higher tier sorts first under ASC sort.
|
||||
return (-quality, order_val, cost_val, name)
|
||||
|
||||
matches.sort(key=sort_key)
|
||||
return matches[0]
|
||||
|
||||
def _resolve_model_for_quality_tier(self, tier: int) -> str:
|
||||
"""
|
||||
Resolve a quality tier to a concrete model name.
|
||||
|
||||
Strategy:
|
||||
1. Exact tier match → first model registered at that tier.
|
||||
2. Round UP to the next higher tier that has a model (closer to a
|
||||
request we might lack capacity for).
|
||||
3. Round DOWN to the closest lower tier that has a model (degrade
|
||||
gracefully instead of jumping straight to `default_model`,
|
||||
which may be off-tier).
|
||||
4. Fall back to `config.default_model`.
|
||||
5. Otherwise raise.
|
||||
"""
|
||||
tier_index = self._tier_to_models
|
||||
if tier in tier_index and tier_index[tier]:
|
||||
return tier_index[tier][0]
|
||||
|
||||
# Round up.
|
||||
higher_tiers = sorted(t for t in tier_index if t > tier)
|
||||
for t in higher_tiers:
|
||||
if tier_index[t]:
|
||||
return tier_index[t][0]
|
||||
|
||||
# Round down — closest lower tier first.
|
||||
lower_tiers = sorted((t for t in tier_index if t < tier), reverse=True)
|
||||
for t in lower_tiers:
|
||||
if tier_index[t]:
|
||||
return tier_index[t][0]
|
||||
|
||||
if self.config.default_model:
|
||||
return self.config.default_model
|
||||
|
||||
raise ValueError(
|
||||
f"QualityRouter: no model available for quality tier {tier} and "
|
||||
f"no default_model configured"
|
||||
)
|
||||
|
||||
def _stash_decision(
|
||||
self,
|
||||
request_kwargs: Optional[Dict[str, Any]],
|
||||
decision: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Stash the routing decision in request_kwargs.metadata so the Router can
|
||||
lift it into response headers (`x-litellm-quality-router-*`). The same
|
||||
dict object flows from here through to `make_call.set_response_headers`.
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return
|
||||
metadata = request_kwargs.setdefault("metadata", {})
|
||||
if isinstance(metadata, dict):
|
||||
metadata["quality_router_decision"] = decision
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""Try keyword override first; fall back to complexity-tier routing."""
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
if messages is None or len(messages) == 0:
|
||||
verbose_router_logger.debug(
|
||||
"QualityRouter: No messages provided, skipping routing"
|
||||
)
|
||||
return None
|
||||
|
||||
# Extract last user message and last system prompt — same rules as
|
||||
# ComplexityRouter.async_pre_routing_hook.
|
||||
user_message: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
content = " ".join(text_parts).strip()
|
||||
if isinstance(content, str) and content:
|
||||
if role == "user" and user_message is None:
|
||||
user_message = content
|
||||
elif role == "system" and system_prompt is None:
|
||||
system_prompt = content
|
||||
|
||||
if user_message is None:
|
||||
verbose_router_logger.debug(
|
||||
"QualityRouter: No user message found, routing to default model"
|
||||
)
|
||||
if not self.config.default_model:
|
||||
raise ValueError(
|
||||
"QualityRouter: no user message and no default_model configured"
|
||||
)
|
||||
return PreRoutingHookResponse(
|
||||
model=self.config.default_model,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Try keyword override first — it short-circuits complexity classification.
|
||||
keyword_match = self._keyword_override(user_message)
|
||||
if keyword_match is not None:
|
||||
routed_model, matched_keyword = keyword_match
|
||||
verbose_router_logger.info(
|
||||
f"QualityRouter: keyword override matched='{matched_keyword}' "
|
||||
f"routed_model={routed_model} "
|
||||
f"(quality_tier={self._model_quality.get(routed_model)}, "
|
||||
f"input_cost_per_token={self._model_cost.get(routed_model)})"
|
||||
)
|
||||
self._stash_decision(
|
||||
request_kwargs,
|
||||
{
|
||||
"router_model_name": self.model_name,
|
||||
"routed_model": routed_model,
|
||||
"routed_via": "keyword",
|
||||
"matched_keyword": matched_keyword,
|
||||
"quality_tier": self._model_quality.get(routed_model),
|
||||
"complexity_tier": None,
|
||||
},
|
||||
)
|
||||
return PreRoutingHookResponse(
|
||||
model=routed_model,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# No keyword match → complexity classification flow.
|
||||
complexity_tier, score, signals = self._scorer.classify(
|
||||
user_message, system_prompt
|
||||
)
|
||||
complexity_name = (
|
||||
complexity_tier.value
|
||||
if hasattr(complexity_tier, "value")
|
||||
else str(complexity_tier)
|
||||
)
|
||||
|
||||
quality_tier = self.config.complexity_to_quality.get(complexity_name)
|
||||
if quality_tier is None:
|
||||
raise ValueError(
|
||||
f"QualityRouter: complexity tier '{complexity_name}' not present "
|
||||
f"in complexity_to_quality mapping {self.config.complexity_to_quality}"
|
||||
)
|
||||
|
||||
routed_model = self._resolve_model_for_quality_tier(int(quality_tier))
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"QualityRouter: complexity={complexity_name}, score={score:.3f}, "
|
||||
f"signals={signals}, quality_tier={quality_tier}, "
|
||||
f"routed_model={routed_model}"
|
||||
)
|
||||
|
||||
self._stash_decision(
|
||||
request_kwargs,
|
||||
{
|
||||
"router_model_name": self.model_name,
|
||||
"routed_model": routed_model,
|
||||
"routed_via": "quality_tier",
|
||||
"matched_keyword": None,
|
||||
"quality_tier": int(quality_tier),
|
||||
"complexity_tier": complexity_name,
|
||||
},
|
||||
)
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=routed_model,
|
||||
messages=messages,
|
||||
)
|
||||
@ -221,6 +221,10 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
complexity_router_config: Optional[Dict] = None
|
||||
complexity_router_default_model: Optional[str] = None
|
||||
|
||||
# quality-router params
|
||||
quality_router_config: Optional[Dict] = None
|
||||
quality_router_default_model: Optional[str] = None
|
||||
|
||||
# Batch/File API Params
|
||||
s3_bucket_name: Optional[str] = None
|
||||
s3_encryption_key_id: Optional[str] = None
|
||||
|
||||
@ -891,6 +891,61 @@ class TestOpenAIChatCompletionsHandlerStreamingOutput:
|
||||
assert result == responses_so_far
|
||||
|
||||
|
||||
class TestGetStructuredMessages:
|
||||
"""Test the get_structured_messages method."""
|
||||
|
||||
def test_should_return_messages_from_chat_completions_request(self):
|
||||
"""Test that messages are returned from a chat completions request."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_should_return_none_when_no_messages(self):
|
||||
"""Test that None is returned when no messages key exists."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
data = {"model": "gpt-4"}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is None
|
||||
|
||||
def test_should_return_none_for_none_messages(self):
|
||||
"""Test that None is returned when messages is explicitly None."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
data = {"messages": None}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is None
|
||||
|
||||
def test_should_handle_multimodal_content(self):
|
||||
"""Test that messages with multimodal content are returned."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0]["content"], list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@ -995,3 +995,63 @@ class TestOpenAIResponsesHandlerStreamingOutputProcessing:
|
||||
|
||||
# Should return the responses
|
||||
assert result == responses_so_far
|
||||
|
||||
|
||||
class TestGetStructuredMessages:
|
||||
"""Test the get_structured_messages method for Responses API handler."""
|
||||
|
||||
def test_should_convert_string_input_to_messages(self):
|
||||
"""Test that a simple string input is converted to OpenAI messages."""
|
||||
handler = OpenAIResponsesHandler()
|
||||
data = {"input": "What is the capital of France?"}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is not None
|
||||
assert len(result) >= 1
|
||||
found_user = False
|
||||
for msg in result:
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
found_user = True
|
||||
break
|
||||
assert found_user, f"Expected a user message, got: {result}"
|
||||
|
||||
def test_should_convert_list_input_to_messages(self):
|
||||
"""Test that list input (ResponseInputParam) is converted to OpenAI messages."""
|
||||
handler = OpenAIResponsesHandler()
|
||||
data = {
|
||||
"input": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is not None
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_should_include_instructions_as_system_message(self):
|
||||
"""Test that instructions are included as a system message."""
|
||||
handler = OpenAIResponsesHandler()
|
||||
data = {
|
||||
"input": "Roll a d20",
|
||||
"instructions": "You are a helpful dungeon master.",
|
||||
}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is not None
|
||||
has_system = any(
|
||||
isinstance(msg, dict) and msg.get("role") == "system" for msg in result
|
||||
)
|
||||
assert has_system, f"Expected system message from instructions, got: {result}"
|
||||
|
||||
def test_should_return_none_when_no_input(self):
|
||||
"""Test that None is returned when input key is missing."""
|
||||
handler = OpenAIResponsesHandler()
|
||||
data = {"model": "gpt-4o"}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is None
|
||||
|
||||
def test_should_return_none_for_none_input(self):
|
||||
"""Test that None is returned when input is explicitly None."""
|
||||
handler = OpenAIResponsesHandler()
|
||||
data = {"input": None}
|
||||
result = handler.get_structured_messages(data)
|
||||
assert result is None
|
||||
|
||||
@ -12,7 +12,148 @@ sys.path.insert(
|
||||
|
||||
from litellm.router_strategy.auto_router.auto_router import AutoRouter
|
||||
|
||||
pytestmark = pytest.mark.skip(reason="Skipping auto router tests - beta feature")
|
||||
pytestmark_skip_beta = pytest.mark.skip(
|
||||
reason="Skipping auto router tests - beta feature"
|
||||
)
|
||||
|
||||
|
||||
class TestExtractTextFromMessages:
|
||||
"""Tests for AutoRouter._extract_text_from_messages (no semantic_router dependency)."""
|
||||
|
||||
def test_should_extract_content_from_simple_user_message(self):
|
||||
messages = [{"role": "user", "content": "Hello world"}]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "Hello world"
|
||||
|
||||
def test_should_extract_last_user_message_from_tool_call_conversation(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"content": "72°F and sunny",
|
||||
},
|
||||
{"role": "user", "content": "Now tell me about London"},
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "Now tell me about London"
|
||||
|
||||
def test_should_find_user_message_when_last_message_is_assistant_with_tool_calls(
|
||||
self,
|
||||
):
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "What's the weather?"
|
||||
|
||||
def test_should_find_user_message_when_last_message_is_tool_response(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "72°F and sunny",
|
||||
},
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "What's the weather?"
|
||||
|
||||
def test_should_handle_multimodal_content_list(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/img.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "What's in this image?"
|
||||
|
||||
def test_should_handle_multimodal_content_with_multiple_text_blocks(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "First part"},
|
||||
{"type": "text", "text": "Second part"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/img.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == "First part Second part"
|
||||
|
||||
def test_should_return_empty_string_when_user_content_is_none(self):
|
||||
messages = [{"role": "user", "content": None}]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == ""
|
||||
|
||||
def test_should_return_empty_string_when_no_user_messages(self):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
result = AutoRouter._extract_text_from_messages(messages)
|
||||
assert result == ""
|
||||
|
||||
def test_should_return_empty_string_for_empty_messages_list(self):
|
||||
result = AutoRouter._extract_text_from_messages([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -41,6 +182,7 @@ def mock_route_choice():
|
||||
return mock_choice
|
||||
|
||||
|
||||
@pytestmark_skip_beta
|
||||
class TestAutoRouter:
|
||||
"""Test class for AutoRouter methods."""
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ Tests the rule-based complexity scoring and tier assignment logic.
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -828,3 +828,222 @@ class TestRouterComplexityDeploymentMethods:
|
||||
)
|
||||
router.init_complexity_router_deployment(deployment)
|
||||
assert "auto_router/complexity_router/test-router" in router.complexity_routers
|
||||
|
||||
|
||||
class TestAsyncPreRoutingHookMultiFormat:
|
||||
"""Test async_pre_routing_hook with multiple input formats."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_route_with_chat_completions_messages(self, complexity_router):
|
||||
"""Test routing with standard chat completions messages."""
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={},
|
||||
messages=[{"role": "user", "content": "What is 2+2?"}],
|
||||
)
|
||||
assert result is not None
|
||||
assert result.model is not None
|
||||
assert result.messages is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_route_with_responses_api_string_input(
|
||||
self, complexity_router
|
||||
):
|
||||
"""Test routing with Responses API string input via handler dispatch."""
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
OpenAIResponsesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
mock_mappings = {CallTypes.responses: OpenAIResponsesHandler}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.load_guardrail_translation_mappings",
|
||||
return_value=mock_mappings,
|
||||
):
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={"input": "What is the capital of France?"},
|
||||
messages=None,
|
||||
input="What is the capital of France?",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.model is not None
|
||||
# messages should be None since the original request didn't have messages
|
||||
assert result.messages is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_route_with_responses_api_list_input(self, complexity_router):
|
||||
"""Test routing with Responses API list input via handler dispatch."""
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
OpenAIResponsesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
mock_mappings = {CallTypes.responses: OpenAIResponsesHandler}
|
||||
|
||||
list_input = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write a Python function to sort a list using merge sort",
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"litellm.llms.load_guardrail_translation_mappings",
|
||||
return_value=mock_mappings,
|
||||
):
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={"input": list_input},
|
||||
messages=None,
|
||||
input=list_input,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.model is not None
|
||||
assert result.messages is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_use_route_based_inference(self, complexity_router):
|
||||
"""Test that route-based call type inference is used when available."""
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
OpenAIResponsesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
mock_mappings = {CallTypes.responses: OpenAIResponsesHandler}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.load_guardrail_translation_mappings",
|
||||
return_value=mock_mappings,
|
||||
):
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={
|
||||
"input": "Roll 2d4+1",
|
||||
"litellm_metadata": {
|
||||
"user_api_key_request_route": "/v1/responses",
|
||||
},
|
||||
},
|
||||
messages=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.model is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_return_none_when_no_messages_or_input(
|
||||
self, complexity_router
|
||||
):
|
||||
"""Test that None is returned when neither messages nor input is available."""
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={},
|
||||
messages=None,
|
||||
input=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_prefer_original_messages_over_conversion(
|
||||
self, complexity_router
|
||||
):
|
||||
"""Test that original messages are used when both messages and input are available."""
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={"input": "This should be ignored"},
|
||||
messages=messages,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.messages == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_include_instructions_in_classification(
|
||||
self, complexity_router
|
||||
):
|
||||
"""Test that Responses API instructions influence classification via system message."""
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
OpenAIResponsesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
mock_mappings = {CallTypes.responses: OpenAIResponsesHandler}
|
||||
|
||||
with patch(
|
||||
"litellm.llms.load_guardrail_translation_mappings",
|
||||
return_value=mock_mappings,
|
||||
):
|
||||
result = await complexity_router.async_pre_routing_hook(
|
||||
model="test-model",
|
||||
request_kwargs={
|
||||
"input": "Write merge sort",
|
||||
"instructions": "You are an expert Python developer. Use advanced algorithms and optimize for performance.",
|
||||
},
|
||||
messages=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.model is not None
|
||||
|
||||
|
||||
class TestExtractUserMessageAndSystemPrompt:
|
||||
"""Test the _extract_user_message_and_system_prompt static method."""
|
||||
|
||||
def test_should_extract_user_message(self):
|
||||
"""Test extraction of the last user message."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt(
|
||||
messages
|
||||
)
|
||||
assert user_msg == "How are you?"
|
||||
assert sys_prompt == "You are helpful."
|
||||
|
||||
def test_should_handle_no_user_message(self):
|
||||
"""Test when there is no user message."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt(
|
||||
messages
|
||||
)
|
||||
assert user_msg is None
|
||||
assert sys_prompt == "You are helpful."
|
||||
|
||||
def test_should_handle_multipart_content(self):
|
||||
"""Test extraction from multipart content messages."""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/img.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt(
|
||||
messages
|
||||
)
|
||||
assert user_msg == "Describe this image"
|
||||
assert sys_prompt is None
|
||||
|
||||
def test_should_handle_empty_messages(self):
|
||||
"""Test with empty messages list."""
|
||||
user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt(
|
||||
[]
|
||||
)
|
||||
assert user_msg is None
|
||||
assert sys_prompt is None
|
||||
|
||||
1033
tests/test_litellm/router_strategy/test_quality_router.py
Normal file
1033
tests/test_litellm/router_strategy/test_quality_router.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user