Merge pull request #26161 from BerriAI/litellm_access-group-routing-fix
fix(router): constrain same-name deployment routing by access groups
This commit is contained in:
commit
d2015f0baf
@ -9351,6 +9351,52 @@ class Router:
|
||||
"""
|
||||
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
def _try_early_resolve_deployments_for_model_not_in_names(
|
||||
self, model: str, request_team_id: Optional[str]
|
||||
) -> Optional[Tuple[str, Union[List, Dict]]]:
|
||||
"""
|
||||
When ``model`` is not in ``self.model_names``, try team routes, pattern routes,
|
||||
team pattern routers, then default deployment. Returns None if none apply.
|
||||
"""
|
||||
if model in self.model_names:
|
||||
return None
|
||||
# Check for team-specific deployments by team_public_model_name.
|
||||
# This intentionally takes priority over team pattern routers below,
|
||||
# so that named team deployments shadow wildcard/pattern routes.
|
||||
if request_team_id is not None:
|
||||
team_deployments = self._get_all_deployments(
|
||||
model_name=model, team_id=request_team_id
|
||||
)
|
||||
if team_deployments:
|
||||
return model, team_deployments
|
||||
|
||||
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
if request_team_id is not None and request_team_id in self.team_pattern_routers:
|
||||
pattern_deployments = self.team_pattern_routers[
|
||||
request_team_id
|
||||
].get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
if self.default_deployment is not None:
|
||||
# Shallow copy with nested litellm_params copy (100x+ faster than deepcopy)
|
||||
updated_deployment = self.default_deployment.copy()
|
||||
updated_deployment["litellm_params"] = self.default_deployment[
|
||||
"litellm_params"
|
||||
].copy()
|
||||
updated_deployment["litellm_params"]["model"] = model
|
||||
return model, updated_deployment
|
||||
|
||||
return None
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
@ -9393,56 +9439,52 @@ class Router:
|
||||
if _model_from_alias is not None:
|
||||
model = _model_from_alias
|
||||
|
||||
if model not in self.model_names:
|
||||
# Check for team-specific deployments by team_public_model_name.
|
||||
# This intentionally takes priority over team pattern routers below,
|
||||
# so that named team deployments shadow wildcard/pattern routes.
|
||||
if request_team_id is not None:
|
||||
team_deployments = self._get_all_deployments(
|
||||
model_name=model, team_id=request_team_id
|
||||
)
|
||||
if team_deployments:
|
||||
return model, team_deployments
|
||||
|
||||
# check if provider/ specific wildcard routing use pattern matching
|
||||
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
if (
|
||||
request_team_id is not None
|
||||
and request_team_id in self.team_pattern_routers
|
||||
):
|
||||
pattern_deployments = self.team_pattern_routers[
|
||||
request_team_id
|
||||
].get_deployments_by_pattern(
|
||||
model=model,
|
||||
)
|
||||
if pattern_deployments:
|
||||
return model, pattern_deployments
|
||||
|
||||
# check if default deployment is set
|
||||
if self.default_deployment is not None:
|
||||
# Shallow copy with nested litellm_params copy (100x+ faster than deepcopy)
|
||||
updated_deployment = self.default_deployment.copy()
|
||||
updated_deployment["litellm_params"] = self.default_deployment[
|
||||
"litellm_params"
|
||||
].copy()
|
||||
updated_deployment["litellm_params"]["model"] = model
|
||||
return model, updated_deployment
|
||||
early = self._try_early_resolve_deployments_for_model_not_in_names(
|
||||
model=model, request_team_id=request_team_id
|
||||
)
|
||||
if early is not None:
|
||||
return early
|
||||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
healthy_deployments = self._get_all_deployments(
|
||||
model_name=model, team_id=request_team_id
|
||||
)
|
||||
_pre_model_access_group_filter_len = len(healthy_deployments)
|
||||
healthy_deployments = self._filter_deployments_by_model_access_groups(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
request_team_id=request_team_id,
|
||||
)
|
||||
_access_group_filter_emptied_candidates = (
|
||||
_pre_model_access_group_filter_len > 0 and len(healthy_deployments) == 0
|
||||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
|
||||
# Do not fall back when access-group filtering removed every candidate;
|
||||
# _get_deployment_by_litellm_model does not re-apply that filter.
|
||||
if _pre_model_access_group_filter_len == 0:
|
||||
_litellm_model_deployments = self._get_deployment_by_litellm_model(
|
||||
model=model
|
||||
)
|
||||
healthy_deployments = self._filter_deployments_by_model_access_groups(
|
||||
model=model,
|
||||
healthy_deployments=_litellm_model_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
request_team_id=request_team_id,
|
||||
)
|
||||
# If the litellm-model lookup produced candidates that access-group
|
||||
# filtering then removed, treat this the same as the by-name path
|
||||
# being emptied: prevent default-model fallback from bypassing the
|
||||
# restriction (the fallback model may have no access_groups and
|
||||
# would short-circuit the filter).
|
||||
if (
|
||||
len(_litellm_model_deployments) > 0
|
||||
and len(healthy_deployments) == 0
|
||||
):
|
||||
_access_group_filter_emptied_candidates = True
|
||||
|
||||
if verbose_router_logger.isEnabledFor(logging.DEBUG):
|
||||
verbose_router_logger.debug(
|
||||
@ -9451,7 +9493,13 @@ class Router:
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
# Check for default fallbacks if no deployments are found for the requested model
|
||||
if self._has_default_fallbacks():
|
||||
# Do not fall back to another model when access-group filtering removed every
|
||||
# candidate for the requested name: re-filtering the fallback model can be a
|
||||
# no-op when it has no access_groups, incorrectly serving a different model.
|
||||
if (
|
||||
self._has_default_fallbacks()
|
||||
and not _access_group_filter_emptied_candidates
|
||||
):
|
||||
fallback_model = self._get_first_default_fallback()
|
||||
if fallback_model:
|
||||
verbose_router_logger.info(
|
||||
@ -9462,6 +9510,14 @@ class Router:
|
||||
healthy_deployments = self._get_all_deployments(
|
||||
model_name=model, team_id=request_team_id
|
||||
)
|
||||
healthy_deployments = (
|
||||
self._filter_deployments_by_model_access_groups(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
request_team_id=request_team_id,
|
||||
)
|
||||
)
|
||||
|
||||
# If still no deployments after checking for fallbacks, raise an error
|
||||
if len(healthy_deployments) == 0:
|
||||
@ -9487,6 +9543,70 @@ class Router:
|
||||
|
||||
return model, healthy_deployments
|
||||
|
||||
def _filter_deployments_by_model_access_groups(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
request_kwargs: Optional[Dict],
|
||||
request_team_id: Optional[str],
|
||||
) -> List:
|
||||
"""
|
||||
Restrict candidate deployments to caller-authorized model access groups.
|
||||
|
||||
This is only applied when:
|
||||
- request metadata includes `user_api_key_auth`, and
|
||||
- caller permissions for this model are access-group-only
|
||||
(no explicit model, wildcard, or all-proxy grants).
|
||||
"""
|
||||
if not healthy_deployments or request_kwargs is None:
|
||||
return healthy_deployments
|
||||
|
||||
metadata = request_kwargs.get("metadata") or {}
|
||||
litellm_metadata = request_kwargs.get("litellm_metadata") or {}
|
||||
user_api_key_auth = metadata.get("user_api_key_auth") or litellm_metadata.get(
|
||||
"user_api_key_auth"
|
||||
)
|
||||
if user_api_key_auth is None:
|
||||
return healthy_deployments
|
||||
|
||||
object_models = set(getattr(user_api_key_auth, "models", []) or [])
|
||||
object_team_models = set(getattr(user_api_key_auth, "team_models", []) or [])
|
||||
allowed_models = object_models | object_team_models
|
||||
if not allowed_models:
|
||||
return healthy_deployments
|
||||
|
||||
# If caller has direct model/wildcard/all-proxy access, do not constrain
|
||||
# deployment choice by access group.
|
||||
if (
|
||||
model in allowed_models
|
||||
or "*" in allowed_models
|
||||
or "all-proxy-models" in allowed_models
|
||||
):
|
||||
return healthy_deployments
|
||||
|
||||
access_groups_for_model = self.get_model_access_groups(
|
||||
model_name=model, team_id=request_team_id
|
||||
)
|
||||
if len(access_groups_for_model) == 0:
|
||||
return healthy_deployments
|
||||
|
||||
allowed_access_groups = set(access_groups_for_model.keys()) & allowed_models
|
||||
if not allowed_access_groups:
|
||||
# No overlap means this request was not authorized via model access
|
||||
# group membership for this model, so do not force group filtering.
|
||||
return healthy_deployments
|
||||
|
||||
filtered_deployments = []
|
||||
for deployment in healthy_deployments:
|
||||
deployment_model_info = deployment.get("model_info") or {}
|
||||
deployment_access_groups = set(
|
||||
deployment_model_info.get("access_groups", []) or []
|
||||
)
|
||||
if deployment_access_groups & allowed_access_groups:
|
||||
filtered_deployments.append(deployment)
|
||||
|
||||
return filtered_deployments
|
||||
|
||||
async def async_get_healthy_deployments(
|
||||
self,
|
||||
model: str,
|
||||
@ -10007,6 +10127,7 @@ class Router:
|
||||
messages=messages,
|
||||
input=input,
|
||||
specific_deployment=specific_deployment,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(healthy_deployments, dict):
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
from litellm import Router
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.router_utils.common_utils import (
|
||||
_deployment_supports_web_search,
|
||||
add_model_file_id_mappings,
|
||||
@ -365,6 +366,49 @@ def test_invalidate_model_group_info_cache():
|
||||
assert router._cached_get_model_group_info.cache_info().currsize == 0
|
||||
|
||||
|
||||
def test_filter_deployments_by_model_access_groups_access_group_only_key():
|
||||
"""
|
||||
Access-group-only keys should only route to deployments in allowed groups,
|
||||
even when multiple deployments share the same public model name.
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"model": "openai/gpt-5.1", "api_key": "key-1"},
|
||||
"model_info": {"access_groups": ["AG1"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"model": "openai/gpt-4o", "api_key": "key-2"},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
scoped_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team-2",
|
||||
models=["AG2"],
|
||||
team_models=["AG2"],
|
||||
)
|
||||
|
||||
filtered = router._filter_deployments_by_model_access_groups(
|
||||
model="gpt-5",
|
||||
healthy_deployments=router._get_all_deployments(model_name="gpt-5"),
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team-2",
|
||||
"user_api_key_auth": scoped_key,
|
||||
}
|
||||
},
|
||||
request_team_id="team-2",
|
||||
)
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0].get("model_info", {}).get("access_groups") == ["AG2"]
|
||||
|
||||
|
||||
class TestAddModelFileIdMappings:
|
||||
"""Test cases for add_model_file_id_mappings.
|
||||
|
||||
|
||||
@ -3267,3 +3267,433 @@ async def test_multiregion_team_failover_between_regions():
|
||||
"response from us-east-1",
|
||||
"response from us-west-2",
|
||||
]
|
||||
|
||||
|
||||
def test_access_group_scoped_key_filters_deployments_with_same_public_model():
|
||||
"""
|
||||
If a key can access a model only via access group membership,
|
||||
router candidate deployments for that public model should be constrained
|
||||
to deployments in the allowed access group.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-5.1",
|
||||
"api_key": "key1",
|
||||
"mock_response": "response-via-AG1",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG1"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "key2",
|
||||
"mock_response": "response-via-AG2",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
scoped_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team2",
|
||||
models=["AG2"],
|
||||
team_models=["AG2"],
|
||||
)
|
||||
|
||||
_model, deployments = router._common_checks_available_deployment(
|
||||
model="gpt-5",
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team2",
|
||||
"user_api_key_auth": scoped_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(deployments) == 1
|
||||
assert deployments[0].get("model_info", {}).get("access_groups") == ["AG2"]
|
||||
|
||||
seen = set()
|
||||
for _ in range(20):
|
||||
response = router.completion(
|
||||
model="gpt-5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
metadata={"user_api_key_team_id": "team2", "user_api_key_auth": scoped_key},
|
||||
)
|
||||
seen.add(response.choices[0].message.content)
|
||||
|
||||
assert seen == {"response-via-AG2"}
|
||||
|
||||
|
||||
def test_explicit_model_access_does_not_force_access_group_filtering():
|
||||
"""
|
||||
If a key has explicit model access in addition to access group entries,
|
||||
do not force access-group-only filtering for deployment selection.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-5.1",
|
||||
"api_key": "key1",
|
||||
"mock_response": "response-via-AG1",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG1"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "key2",
|
||||
"mock_response": "response-via-AG2",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
explicit_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team2",
|
||||
models=["AG2", "gpt-5"],
|
||||
team_models=["AG2", "gpt-5"],
|
||||
)
|
||||
|
||||
_model, deployments = router._common_checks_available_deployment(
|
||||
model="gpt-5",
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team2",
|
||||
"user_api_key_auth": explicit_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
deployment_groups = [
|
||||
d.get("model_info", {}).get("access_groups") for d in deployments
|
||||
]
|
||||
assert ["AG1"] in deployment_groups
|
||||
assert ["AG2"] in deployment_groups
|
||||
|
||||
|
||||
def test_access_group_filter_empty_does_not_bypass_via_litellm_model_fallback(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""
|
||||
When access-group filtering removes all candidates, _get_deployment_by_litellm_model
|
||||
must not run: it does not re-apply access groups and could return blocked deployments
|
||||
that share the same litellm_params.model as the request model string.
|
||||
|
||||
``get_model_access_groups`` is patched to expose AG1 for the public model (so the
|
||||
access-group filter runs with a non-empty allowed set) while every deployment
|
||||
returned for that name is AG2-only — filtered to empty. Without the guard, the
|
||||
litellm-model fallback would return both rows because ``litellm_params.model`` matches.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "gpt-5",
|
||||
"api_key": "key1",
|
||||
"mock_response": "blocked-dep-1",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "gpt-5",
|
||||
"api_key": "key2",
|
||||
"mock_response": "blocked-dep-2",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
orig_groups = router.get_model_access_groups
|
||||
|
||||
def fake_get_model_access_groups(
|
||||
model_name=None, model_access_group=None, team_id=None
|
||||
):
|
||||
if model_name == "gpt-5" and model_access_group is None:
|
||||
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
|
||||
return orig_groups(
|
||||
model_name=model_name,
|
||||
model_access_group=model_access_group,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
|
||||
|
||||
scoped_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team2",
|
||||
models=["AG1"],
|
||||
team_models=["AG1"],
|
||||
)
|
||||
|
||||
with pytest.raises(litellm.BadRequestError):
|
||||
router._common_checks_available_deployment(
|
||||
model="gpt-5",
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team2",
|
||||
"user_api_key_auth": scoped_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_access_group_block_does_not_silently_use_default_fallback_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""
|
||||
When access-group filtering empties candidates for model X, the router must not use
|
||||
``fallbacks`` default ``*`` routing to model Y: Y may have no ``access_groups``, so
|
||||
``_filter_deployments_by_model_access_groups`` would not constrain Y and the caller
|
||||
would be served despite being blocked from X.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "gpt-5",
|
||||
"api_key": "key1",
|
||||
"mock_response": "blocked-dep-1",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "gpt-5",
|
||||
"api_key": "key2",
|
||||
"mock_response": "blocked-dep-2",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4-fallback",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"api_key": "fallback-key",
|
||||
"mock_response": "should-not-reach",
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"*": ["gpt-4-fallback"]}],
|
||||
)
|
||||
|
||||
orig_groups = router.get_model_access_groups
|
||||
|
||||
def fake_get_model_access_groups(
|
||||
model_name=None, model_access_group=None, team_id=None
|
||||
):
|
||||
if model_name == "gpt-5" and model_access_group is None:
|
||||
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
|
||||
return orig_groups(
|
||||
model_name=model_name,
|
||||
model_access_group=model_access_group,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
|
||||
|
||||
scoped_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team2",
|
||||
models=["AG1"],
|
||||
team_models=["AG1"],
|
||||
)
|
||||
|
||||
with pytest.raises(litellm.BadRequestError):
|
||||
router._common_checks_available_deployment(
|
||||
model="gpt-5",
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team2",
|
||||
"user_api_key_auth": scoped_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_access_group_block_via_litellm_model_branch_does_not_use_default_fallback(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""
|
||||
When the by-name lookup returns no deployments and the litellm-model fallback
|
||||
branch finds candidates that access-group filtering then empties, the router
|
||||
must not fall through to default ``fallbacks`` routing — the default fallback
|
||||
model may have no ``access_groups`` and would short-circuit the filter,
|
||||
silently serving a caller blocked by access-group restrictions.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5-alias",
|
||||
"litellm_params": {
|
||||
"model": "gpt-5",
|
||||
"api_key": "key1",
|
||||
"mock_response": "blocked-dep-1",
|
||||
},
|
||||
"model_info": {"access_groups": ["AG2"]},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4-fallback",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"api_key": "fallback-key",
|
||||
"mock_response": "should-not-reach",
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"*": ["gpt-4-fallback"]}],
|
||||
)
|
||||
|
||||
orig_groups = router.get_model_access_groups
|
||||
|
||||
def fake_get_model_access_groups(
|
||||
model_name=None, model_access_group=None, team_id=None
|
||||
):
|
||||
if model_name == "gpt-5" and model_access_group is None:
|
||||
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
|
||||
return orig_groups(
|
||||
model_name=model_name,
|
||||
model_access_group=model_access_group,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
|
||||
|
||||
scoped_key = UserAPIKeyAuth(
|
||||
api_key="hashed-key",
|
||||
team_id="team2",
|
||||
models=["AG1"],
|
||||
team_models=["AG1"],
|
||||
)
|
||||
|
||||
with pytest.raises(litellm.BadRequestError):
|
||||
router._common_checks_available_deployment(
|
||||
model="gpt-5",
|
||||
request_kwargs={
|
||||
"metadata": {
|
||||
"user_api_key_team_id": "team2",
|
||||
"user_api_key_auth": scoped_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_try_early_resolve_deployments_for_model_not_in_names():
|
||||
"""
|
||||
Direct coverage for ``_try_early_resolve_deployments_for_model_not_in_names``:
|
||||
|
||||
- Returns ``None`` when the requested model is already in ``self.model_names``
|
||||
(the by-name lookup path will handle it).
|
||||
- Returns ``None`` when there are no team deployments, no pattern matches, and
|
||||
no default deployment to fall back to.
|
||||
- Returns the pattern-router match when the model matches a wildcard route.
|
||||
- Returns the default deployment with the request model substituted in when one
|
||||
is configured, without mutating the stored default.
|
||||
"""
|
||||
router_in_names = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-5",
|
||||
"api_key": "key1",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
router_in_names._try_early_resolve_deployments_for_model_not_in_names(
|
||||
model="gpt-5", request_team_id=None
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
router_in_names._try_early_resolve_deployments_for_model_not_in_names(
|
||||
model="some-unknown-model", request_team_id=None
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
pattern_router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "key-pattern",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
pattern_result = (
|
||||
pattern_router._try_early_resolve_deployments_for_model_not_in_names(
|
||||
model="openai/gpt-4o-mini", request_team_id=None
|
||||
)
|
||||
)
|
||||
assert pattern_result is not None
|
||||
resolved_model, pattern_deployments = pattern_result
|
||||
assert resolved_model == "openai/gpt-4o-mini"
|
||||
assert isinstance(pattern_deployments, list) and len(pattern_deployments) == 1
|
||||
|
||||
default_router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "named-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "key-named",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
default_router.default_deployment = {
|
||||
"model_name": "default",
|
||||
"litellm_params": {
|
||||
"model": "openai/will-be-overridden",
|
||||
"api_key": "key-default",
|
||||
},
|
||||
}
|
||||
|
||||
default_result = (
|
||||
default_router._try_early_resolve_deployments_for_model_not_in_names(
|
||||
model="brand-new-model", request_team_id=None
|
||||
)
|
||||
)
|
||||
assert default_result is not None
|
||||
resolved_model, default_deployment = default_result
|
||||
assert resolved_model == "brand-new-model"
|
||||
assert isinstance(default_deployment, dict)
|
||||
assert default_deployment["litellm_params"]["model"] == "brand-new-model"
|
||||
# The original default_deployment must not be mutated.
|
||||
assert (
|
||||
default_router.default_deployment["litellm_params"]["model"]
|
||||
== "openai/will-be-overridden"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user