[Feat] - Ishaan main merge branch (#23596)

* fix(bedrock): respect s3_region_name for batch file uploads (#23569)

* fix(bedrock): respect s3_region_name for batch file uploads (GovCloud fix)

* fix: s3_region_name always wins over aws_region_name for S3 signing (Greptile feedback)

* fix: _filter_headers_for_aws_signature - Bedrock KB (#23571)

* fix: _filter_headers_for_aws_signature

* fix: filter None header values in all post-signing re-merge paths

Addresses Greptile feedback: None-valued headers were being filtered
during SigV4 signing but re-merged back into the final headers dict
afterward, which would cause downstream HTTP client failures.

Made-with: Cursor

* feat(router): tag_regex routing — route by User-Agent regex without per-developer tag config (#23594)

* feat(router): add tag_regex support for header-based routing

Adds a new `tag_regex` field to litellm_params that lets operators route
requests based on regex patterns matched against request headers — primarily
User-Agent — without requiring per-developer tag configuration.

Use case: route all Claude Code traffic (User-Agent: claude-code/x.y.z) to
a dedicated deployment by setting:

  tag_regex:
    - "^User-Agent: claude-code\\/"

in the deployment's litellm_params. Works alongside existing `tags` routing;
exact tag match takes precedence over regex match. Unmatched requests fall
through to deployments tagged `default`.

The matched deployment, pattern, and user_agent are recorded in
`metadata["tag_routing"]` so they flow through to SpendLogs automatically.

* fix(tag_regex): address backwards-compat, metadata overwrite, and warning noise

Three issues from code review:

1. Backwards-compat: `has_tag_filter` was widened to activate on any non-empty
   User-Agent, which would raise ValueError for existing deployments using plain
   tags without a `default` fallback. Fix: only activate header-based regex
   filtering when at least one candidate deployment has `tag_regex` configured.

2. Metadata overwrite: `metadata["tag_routing"]` was overwritten for every
   matching deployment in the loop, leaving inaccurate provenance when multiple
   deployments match. Fix: write only for the first match.

3. Warning noise: an invalid regex pattern logged one warning per header string
   rather than once per pattern. Fix: compile first (catching re.error once),
   then iterate over header strings.

Also adds two new tests covering these cases, and adds docs page for
tag_regex routing with a Claude Code walk-through.

* refactor(tag_regex): remove unnecessary _healthy_list copy

* docs: merge tag_regex section into tag_routing.md, remove standalone page

- Add ## Regex-based tag routing (tag_regex) section to existing
  tag_routing.md instead of a separate page
- Remove tag_regex_routing.md standalone doc (odd UX to have a separate
  page for a sub-feature)
- Remove proxy/tag_regex_routing from sidebars.js
- Add match_any=False debug warning in tag_based_routing.py when regex
  routing fires under strict mode (regex always uses OR semantics)

* fix(tag_regex): address greptile review - security docs, strict-mode enforcement, validation order

- Strengthen security note in tag_routing.md: explicitly state User-Agent
  is client-supplied and can be set to any value; frame tag_regex as a
  traffic classification hint, not an access-control mechanism
- Move tag_regex startup validation before _add_deployment() so an invalid
  pattern never leaves partial router state
- Enforce match_any=False strict-tag policy: when a deployment has both
  tags and tag_regex and the strict tag check fails, skip the regex fallback
  rather than silently bypassing the operator's intent
- Extract per-deployment match logic into _match_deployment() helper to
  keep get_deployments_for_tag() readable
- Add two new tests: strict-mode blocks regex fallback, regex-only
  deployment still matches under match_any=False

* fix(ci): apply Black formatting to 14 files and stabilize flaky caplog tests

- Run Black formatter on 14 files that were failing the lint check
- Replace caplog-based assertions in TestAliasConflicts with
  unittest.mock.patch on verbose_logger.warning for xdist compatibility
- The caplog fixture can produce empty text in pytest-xdist workers
  in certain CI environments, causing flaky test failures

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2026-03-14 09:40:00 -07:00 committed by GitHub
parent 25ee2fb3f9
commit b87d1f8dad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 959 additions and 73 deletions

View File

@ -209,6 +209,106 @@ Expect to see the following response header when this works
x-litellm-model-id: default-model
```
## Regex-based tag routing (`tag_regex`)
Use `tag_regex` to route requests based on regex patterns matched against request headers, without requiring clients to pass a tag explicitly. This is useful when clients already send a recognisable header, such as `User-Agent`.
**Use case: route all Claude Code traffic to dedicated AWS accounts**
Claude Code always sends `User-Agent: claude-code/<version>`. With `tag_regex` you can route that traffic to a dedicated deployment automatically — no per-developer configuration needed.
### 1. Config
```yaml
model_list:
# Claude Code traffic → dedicated deployment, matched by User-Agent
- model_name: claude-sonnet
litellm_params:
model: bedrock/converse/anthropic-claude-sonnet-4-6
aws_region_name: us-east-1
aws_role_name: arn:aws:iam::111122223333:role/LiteLLMClaudeCode
tag_regex:
- "^User-Agent: claude-code\\/" # matches claude-code/1.x, 2.x, etc.
model_info:
id: claude-code-deployment
# All other traffic falls back to the default deployment
- model_name: claude-sonnet
litellm_params:
model: bedrock/converse/anthropic-claude-sonnet-4-6
aws_region_name: us-east-1
aws_role_name: arn:aws:iam::444455556666:role/LiteLLMDefault
tags:
- default
model_info:
id: regular-deployment
router_settings:
enable_tag_filtering: true
tag_filtering_match_any: true
general_settings:
master_key: sk-1234
```
### 2. Verify routing
Claude Code sets `User-Agent: claude-code/<version>` automatically — no client config needed:
```shell
# Claude Code request (User-Agent set automatically by Claude Code)
curl http://localhost:4000/v1/chat/completions \
-H "Authorization: Bearer sk-1234" \
-H "User-Agent: claude-code/1.2.3" \
-d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}'
# → x-litellm-model-id: claude-code-deployment
# Any other client (no matching User-Agent) → default deployment
curl http://localhost:4000/v1/chat/completions \
-H "Authorization: Bearer sk-1234" \
-d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}'
# → x-litellm-model-id: regular-deployment
```
### How matching works
| Priority | Condition | Result |
|----------|-----------|--------|
| 1 | Request has `tags` AND deployment has `tags` | Exact tag match (respects `match_any` setting) |
| 2 | Deployment has `tag_regex` AND request has a `User-Agent` | Regex match (always OR logic — any pattern match suffices) |
| 3 | Deployment has `tags: [default]` | Default fallback |
| 4 | No default set | All healthy deployments returned |
`tag_regex` always uses OR semantics — `tag_filtering_match_any=False` applies only to exact tag matching, not to regex patterns.
### Observability
When a regex matches, `tag_routing` is written into request metadata and flows to SpendLogs:
```json
{
"tag_routing": {
"matched_via": "tag_regex",
"matched_value": "^User-Agent: claude-code\\/",
"user_agent": "claude-code/1.2.3",
"request_tags": []
}
}
```
### Security note
:::caution
**`User-Agent` is a client-supplied header and can be set to any value.** Any API consumer can send `User-Agent: claude-code/1.0` regardless of whether they are actually using Claude Code.
Do not rely on `tag_regex` routing to enforce access controls or spend limits — use [team/key-based routing](./users) for that. `tag_regex` is a **traffic classification hint** (useful for billing visibility, capacity planning, and routing convenience), not a security boundary.
:::
---
## ✨ Team based tag routing (Enterprise)
LiteLLM Proxy supports team-based tag routing, allowing you to associate specific tags with teams and route requests accordingly. Example **Team A can access gpt-4 deployment A, Team B can access gpt-4 deployment B** (LLM Access Control For Teams)

View File

@ -398,6 +398,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
ResponseOutputMessage,
ResponseReasoningItem,
)
try:
from openai.types.responses.response_output_item import (
ResponseApplyPatchToolCall,
@ -460,7 +461,9 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
accumulated_tool_calls.append(tool_call_dict)
tool_call_index += 1
elif ResponseApplyPatchToolCall is not None and isinstance(item, ResponseApplyPatchToolCall):
elif ResponseApplyPatchToolCall is not None and isinstance(
item, ResponseApplyPatchToolCall
):
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)

View File

@ -2680,7 +2680,9 @@ def anthropic_messages_pt( # noqa: PLR0915
_content_is_list = "content" in assistant_content_block and isinstance(
assistant_content_block["content"], list
)
_content_list = assistant_content_block.get("content") if _content_is_list else None
_content_list = (
assistant_content_block.get("content") if _content_is_list else None
)
_list_has_thinking = False
if _content_is_list and _content_list is not None:
for _item in _content_list:

View File

@ -79,7 +79,9 @@ class AnthropicFilesConfig(BaseFilesConfig):
return AnthropicError(
status_code=status_code,
message=error_message,
headers=cast(httpx.Headers, headers) if isinstance(headers, dict) else headers,
headers=cast(httpx.Headers, headers)
if isinstance(headers, dict)
else headers,
)
def validate_environment(

View File

@ -144,9 +144,7 @@ class BaseModelResponseIterator:
# Skip empty lines (common in SSE streams between events).
# Only apply to str chunks — non-string objects (e.g. Pydantic
# BaseModel events from the Responses API) must pass through.
if isinstance(str_line, str) and (
not str_line or not str_line.strip()
):
if isinstance(str_line, str) and (not str_line or not str_line.strip()):
continue
# chunk is a str at this point
@ -184,9 +182,7 @@ class BaseModelResponseIterator:
# Skip empty lines (common in SSE streams between events).
# Only apply to str chunks — non-string objects (e.g. Pydantic
# BaseModel events from the Responses API) must pass through.
if isinstance(str_line, str) and (
not str_line or not str_line.strip()
):
if isinstance(str_line, str) and (not str_line or not str_line.strip()):
continue
# chunk is a str at this point

View File

@ -1268,7 +1268,8 @@ class BaseAWSLLM:
# Add back all original headers (including forwarded ones) after signature calculation
for header_name, header_value in headers.items():
request.headers[header_name] = header_value
if header_value is not None:
request.headers[header_name] = header_value
if (
extra_headers is not None and "Authorization" in extra_headers
@ -1298,6 +1299,8 @@ class BaseAWSLLM:
}
for header_name, header_value in headers.items():
if header_value is None:
continue
header_lower = header_name.lower()
if (
header_lower in aws_headers
@ -1393,7 +1396,8 @@ class BaseAWSLLM:
# Add back original headers after signing. Only headers in SignedHeaders
# are integrity-protected; forwarded headers (x-forwarded-*) must remain unsigned.
for header_name, header_value in headers.items():
request_headers_dict[header_name] = header_value
if header_value is not None:
request_headers_dict[header_name] = header_value
if (
headers is not None and "Authorization" in headers
): # prevent sigv4 from overwriting the auth header

View File

@ -173,7 +173,12 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
"S3 bucket_name is required. Set 's3_bucket_name' in litellm_params or AWS_S3_BUCKET_NAME env var"
)
aws_region_name = self._get_aws_region_name(optional_params, model)
s3_region_name = litellm_params.get("s3_region_name") or optional_params.get(
"s3_region_name"
)
aws_region_name = s3_region_name or self._get_aws_region_name(
optional_params, model
)
file_data = data.get("file")
purpose = data.get("purpose")
@ -398,6 +403,15 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
data=create_file_data,
)
# s3_region_name always wins for S3 operations (same priority as in
# get_complete_file_url above). Overwrite aws_region_name unconditionally
# so the SigV4 region matches the URL region, avoiding SignatureDoesNotMatch.
s3_region_name = litellm_params.get("s3_region_name") or optional_params.get(
"s3_region_name"
)
if s3_region_name:
optional_params = {**optional_params, "aws_region_name": s3_region_name}
# Sign the request and return a pre-signed request object
signed_headers, signed_body = self._sign_s3_request(
content=file_content,

View File

@ -201,7 +201,9 @@ class BlackForestLabsImageEditConfig(BaseImageEditConfig):
return image
elif isinstance(image, list):
# If it's a list, take the first image
return self._read_image_bytes(image[0], depth=depth + 1, max_depth=max_depth)
return self._read_image_bytes(
image[0], depth=depth + 1, max_depth=max_depth
)
elif isinstance(image, str):
if image.startswith(("http://", "https://")):
# Download image from URL

View File

@ -71,7 +71,9 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
result: List[Any] = []
for item in input:
if isinstance(item, dict) and "type" not in item:
new_item = dict(item) # convert to plain dict to avoid TypedDict checking
new_item = dict(
item
) # convert to plain dict to avoid TypedDict checking
new_item["type"] = "message"
result.append(new_item)
else:

View File

@ -378,9 +378,7 @@ if MCP_AVAILABLE:
# Resolve a server name to its UUID if needed
_name_resolved = None
if server_id not in allowed_server_ids:
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(
server_id
)
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
if _name_resolved is not None and _name_resolved.server_id in set(
allowed_server_ids
):
@ -442,9 +440,7 @@ if MCP_AVAILABLE:
extra_headers=user_oauth_extra_headers,
)
except Exception as e:
verbose_logger.exception(
f"Error getting tools from {server.name}: {e}"
)
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
return {
"tools": [],
"error": "server_error",
@ -473,7 +469,9 @@ if MCP_AVAILABLE:
_name_resolved = None
if server_id not in allowed_server_ids:
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
if _name_resolved is not None and _name_resolved.server_id in set(allowed_server_ids):
if _name_resolved is not None and _name_resolved.server_id in set(
allowed_server_ids
):
server_id = _name_resolved.server_id
if server_id not in allowed_server_ids:
@ -518,7 +516,9 @@ if MCP_AVAILABLE:
server_auth_header = _get_server_auth_header(
server, mcp_server_auth_headers, mcp_auth_header
)
user_oauth_extra_headers = await _get_user_oauth_extra_headers(server, user_api_key_dict)
user_oauth_extra_headers = await _get_user_oauth_extra_headers(
server, user_api_key_dict
)
try:
list_tools_result = await _get_tools_for_single_server(
@ -529,9 +529,7 @@ if MCP_AVAILABLE:
extra_headers=user_oauth_extra_headers,
)
except Exception as e:
verbose_logger.exception(
f"Error getting tools from {server.name}: {e}"
)
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
return {
"tools": [],
"error": "server_error",
@ -905,7 +903,9 @@ if MCP_AVAILABLE:
try:
client_id, client_secret, scopes = _extract_credentials(request)
_oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = (
_oauth2_flow: Optional[
Literal["client_credentials", "authorization_code"]
] = (
"client_credentials"
if client_id and client_secret and request.token_url
else None

View File

@ -64,7 +64,11 @@ from litellm.proxy.common_utils.http_parsing_utils import (
populate_request_with_path_params,
)
from litellm.proxy.common_utils.realtime_utils import _realtime_request_body
from litellm.proxy.utils import PrismaClient, ProxyLogging, normalize_route_for_root_path
from litellm.proxy.utils import (
PrismaClient,
ProxyLogging,
normalize_route_for_root_path,
)
from litellm.secret_managers.main import get_secret_bool
from litellm.types.services import ServiceTypes

View File

@ -106,9 +106,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
if (self.output_parse_pii or self.apply_to_output) and not logging_only:
current_hook = self.event_hook
if isinstance(current_hook, str) and current_hook != "post_call":
self.event_hook = cast(List[GuardrailEventHooks], [current_hook, "post_call"])
self.event_hook = cast(
List[GuardrailEventHooks], [current_hook, "post_call"]
)
elif isinstance(current_hook, list) and "post_call" not in current_hook:
self.event_hook = cast(List[GuardrailEventHooks], current_hook + ["post_call"])
self.event_hook = cast(
List[GuardrailEventHooks], current_hook + ["post_call"]
)
self.pii_entities_config: Dict[Union[PiiEntityType, str], PiiAction] = (
pii_entities_config or {}
)

View File

@ -1838,9 +1838,7 @@ async def _validate_update_key_data(
# Check team limits if key has a team_id (from request or existing key)
team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
_team_id_to_check = data.team_id or getattr(
existing_key_row, "team_id", None
)
_team_id_to_check = data.team_id or getattr(existing_key_row, "team_id", None)
if _team_id_to_check is not None:
team_obj = await get_team_object(
team_id=_team_id_to_check,
@ -1910,9 +1908,7 @@ async def _validate_update_key_data(
if team_obj is None:
raise HTTPException(
status_code=500,
detail={
"error": "Team object not found for team change validation"
},
detail={"error": "Team object not found for team change validation"},
)
await validate_key_team_change(
key=existing_key_row,

View File

@ -846,12 +846,16 @@ async def get_generic_sso_response(
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
additional_generic_sso_headers_dict = _parse_generic_sso_headers()
code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking
code_verifier: Optional[
str
] = None # assigned inside try; initialized for type tracking
try:
token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters(
request=request,
generic_include_client_id=generic_include_client_id,
token_exchange_params = (
await SSOAuthenticationHandler.prepare_token_exchange_parameters(
request=request,
generic_include_client_id=generic_include_client_id,
)
)
# Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso
@ -915,7 +919,9 @@ async def get_generic_sso_response(
# Assign directly rather than relying on nonlocal mutation so that Pyright
# can track that received_response is non-None from this point on.
received_response = {
k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS
k: v
for k, v in combined_response.items()
if k not in _OAUTH_TOKEN_FIELDS
}
# In the PKCE path verify_and_process is skipped, so generic_sso.access_token
# is never set. Read the token directly from the exchange response instead so
@ -2598,7 +2604,9 @@ class SSOAuthenticationHandler:
state,
)
else:
verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache")
verbose_proxy_logger.debug(
"PKCE code_verifier retrieved from cache"
)
elif isinstance(cached_data, str):
# Handle legacy format (plain string) for backward compatibility
code_verifier = cached_data
@ -2647,7 +2655,9 @@ class SSOAuthenticationHandler:
In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException.
Otherwise logs a warning and returns (token exchange proceeds without verifier).
"""
active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
active_cache = (
redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
)
strict_cache_miss = (
os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true"
)

View File

@ -5233,7 +5233,7 @@ def normalize_route_for_root_path(route: str) -> Optional[str]:
root_path = get_server_root_path()
if root_path and root_path != "/":
if route.startswith(root_path + "/"):
return route[len(root_path):]
return route[len(root_path) :]
return None
return route

View File

@ -415,7 +415,9 @@ class LiteLLMCompletionResponsesConfig:
if isinstance(new_msg, dict)
else getattr(new_msg, "tool_calls", None)
)
new_tcs: list = _raw_tcs if isinstance(_raw_tcs, list) else []
new_tcs: list = (
_raw_tcs if isinstance(_raw_tcs, list) else []
)
for tc in new_tcs:
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
last_msg, tc

View File

@ -14,6 +14,7 @@ import hashlib
import inspect
import json
import logging
import re
import threading
import time
import traceback
@ -1551,7 +1552,9 @@ class Router:
# Drain any fire-and-forget tasks (e.g. alerting hooks)
# scheduled via asyncio.create_task during acompletion.
pending = asyncio.all_tasks()
pending.discard(asyncio.current_task())
current = asyncio.current_task()
if current is not None:
pending.discard(current)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
@ -6542,6 +6545,18 @@ class Router:
)
return None
# Validate tag_regex patterns BEFORE adding the deployment so we never
# have partially-initialised router state if a pattern is invalid.
_tag_regex = deployment.litellm_params.get("tag_regex") or []
for pattern in _tag_regex:
try:
re.compile(pattern)
except re.error as exc:
raise ValueError(
f"Invalid regex in tag_regex for model '{deployment.model_name}': "
f"{pattern!r}{exc}"
) from exc
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)

View File

@ -6,6 +6,7 @@ Use this to route requests between Teams
- If no default_deployments are set, return all deployments
"""
import re
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from litellm._logging import verbose_logger
@ -19,6 +20,29 @@ else:
LitellmRouter = Any
def _is_valid_deployment_tag_regex(
tag_regexes: List[str],
header_strings: List[str],
) -> Optional[str]:
"""
Test compiled regex patterns against "Header-Name: value" strings.
Returns the first matching pattern string, or None if nothing matches.
Compiles each pattern once (re's LRU cache) and logs invalid patterns once
per pattern, not once per header string.
"""
for pattern in tag_regexes:
try:
compiled = re.compile(pattern)
except re.error:
verbose_logger.warning("tag_regex: invalid pattern %r — skipping", pattern)
continue
for header_str in header_strings:
if compiled.search(header_str):
return pattern
return None
def is_valid_deployment_tag(
deployment_tags: List[str], request_tags: List[str], match_any: bool = True
) -> bool:
@ -47,6 +71,54 @@ def is_valid_deployment_tag(
return False
def _match_deployment(
deployment: Any,
request_tags: Optional[List[str]],
header_strings: List[str],
match_any: bool,
) -> Optional[Dict[str, str]]:
"""
Determine whether *deployment* matches the current request.
Returns {"matched_via": ..., "matched_value": ...} if the deployment
should be included, or None if it should be excluded.
Priority:
1. Exact tag match (respects match_any semantics).
2. Regex match skipped when match_any=False and the tag check already
ran and failed, so the regex cannot override strict-tag policy.
"""
litellm_params = deployment.get("litellm_params", {})
deployment_tags: Optional[List[str]] = litellm_params.get("tags")
deployment_tag_regex: Optional[List[str]] = litellm_params.get("tag_regex")
# 1. Exact tag match (existing behaviour).
if deployment_tags and request_tags:
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
matched_value = next(
(t for t in deployment_tags if t in set(request_tags)),
deployment_tags[0],
)
return {"matched_via": "tags", "matched_value": matched_value}
# 2. Regex match against request headers.
# When match_any=False and the deployment has both plain tags and tag_regex,
# the strict tag check has already failed (step 1 returned None). Allow
# the regex to fire only when the deployment has NO plain tags, so we never
# use regex as a backdoor around the operator's strict-tag policy.
strict_tag_check_failed = (
not match_any and bool(deployment_tags) and bool(request_tags)
)
if deployment_tag_regex and header_strings and not strict_tag_check_failed:
regex_match = _is_valid_deployment_tag_regex(
deployment_tag_regex, header_strings
)
if regex_match is not None:
return {"matched_via": "tag_regex", "matched_value": regex_match}
return None
async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
model: str, # used to raise the correct error
@ -83,30 +155,63 @@ async def get_deployments_for_tag(
request_tags = metadata.get("tags")
match_any = llm_router_instance.tag_filtering_match_any
new_healthy_deployments = []
default_deployments = []
if request_tags:
verbose_logger.debug(
"get_deployments_for_tag routing: router_keys: %s", request_tags
)
# example this can be router_keys=["free", "custom"]
for deployment in healthy_deployments:
deployment_litellm_params = deployment.get("litellm_params")
deployment_tags = deployment_litellm_params.get("tags")
# Build header strings for regex matching from what the proxy already stores.
# Currently we match against User-Agent; format matches "^User-Agent: claude-code/..."
user_agent = metadata.get("user_agent", "")
header_strings: List[str] = [f"User-Agent: {user_agent}"] if user_agent else []
verbose_logger.debug(
"deployment: %s, deployment_router_keys: %s",
deployment,
deployment_tags,
new_healthy_deployments: List[Any] = []
default_deployments: List[Any] = []
# Only activate header-based regex filtering when at least one deployment in
# the candidate set has tag_regex configured. This preserves existing
# behaviour for operators who use plain tags: a request that carries a
# User-Agent (all proxy requests do) but targets deployments with no
# tag_regex will continue to use the original tag-only code path.
has_regex_deployments = any(
d.get("litellm_params", {}).get("tag_regex") for d in healthy_deployments
)
has_tag_filter = bool(request_tags) or (
bool(header_strings) and has_regex_deployments
)
if has_tag_filter:
verbose_logger.debug(
"get_deployments_for_tag routing: request_tags=%s user_agent=%s",
request_tags,
user_agent,
)
for deployment in healthy_deployments:
deployment_tags = deployment.get("litellm_params", {}).get("tags")
match_result = _match_deployment(
deployment=deployment,
request_tags=request_tags,
header_strings=header_strings,
match_any=match_any,
)
if deployment_tags is None:
continue
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
if match_result is not None:
verbose_logger.debug(
"tag routing match: deployment=%s matched_via=%s matched_value=%s",
deployment.get("model_name"),
match_result["matched_via"],
match_result["matched_value"],
)
# Record provenance in metadata so it flows to SpendLogs.
# Written only for the first match — load balancer selects one
# deployment from new_healthy_deployments, so overwriting on
# subsequent matches would produce misleading observability data.
if "tag_routing" not in metadata:
metadata["tag_routing"] = {
"matched_deployment": deployment.get("model_name"),
"matched_via": match_result["matched_via"],
"matched_value": match_result["matched_value"],
"request_tags": request_tags or [],
"user_agent": user_agent,
}
new_healthy_deployments.append(deployment)
if "default" in deployment_tags:
if deployment_tags and "default" in deployment_tags:
default_deployments.append(deployment)
if len(new_healthy_deployments) == 0 and len(default_deployments) == 0:

View File

@ -198,6 +198,11 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
model_info: Optional[Dict] = None
mock_response: Optional[Union[str, ModelResponse, Exception, Any]] = None
# tag-based routing
tags: Optional[List[str]] = None
# regex patterns matched against request headers for tag routing
tag_regex: Optional[List[str]] = None
# auto-router params
auto_router_config_path: Optional[str] = None
auto_router_config: Optional[str] = None
@ -334,6 +339,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
# routing params
# use this for tag-based routing
tags: Optional[List[str]]
# regex patterns matched against request headers (e.g. "^User-Agent:\\s*claude-code\\/")
tag_regex: Optional[List[str]]
# deployment budgets
max_budget: Optional[float]

View File

@ -272,6 +272,164 @@ class TestBedrockFilesTransformation:
assert "messages" in model_input
assert "max_tokens" in model_input
def test_get_complete_file_url_respects_s3_region_name(self):
"""
s3_region_name in litellm_params must be used when building the S3 URL.
Previously the code fell back to us-west-2 even when s3_region_name was set,
breaking GovCloud (us-gov-west-1) deployments.
"""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
jsonl_content = json.dumps(
{
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/amazon.nova-pro-v1:0",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
).encode()
create_file_data = {
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
"purpose": "batch",
}
litellm_params = {
"s3_bucket_name": "litellm-batch-352026",
"s3_region_name": "us-gov-west-1",
}
url = config.get_complete_file_url(
api_base=None,
api_key=None,
model="amazon.nova-pro-v1:0",
optional_params={},
litellm_params=litellm_params,
data=create_file_data,
)
assert "us-gov-west-1" in url, (
f"Expected us-gov-west-1 in URL but got: {url}"
)
assert "us-west-2" not in url, (
f"us-west-2 must not appear when s3_region_name is set, got: {url}"
)
assert "litellm-batch-352026" in url
def test_transform_create_file_request_injects_s3_region_for_signing(self):
"""
When s3_region_name is provided, transform_create_file_request must pass
that region to _sign_s3_request so SigV4 signatures use the correct region.
"""
from unittest.mock import patch
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
jsonl_content = json.dumps(
{
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/amazon.nova-pro-v1:0",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
).encode()
create_file_data = {
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
"purpose": "batch",
}
litellm_params = {
"s3_bucket_name": "litellm-batch-352026",
"s3_region_name": "us-gov-west-1",
}
captured_optional_params: dict = {}
def fake_sign(content, api_base, optional_params):
captured_optional_params.update(optional_params)
return {"Authorization": "fake"}, content
with patch.object(config, "_sign_s3_request", side_effect=fake_sign):
config.transform_create_file_request(
model="amazon.nova-pro-v1:0",
create_file_data=create_file_data,
optional_params={},
litellm_params=litellm_params,
)
assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", (
"s3_region_name must be forwarded as aws_region_name for SigV4 signing"
)
def test_s3_region_name_wins_over_aws_region_name_for_signing(self):
"""
When both s3_region_name and aws_region_name are set to different values,
s3_region_name must win for signing (same as for the URL). Otherwise the
SigV4 signature would be computed against a different region than the URL,
causing SignatureDoesNotMatch from AWS.
"""
from unittest.mock import patch
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
jsonl_content = json.dumps(
{
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/amazon.nova-pro-v1:0",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
).encode()
create_file_data = {
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
"purpose": "batch",
}
litellm_params = {
"s3_bucket_name": "litellm-batch-352026",
"s3_region_name": "us-gov-west-1",
}
# aws_region_name set to something different — s3_region_name must still win
optional_params = {"aws_region_name": "us-east-1"}
captured_optional_params: dict = {}
def fake_sign(content, api_base, optional_params):
captured_optional_params.update(optional_params)
return {"Authorization": "fake"}, content
with patch.object(config, "_sign_s3_request", side_effect=fake_sign):
config.transform_create_file_request(
model="amazon.nova-pro-v1:0",
create_file_data=create_file_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", (
"s3_region_name must override aws_region_name for SigV4 signing"
)
def test_openai_passthrough_still_works(self):
"""
Regression test: ensure OpenAI-compatible models (e.g. gpt-oss)

View File

@ -14,15 +14,16 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Dict
from unittest.mock import MagicMock, patch
from botocore.awsrequest import AWSPreparedRequest, AWSRequest
from botocore.credentials import Credentials
from botocore.awsrequest import AWSRequest, AWSPreparedRequest
import litellm
from litellm.caching.caching import DualCache
from litellm.llms.bedrock.base_aws_llm import (
AwsAuthError,
BaseAWSLLM,
Boto3CredentialsInfo,
)
from litellm.caching.caching import DualCache
# Global variable for the base_aws_llm.py file path
@ -1519,6 +1520,83 @@ def test_is_already_running_as_role_invalid_target_arn():
assert base_aws_llm._is_already_running_as_role("not-a-valid-arn") is False
def test_filter_headers_skips_none_values():
"""
Test that _filter_headers_for_aws_signature skips headers with None values.
Reproduces the issue where botocore's SigV4Auth crashes with
'NoneType' object has no attribute 'split' when a header value is None.
"""
llm = BaseAWSLLM()
headers = {
"Content-Type": "application/json",
"x-amz-security-token": None,
"x-amzn-bedrock-kb-session-id": None,
"host": None,
"x-amz-date": "20240101T000000Z",
"x-custom-header": None,
}
filtered = llm._filter_headers_for_aws_signature(headers)
assert filtered["Content-Type"] == "application/json"
assert filtered["x-amz-date"] == "20240101T000000Z"
assert "x-amz-security-token" not in filtered
assert "x-amzn-bedrock-kb-session-id" not in filtered
assert "host" not in filtered
# Non-AWS headers are excluded regardless
assert "x-custom-header" not in filtered
def test_sign_request_with_none_header_values():
"""
End-to-end test that _sign_request does not crash when headers contain
None values for x-amz-* keys.
This reproduces the Bedrock KB GovCloud issue where SigV4 signing failed
with 'NoneType' object has no attribute 'split'.
Also verifies that None-valued headers are NOT re-merged into the
returned headers dict (which would cause downstream HTTP client failures).
"""
llm = BaseAWSLLM()
mock_credentials = Credentials("test_key", "test_secret")
headers_with_nones = {
"Content-Type": "application/json",
"x-amzn-trace-id": None,
"x-forwarded-for": None,
}
with patch.object(
llm, "get_credentials", return_value=mock_credentials
), patch.object(
llm, "_get_aws_region_name", return_value="us-gov-west-1"
):
result_headers, result_body = llm._sign_request(
service_name="bedrock",
headers=headers_with_nones,
optional_params={
"aws_access_key_id": "test_key",
"aws_secret_access_key": "test_secret",
"aws_region_name": "us-gov-west-1",
},
request_data={"retrievalQuery": {"text": "test query"}},
api_base="https://bedrock-agent-runtime.us-gov-west-1.amazonaws.com/knowledgebases/KB123/retrieve",
)
assert "Authorization" in result_headers
assert result_body is not None
# None-valued headers must NOT appear in the returned headers
for header_name, header_value in result_headers.items():
assert header_value is not None, (
f"Header '{header_name}' has None value in returned headers"
)
def test_is_already_running_as_role_ssl_verify_passed():
"""
Test that ssl_verify parameter is correctly passed to the STS client.

View File

@ -0,0 +1,375 @@
"""
Unit tests for tag_regex routing.
Tests _is_valid_deployment_tag_regex() and get_deployments_for_tag() with tag_regex
patterns, verifying that regex-based header matching works correctly alongside
existing tag-based routing.
"""
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from unittest.mock import MagicMock
from litellm.router_strategy import tag_based_routing
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
_is_valid_deployment_tag_regex = tag_based_routing._is_valid_deployment_tag_regex
# ---------------------------------------------------------------------------
# _is_valid_deployment_tag_regex unit tests
# ---------------------------------------------------------------------------
def test_regex_matches_claude_code_user_agent():
"""^User-Agent: claude-code/ matches a claude-code UA string."""
result = _is_valid_deployment_tag_regex(
tag_regexes=[r"^User-Agent: claude-code\/"],
header_strings=["User-Agent: claude-code/1.2.3"],
)
assert result == r"^User-Agent: claude-code\/"
def test_regex_no_match_for_other_ua():
"""Pattern does not match a non-claude-code User-Agent."""
result = _is_valid_deployment_tag_regex(
tag_regexes=[r"^User-Agent: claude-code\/"],
header_strings=["User-Agent: Mozilla/5.0 (browser)"],
)
assert result is None
def test_regex_returns_first_matching_pattern():
"""When multiple patterns are provided, returns the first match."""
result = _is_valid_deployment_tag_regex(
tag_regexes=[r"^User-Agent: cursor\/", r"^User-Agent: claude-code\/"],
header_strings=["User-Agent: claude-code/2.0.0"],
)
assert result == r"^User-Agent: claude-code\/"
def test_regex_empty_inputs_return_none():
"""Empty lists return None without errors."""
assert _is_valid_deployment_tag_regex([], ["User-Agent: claude-code/1.0"]) is None
assert _is_valid_deployment_tag_regex([r"^User-Agent: claude-code\/"], []) is None
def test_invalid_regex_skipped_does_not_raise():
"""An invalid regex pattern is skipped (warning logged) — no exception raised."""
result = _is_valid_deployment_tag_regex(
tag_regexes=["[invalid(regex"],
header_strings=["User-Agent: claude-code/1.0"],
)
assert result is None
def test_regex_matches_version_range():
"""Semver-aware pattern matches multiple versions."""
pattern = r"^User-Agent: claude-code\/\d"
for ua in ["claude-code/1.0", "claude-code/2.0.0-beta.1", "claude-code/99.0"]:
result = _is_valid_deployment_tag_regex(
tag_regexes=[pattern],
header_strings=[f"User-Agent: {ua}"],
)
assert result == pattern, f"Expected match for UA: {ua}"
# ---------------------------------------------------------------------------
# get_deployments_for_tag integration tests
# ---------------------------------------------------------------------------
CLAUDE_CODE_DEPLOYMENT = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/claude-code-deployment",
"api_key": "fake",
"mock_response": "cc",
"tag_regex": [r"^User-Agent: claude-code\/"],
},
"model_info": {"id": "claude-code-deployment"},
}
REGULAR_DEPLOYMENT = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/regular-deployment",
"api_key": "fake",
"mock_response": "regular",
"tags": ["default"],
},
"model_info": {"id": "regular-deployment"},
}
ALL_DEPLOYMENTS = [CLAUDE_CODE_DEPLOYMENT, REGULAR_DEPLOYMENT]
def _make_router_mock(enable_tag_filtering=True, match_any=True):
mock = MagicMock()
mock.enable_tag_filtering = enable_tag_filtering
mock.tag_filtering_match_any = match_any
return mock
@pytest.mark.asyncio
async def test_claude_code_ua_routes_to_cc_deployment():
"""claude-code/x.y.z UA → claude-code-deployment via tag_regex."""
router = _make_router_mock()
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=ALL_DEPLOYMENTS,
request_kwargs={"metadata": {"user_agent": "claude-code/1.2.3"}},
)
assert len(result) == 1
assert result[0]["model_info"]["id"] == "claude-code-deployment"
@pytest.mark.asyncio
async def test_regular_ua_routes_to_default_deployment():
"""Mozilla UA → regular-deployment via default tag fallback."""
router = _make_router_mock()
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=ALL_DEPLOYMENTS,
request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (browser)"}},
)
assert len(result) == 1
assert result[0]["model_info"]["id"] == "regular-deployment"
@pytest.mark.asyncio
async def test_no_ua_routes_to_default_deployment():
"""No User-Agent → default deployment."""
router = _make_router_mock()
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=ALL_DEPLOYMENTS,
request_kwargs={"metadata": {}},
)
assert len(result) == 1
assert result[0]["model_info"]["id"] == "regular-deployment"
@pytest.mark.asyncio
async def test_tag_routing_metadata_written_for_regex_match():
"""tag_routing metadata block is populated when regex matches."""
router = _make_router_mock()
metadata: dict = {"user_agent": "claude-code/2.0.0-beta.1"}
await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=ALL_DEPLOYMENTS,
request_kwargs={"metadata": metadata},
)
assert "tag_routing" in metadata
tr = metadata["tag_routing"]
assert tr["matched_via"] == "tag_regex"
assert tr["matched_value"] == r"^User-Agent: claude-code\/"
assert tr["user_agent"] == "claude-code/2.0.0-beta.1"
@pytest.mark.asyncio
async def test_tag_filtering_disabled_returns_all_deployments():
"""When enable_tag_filtering is False, all deployments returned regardless of UA."""
router = _make_router_mock(enable_tag_filtering=False)
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=ALL_DEPLOYMENTS,
request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}},
)
assert result == ALL_DEPLOYMENTS
@pytest.mark.asyncio
async def test_explicit_tag_match_takes_precedence_over_regex():
"""A deployment with both tags and tag_regex: exact tag match fires first."""
deployment_with_both = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/both-deployment",
"api_key": "fake",
"tags": ["premium"],
"tag_regex": [r"^User-Agent: claude-code\/"],
},
"model_info": {"id": "both-deployment"},
}
router = _make_router_mock()
metadata: dict = {
"tags": ["premium"],
"user_agent": "claude-code/1.0",
}
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=[deployment_with_both],
request_kwargs={"metadata": metadata},
)
assert len(result) == 1
tr = metadata.get("tag_routing", {})
assert tr.get("matched_via") == "tags"
@pytest.mark.asyncio
async def test_user_agent_present_no_tag_regex_deployments_does_not_raise():
"""
Backwards-compat: a request that carries a User-Agent but targets plain-tag
deployments (no tag_regex) must NOT raise ValueError it should fall
through to the default/all-deployments path just like before.
"""
plain_tag_only_deployments = [
{
"model_name": "gpt-4",
"litellm_params": {
"model": "openai/premium-deployment",
"api_key": "fake",
"tags": ["premium"],
},
"model_info": {"id": "premium-deployment"},
},
{
"model_name": "gpt-4",
"litellm_params": {
"model": "openai/free-deployment",
"api_key": "fake",
"tags": ["free"],
},
"model_info": {"id": "free-deployment"},
},
]
router = _make_router_mock()
# The request has a User-Agent (as all proxy requests do) but NO tags and
# neither deployment has tag_regex — must not raise, must return all.
result = await get_deployments_for_tag(
llm_router_instance=router,
model="gpt-4",
healthy_deployments=plain_tag_only_deployments,
request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (any-client)"}},
)
# Falls through to "return healthy_deployments" path unchanged
assert result == plain_tag_only_deployments
@pytest.mark.asyncio
async def test_tag_routing_metadata_not_overwritten_for_multiple_matches():
"""
When multiple deployments match, tag_routing records only the first match
so the provenance reflects what the load balancer likely selected.
"""
deployment_a = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/cc-deployment-a",
"api_key": "fake",
"tag_regex": [r"^User-Agent: claude-code\/"],
},
"model_info": {"id": "cc-deployment-a"},
}
deployment_b = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/cc-deployment-b",
"api_key": "fake",
"tag_regex": [r"^User-Agent: claude-code\/"],
},
"model_info": {"id": "cc-deployment-b"},
}
router = _make_router_mock()
metadata: dict = {"user_agent": "claude-code/1.0"}
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=[deployment_a, deployment_b],
request_kwargs={"metadata": metadata},
)
assert len(result) == 2
# tag_routing recorded once and reflects the first match
tr = metadata.get("tag_routing", {})
assert tr.get("matched_deployment") == "claude-sonnet"
assert tr.get("matched_via") == "tag_regex"
@pytest.mark.asyncio
async def test_match_any_false_strict_tag_check_blocks_regex_fallback():
"""
When match_any=False and a deployment has both tags and tag_regex:
if the strict tag check fails (request has a tag NOT present on the
deployment, so req_set is NOT a subset of dep_set), the regex fallback
must NOT fire that would violate the operator's strict-filtering intent.
Semantics of match_any=False: req_set.issubset(dep_set), i.e. every
request tag must appear on the deployment. A request with tags ["vip"]
against a deployment with tags ["premium"] fails because "vip" dep_set.
"""
deployment_strict = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/strict-deployment",
"api_key": "fake",
"tags": ["premium"],
"tag_regex": [r"^User-Agent: claude-code\/"],
},
"model_info": {"id": "strict-deployment"},
}
default_deployment = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/default-deployment",
"api_key": "fake",
"tags": ["default"],
},
"model_info": {"id": "default-deployment"},
}
# match_any=False: req_set must be a subset of dep_set.
# Request has "vip" which is NOT in ["premium"], so tag check fails.
# Even though UA matches tag_regex, the deployment must NOT be selected.
router = _make_router_mock(enable_tag_filtering=True, match_any=False)
metadata: dict = {
"tags": ["vip"], # "vip" not in deployment tags → strict check fails
"user_agent": "claude-code/1.0",
}
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=[deployment_strict, default_deployment],
request_kwargs={"metadata": metadata},
)
ids = [d["model_info"]["id"] for d in result]
assert "strict-deployment" not in ids, (
"strict-deployment should not be selected: strict tag check failed "
"and regex must not override the strict policy"
)
@pytest.mark.asyncio
async def test_match_any_false_regex_only_deployment_still_matches():
"""
When match_any=False and a deployment has ONLY tag_regex (no plain tags),
there is no strict tag policy to violate, so the regex check must still fire.
"""
regex_only_deployment = {
"model_name": "claude-sonnet",
"litellm_params": {
"model": "openai/regex-only-deployment",
"api_key": "fake",
"tag_regex": [r"^User-Agent: claude-code\/"],
# no "tags" key at all
},
"model_info": {"id": "regex-only-deployment"},
}
router = _make_router_mock(enable_tag_filtering=True, match_any=False)
result = await get_deployments_for_tag(
llm_router_instance=router,
model="claude-sonnet",
healthy_deployments=[regex_only_deployment],
request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}},
)
assert len(result) == 1
assert result[0]["model_info"]["id"] == "regex-only-deployment"

View File

@ -5,8 +5,9 @@ The ``_expand_model_aliases`` function processes ``aliases`` lists from model
entries, creating shared dict references for alias entries at load time.
"""
import logging
from unittest.mock import patch
from litellm import verbose_logger
from litellm.litellm_core_utils.get_model_cost_map import _expand_model_aliases
@ -118,7 +119,7 @@ class TestExpandModelAliases:
class TestAliasConflicts:
"""Tests for alias conflict detection and handling."""
def test_alias_conflicts_with_canonical_entry(self, caplog):
def test_alias_conflicts_with_canonical_entry(self):
"""Alias that matches an existing canonical entry is skipped with a warning."""
model_cost = {
"model-latest": {
@ -133,14 +134,17 @@ class TestAliasConflicts:
"mode": "chat",
},
}
with caplog.at_level(logging.WARNING, logger="LiteLLM"):
with patch.object(verbose_logger, "warning") as mock_warn:
result = _expand_model_aliases(model_cost)
# The canonical "model-dated" entry is preserved, not overwritten
assert "model-dated" in result
assert "alias conflict" in caplog.text.lower()
# Verify a warning about the alias conflict was logged
mock_warn.assert_called()
warning_messages = " ".join(str(c) for c in mock_warn.call_args_list)
assert "alias conflict" in warning_messages.lower()
def test_duplicate_alias_across_entries(self, caplog):
def test_duplicate_alias_across_entries(self):
"""Same alias claimed by two different entries: second one is skipped."""
model_cost = {
"model-a": {
@ -156,13 +160,16 @@ class TestAliasConflicts:
"mode": "chat",
},
}
with caplog.at_level(logging.WARNING, logger="LiteLLM"):
with patch.object(verbose_logger, "warning") as mock_warn:
result = _expand_model_aliases(model_cost)
# "shared-alias" should point to model-a (first one wins)
assert "shared-alias" in result
assert result["shared-alias"]["input_cost_per_token"] == 1e-06
assert "alias conflict" in caplog.text.lower()
# Verify a warning about the alias conflict was logged
mock_warn.assert_called()
warning_messages = " ".join(str(c) for c in mock_warn.call_args_list)
assert "alias conflict" in warning_messages.lower()
def test_canonical_entry_not_overwritten_by_alias(self):
"""An alias must never overwrite an existing canonical entry's data."""