fix(proxy): align /v1/model/info with router deployments (#30025)
* fix(proxy): align /v1/model/info with router deployments Return router model_list entries (including team-scoped models) with team access metadata instead of wildcard-expanded names from get_complete_model_list. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(proxy): gate v1 team filter and honor key allowlists Only apply get_all_team_and_direct_access_models for admin or user-bound keys, then intersect with key/team model restrictions to avoid empty lists for service tokens and metadata leaks for restricted keys. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(proxy): skip v1 team filter when user row is missing Require a DB-backed user before applying team-access filtering on /v1/model/info, and skip the trailing filter in get_all_team_and_direct_access_models when user context cannot be resolved. Co-authored-by: Cursor <cursoragent@cursor.com> * Revert "fix(proxy): skip v1 team filter when user row is missing" This reverts commit 74e1fbd77a981103cd9a4ed1cbdd662f5cbcf209. * fix(proxy): restore legacy v1 model access filtering Keep /v1/model/info on key/team allowlists instead of DB team-membership filtering, while still listing router deployments for team-scoped models. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(proxy): drop A2A agent entries from public /v1/model/info list * fix(proxy): scope team BYOK rows on /v1/model/info to caller's teams Listing the full router model_list let any authenticated key without explicit model restrictions enumerate other teams' BYOK deployments (public name, team_id, api_base) via /v1/model/info. Reuse the existing _get_caller_byok_team_scope check so non-admin callers only see global deployments plus their own team's BYOK rows; admins keep the full view. --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
This commit is contained in:
parent
4def6916da
commit
6068bb7781
@ -11140,6 +11140,22 @@ async def _get_caller_byok_team_scope(
|
||||
return set(user_row.teams or [])
|
||||
|
||||
|
||||
def _byok_row_outside_caller_teams(
|
||||
model_info_dict: Dict[str, Any], allowed_team_ids: Optional[Set[str]]
|
||||
) -> bool:
|
||||
"""Whether a team BYOK row belongs to a team the caller is not a member of.
|
||||
|
||||
`team_id` is only set on team BYOK rows; non-team rows fall through
|
||||
unaffected. `allowed_team_ids is None` means no scoping (e.g. admins).
|
||||
"""
|
||||
if allowed_team_ids is None:
|
||||
return False
|
||||
team_id = model_info_dict.get("team_id")
|
||||
if team_id is None:
|
||||
return False
|
||||
return team_id not in allowed_team_ids
|
||||
|
||||
|
||||
# Hard cap on rows the DB-side BYOK search may pull when results need to be
|
||||
# sorted across the full match set. Without this, an authenticated caller
|
||||
# can hit `/v2/model/info?search=<broad>&sortBy=<field>` and force the
|
||||
@ -11261,15 +11277,7 @@ async def _apply_search_filter_to_models(
|
||||
)
|
||||
|
||||
def _is_byok_outside_caller_teams(model_info_dict: Dict[str, Any]) -> bool:
|
||||
# `team_id` is only set on team BYOK rows. Non-team rows fall
|
||||
# through unaffected — they are gated by other paths (router
|
||||
# membership, direct_access, include_team_models).
|
||||
if allowed_team_ids is None:
|
||||
return False
|
||||
team_id = model_info_dict.get("team_id")
|
||||
if team_id is None:
|
||||
return False
|
||||
return team_id not in allowed_team_ids
|
||||
return _byok_row_outside_caller_teams(model_info_dict, allowed_team_ids)
|
||||
|
||||
def _model_matches_search(m: Dict[str, Any]) -> bool:
|
||||
# Team BYOK models persist an internal `model_name`
|
||||
@ -12409,6 +12417,72 @@ async def model_metrics_exceptions(
|
||||
return {"data": response, "exception_types": list(exception_types)}
|
||||
|
||||
|
||||
def _deployment_matches_allowed_model_names(
|
||||
model: Dict[str, Any], allowed_model_names: Set[str]
|
||||
) -> bool:
|
||||
"""Match a router deployment against allowed public model names.
|
||||
|
||||
Team-scoped rows store an internal routing key in ``model_name``; callers
|
||||
with key/team restrictions still refer to the public name in
|
||||
``model_info.team_public_model_name``.
|
||||
"""
|
||||
if model.get("model_name") in allowed_model_names:
|
||||
return True
|
||||
model_info = model.get("model_info")
|
||||
if not isinstance(model_info, dict):
|
||||
return False
|
||||
team_public_model_name = model_info.get("team_public_model_name")
|
||||
return (
|
||||
isinstance(team_public_model_name, str)
|
||||
and team_public_model_name in allowed_model_names
|
||||
)
|
||||
|
||||
|
||||
def _get_v1_model_info_allowed_model_names(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
llm_router: Router,
|
||||
) -> Optional[Set[str]]:
|
||||
"""Return key/team allowlisted public model names, or None if unrestricted."""
|
||||
model_access_groups = llm_router.get_model_access_groups()
|
||||
proxy_model_list = llm_router.get_model_names()
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
)
|
||||
team_models = get_team_models(
|
||||
team_models=user_api_key_dict.team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
)
|
||||
if not key_models and not team_models:
|
||||
return None
|
||||
return set(
|
||||
get_complete_model_list(
|
||||
key_models=key_models,
|
||||
team_models=team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
llm_router=llm_router,
|
||||
return_wildcard_routes=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _filter_v1_model_info_deployments(
|
||||
all_models: List[dict],
|
||||
allowed_model_names: Optional[Set[str]],
|
||||
) -> List[dict]:
|
||||
if allowed_model_names is None:
|
||||
return all_models
|
||||
return [
|
||||
model
|
||||
for model in all_models
|
||||
if _deployment_matches_allowed_model_names(model, allowed_model_names)
|
||||
]
|
||||
|
||||
|
||||
def _translate_model_name_for_response(model: dict) -> dict:
|
||||
"""For team-scoped DB rows, replace `model_name` with the public name
|
||||
in `model_info.team_public_model_name` before returning. The DB column
|
||||
@ -12578,49 +12652,42 @@ async def model_info_v1( # noqa: PLR0915
|
||||
)
|
||||
return {"data": [_deployment_info_dict]}
|
||||
|
||||
all_models: List[dict] = []
|
||||
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_router is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = llm_router.get_model_names()
|
||||
model_access_groups = llm_router.get_model_access_groups()
|
||||
key_models = get_key_models(
|
||||
# Return router deployments (same source as /v2/model/info), not wildcard-
|
||||
# expanded model names from get_complete_model_list(). Team-scoped rows
|
||||
# use internal routing keys (model_name_{team_id}_{uuid}) and were omitted
|
||||
# when v1 resolved models only via public model_name strings.
|
||||
all_models: List[dict] = copy.deepcopy(llm_router.model_list)
|
||||
allowed_model_names = _get_v1_model_info_allowed_model_names(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
)
|
||||
team_models = get_team_models(
|
||||
team_models=user_api_key_dict.team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
model_access_groups=model_access_groups,
|
||||
)
|
||||
all_models_str = get_complete_model_list(
|
||||
key_models=key_models,
|
||||
team_models=team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if len(all_models_str) > 0:
|
||||
_relevant_models = []
|
||||
for model in all_models_str:
|
||||
router_models = llm_router.get_model_list(model_name=model)
|
||||
if router_models is not None:
|
||||
_relevant_models.extend(router_models)
|
||||
if llm_model_list is not None:
|
||||
all_models = copy.deepcopy(_relevant_models) # type: ignore
|
||||
else:
|
||||
all_models = []
|
||||
all_models = _filter_v1_model_info_deployments(
|
||||
all_models=all_models,
|
||||
allowed_model_names=allowed_model_names,
|
||||
)
|
||||
|
||||
# Reassign each entry: _get_proxy_model_info returns a (possibly new)
|
||||
# dict via _translate_model_name_for_response, which does NOT mutate in
|
||||
# place. Binding only the loop variable would drop the public-name swap
|
||||
# for team-scoped rows and leak the internal routing key (#28382).
|
||||
all_models = [_get_proxy_model_info(model=model) for model in all_models]
|
||||
# Team BYOK deployments carry an internal routing key and other teams'
|
||||
# public name/team_id/api_base; drop the ones the caller cannot access so
|
||||
# listing the full router model_list does not leak cross-team metadata.
|
||||
allowed_team_ids = await _get_caller_byok_team_scope(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
all_models = [
|
||||
model
|
||||
for model in all_models
|
||||
if not _byok_row_outside_caller_teams(
|
||||
model.get("model_info") or {}, allowed_team_ids
|
||||
)
|
||||
]
|
||||
|
||||
all_models = [
|
||||
_translate_model_name_for_response(
|
||||
_enrich_model_info_with_litellm_data(model=model, llm_router=llm_router)
|
||||
)
|
||||
for model in all_models
|
||||
]
|
||||
|
||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||
return {"data": all_models}
|
||||
|
||||
@ -151,21 +151,24 @@ async def test_model_info_v2_translates_team_model_name(monkeypatch):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v1_list_path_translates_team_model_name(monkeypatch):
|
||||
"""/v1/model/info list path (no litellm_model_id) must surface the public
|
||||
name. Covers the list comprehension that assigns _get_proxy_model_info's
|
||||
return back into all_models (#28382 review)."""
|
||||
"""/v1/model/info list path (no litellm_model_id) must include team-scoped
|
||||
deployments from the router model list and surface the public name (#28382)."""
|
||||
team_row = _team_row()
|
||||
global_row = {
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"model": "gpt-4o"},
|
||||
"model_info": {"id": "normal-id-1", "db_model": False},
|
||||
}
|
||||
router = MagicMock()
|
||||
router.get_model_names.return_value = ["team-claude-sonnet"]
|
||||
router.model_list = [team_row, global_row]
|
||||
router.get_model_names.return_value = ["gpt-4o"]
|
||||
router.get_model_access_groups.return_value = {}
|
||||
router.get_model_list.return_value = [_team_row()]
|
||||
|
||||
monkeypatch.setattr(ps, "user_model", None)
|
||||
monkeypatch.setattr(ps, "llm_model_list", [_team_row()])
|
||||
monkeypatch.setattr(ps, "llm_model_list", router.model_list)
|
||||
monkeypatch.setattr(ps, "llm_router", router)
|
||||
monkeypatch.setattr(ps, "get_key_models", lambda **kw: [])
|
||||
monkeypatch.setattr(ps, "get_team_models", lambda **kw: [])
|
||||
monkeypatch.setattr(
|
||||
ps, "get_complete_model_list", lambda **kw: ["team-claude-sonnet"]
|
||||
ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model
|
||||
)
|
||||
|
||||
admin = UserAPIKeyAuth(
|
||||
@ -176,3 +179,167 @@ async def test_model_info_v1_list_path_translates_team_model_name(monkeypatch):
|
||||
names = [m["model_name"] for m in resp["data"]]
|
||||
assert "team-claude-sonnet" in names
|
||||
assert "model_name_team-abc-123_4a6b8" not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v1_unrestricted_key_returns_all_deployments(monkeypatch):
|
||||
"""Unrestricted keys must see all router deployments (legacy v1 access logic)."""
|
||||
deployment = {
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "global-id-1", "db_model": False},
|
||||
}
|
||||
router = MagicMock()
|
||||
router.model_list = [deployment]
|
||||
router.get_model_names.return_value = ["gpt-4"]
|
||||
router.get_model_access_groups.return_value = {}
|
||||
|
||||
monkeypatch.setattr(ps, "user_model", None)
|
||||
monkeypatch.setattr(ps, "llm_model_list", router.model_list)
|
||||
monkeypatch.setattr(ps, "llm_router", router)
|
||||
monkeypatch.setattr(
|
||||
ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model
|
||||
)
|
||||
|
||||
caller = UserAPIKeyAuth(
|
||||
user_id="user-1",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
models=[],
|
||||
team_models=[],
|
||||
)
|
||||
resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None)
|
||||
|
||||
assert [m["model_name"] for m in resp["data"]] == ["gpt-4"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v1_restricted_key_filters_deployments(monkeypatch):
|
||||
"""Key-level model allowlists must filter router deployments."""
|
||||
team_row = _team_row()
|
||||
global_row = {
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "global-id-1", "db_model": False},
|
||||
}
|
||||
router = MagicMock()
|
||||
router.model_list = [team_row, global_row]
|
||||
router.get_model_names.return_value = ["gpt-4", "team-claude-sonnet"]
|
||||
router.get_model_access_groups.return_value = {}
|
||||
|
||||
monkeypatch.setattr(ps, "user_model", None)
|
||||
monkeypatch.setattr(ps, "llm_model_list", router.model_list)
|
||||
monkeypatch.setattr(ps, "llm_router", router)
|
||||
monkeypatch.setattr(
|
||||
ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model
|
||||
)
|
||||
|
||||
caller = UserAPIKeyAuth(
|
||||
user_id="user-1",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
models=["gpt-4"],
|
||||
team_models=[],
|
||||
)
|
||||
resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None)
|
||||
|
||||
assert [m["model_name"] for m in resp["data"]] == ["gpt-4"]
|
||||
|
||||
|
||||
def _other_team_row() -> dict:
|
||||
return {
|
||||
"model_name": "model_name_team-other_9f2c1",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-5.2-low-rpm-testing",
|
||||
"api_base": "https://team-other-private.example.com",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "byok-id-other",
|
||||
"team_id": "team-other",
|
||||
"team_public_model_name": "team-claude-sonnet",
|
||||
"db_model": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v1_unrestricted_key_hides_other_team_byok(monkeypatch):
|
||||
"""Unrestricted non-admin keys must not enumerate other teams' BYOK
|
||||
deployments, but must still see global models and their own team's."""
|
||||
team_row = _team_row()
|
||||
other_team_row = _other_team_row()
|
||||
global_row = {
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "global-id-1", "db_model": False},
|
||||
}
|
||||
router = MagicMock()
|
||||
router.model_list = [team_row, other_team_row, global_row]
|
||||
router.get_model_names.return_value = ["gpt-4"]
|
||||
router.get_model_access_groups.return_value = {}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
caller_user_row = MagicMock()
|
||||
caller_user_row.teams = ["team-abc-123"]
|
||||
prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
||||
return_value=caller_user_row
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ps, "user_model", None)
|
||||
monkeypatch.setattr(ps, "llm_model_list", router.model_list)
|
||||
monkeypatch.setattr(ps, "llm_router", router)
|
||||
monkeypatch.setattr(ps, "prisma_client", prisma_client)
|
||||
monkeypatch.setattr(
|
||||
ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model
|
||||
)
|
||||
|
||||
caller = UserAPIKeyAuth(
|
||||
user_id="user-1",
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
models=[],
|
||||
team_models=[],
|
||||
)
|
||||
resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None)
|
||||
|
||||
returned_ids = {m["model_info"]["id"] for m in resp["data"]}
|
||||
assert returned_ids == {"global-id-1", "byok-id-1"}
|
||||
assert "byok-id-other" not in returned_ids
|
||||
names = [m["model_name"] for m in resp["data"]]
|
||||
assert "team-claude-sonnet" in names
|
||||
assert "gpt-4" in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v1_service_key_hides_all_team_byok(monkeypatch):
|
||||
"""A key without a resolvable user (e.g. CI/service token) sees only
|
||||
global deployments, never any team-scoped BYOK rows."""
|
||||
team_row = _team_row()
|
||||
other_team_row = _other_team_row()
|
||||
global_row = {
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "global-id-1", "db_model": False},
|
||||
}
|
||||
router = MagicMock()
|
||||
router.model_list = [team_row, other_team_row, global_row]
|
||||
router.get_model_names.return_value = ["gpt-4"]
|
||||
router.get_model_access_groups.return_value = {}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
|
||||
monkeypatch.setattr(ps, "user_model", None)
|
||||
monkeypatch.setattr(ps, "llm_model_list", router.model_list)
|
||||
monkeypatch.setattr(ps, "llm_router", router)
|
||||
monkeypatch.setattr(ps, "prisma_client", prisma_client)
|
||||
monkeypatch.setattr(
|
||||
ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model
|
||||
)
|
||||
|
||||
caller = UserAPIKeyAuth(
|
||||
user_id=None,
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
team_id="team-abc-123",
|
||||
models=[],
|
||||
team_models=[],
|
||||
)
|
||||
resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None)
|
||||
|
||||
assert [m["model_info"]["id"] for m in resp["data"]] == ["global-id-1"]
|
||||
|
||||
@ -146,9 +146,9 @@ class TestModelInfoEndpointWithRouter:
|
||||
deployment_dict = deployment.model_dump(exclude_none=True)
|
||||
|
||||
mock_router = MagicMock()
|
||||
mock_router.model_list = [deployment_dict]
|
||||
mock_router.get_model_names.return_value = ["model1"]
|
||||
mock_router.get_model_access_groups.return_value = {}
|
||||
mock_router.get_model_list.return_value = [deployment_dict]
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key="sk-test")
|
||||
|
||||
@ -156,6 +156,7 @@ class TestModelInfoEndpointWithRouter:
|
||||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||||
patch("litellm.proxy.proxy_server.llm_model_list", [deployment_dict]),
|
||||
patch("litellm.proxy.proxy_server.user_model", None),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", None),
|
||||
patch("litellm.proxy.proxy_server.get_key_models", return_value=["model1"]),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.get_team_models", return_value=["model1"]
|
||||
|
||||
@ -3840,14 +3840,15 @@ async def test_model_info_v1_oci_secrets_not_leaked():
|
||||
|
||||
# Mock the llm_router to return our test data
|
||||
mock_router = MagicMock()
|
||||
mock_router.model_list = [mock_model_data]
|
||||
mock_router.get_model_names.return_value = ["oci-grok-test"]
|
||||
mock_router.get_model_access_groups.return_value = {}
|
||||
mock_router.get_model_list.return_value = [mock_model_data]
|
||||
|
||||
# Mock global variables
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||||
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", None),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{"infer_model_from_keys": False},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user