diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 96c9cd1e8f..267e388112 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -11140,6 +11140,22 @@ async def _get_caller_byok_team_scope( return set(user_row.teams or []) +def _byok_row_outside_caller_teams( + model_info_dict: Dict[str, Any], allowed_team_ids: Optional[Set[str]] +) -> bool: + """Whether a team BYOK row belongs to a team the caller is not a member of. + + `team_id` is only set on team BYOK rows; non-team rows fall through + unaffected. `allowed_team_ids is None` means no scoping (e.g. admins). + """ + if allowed_team_ids is None: + return False + team_id = model_info_dict.get("team_id") + if team_id is None: + return False + return team_id not in allowed_team_ids + + # Hard cap on rows the DB-side BYOK search may pull when results need to be # sorted across the full match set. Without this, an authenticated caller # can hit `/v2/model/info?search=&sortBy=` and force the @@ -11261,15 +11277,7 @@ async def _apply_search_filter_to_models( ) def _is_byok_outside_caller_teams(model_info_dict: Dict[str, Any]) -> bool: - # `team_id` is only set on team BYOK rows. Non-team rows fall - # through unaffected — they are gated by other paths (router - # membership, direct_access, include_team_models). - if allowed_team_ids is None: - return False - team_id = model_info_dict.get("team_id") - if team_id is None: - return False - return team_id not in allowed_team_ids + return _byok_row_outside_caller_teams(model_info_dict, allowed_team_ids) def _model_matches_search(m: Dict[str, Any]) -> bool: # Team BYOK models persist an internal `model_name` @@ -12409,6 +12417,72 @@ async def model_metrics_exceptions( return {"data": response, "exception_types": list(exception_types)} +def _deployment_matches_allowed_model_names( + model: Dict[str, Any], allowed_model_names: Set[str] +) -> bool: + """Match a router deployment against allowed public model names. + + Team-scoped rows store an internal routing key in ``model_name``; callers + with key/team restrictions still refer to the public name in + ``model_info.team_public_model_name``. + """ + if model.get("model_name") in allowed_model_names: + return True + model_info = model.get("model_info") + if not isinstance(model_info, dict): + return False + team_public_model_name = model_info.get("team_public_model_name") + return ( + isinstance(team_public_model_name, str) + and team_public_model_name in allowed_model_names + ) + + +def _get_v1_model_info_allowed_model_names( + user_api_key_dict: UserAPIKeyAuth, + llm_router: Router, +) -> Optional[Set[str]]: + """Return key/team allowlisted public model names, or None if unrestricted.""" + model_access_groups = llm_router.get_model_access_groups() + proxy_model_list = llm_router.get_model_names() + key_models = get_key_models( + user_api_key_dict=user_api_key_dict, + proxy_model_list=proxy_model_list, + model_access_groups=model_access_groups, + ) + team_models = get_team_models( + team_models=user_api_key_dict.team_models, + proxy_model_list=proxy_model_list, + model_access_groups=model_access_groups, + ) + if not key_models and not team_models: + return None + return set( + get_complete_model_list( + key_models=key_models, + team_models=team_models, + proxy_model_list=proxy_model_list, + user_model=user_model, + infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + llm_router=llm_router, + return_wildcard_routes=False, + ) + ) + + +def _filter_v1_model_info_deployments( + all_models: List[dict], + allowed_model_names: Optional[Set[str]], +) -> List[dict]: + if allowed_model_names is None: + return all_models + return [ + model + for model in all_models + if _deployment_matches_allowed_model_names(model, allowed_model_names) + ] + + def _translate_model_name_for_response(model: dict) -> dict: """For team-scoped DB rows, replace `model_name` with the public name in `model_info.team_public_model_name` before returning. The DB column @@ -12578,49 +12652,42 @@ async def model_info_v1( # noqa: PLR0915 ) return {"data": [_deployment_info_dict]} - all_models: List[dict] = [] - model_access_groups: Dict[str, List[str]] = defaultdict(list) - ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_router is None: - proxy_model_list = [] - else: - proxy_model_list = llm_router.get_model_names() - model_access_groups = llm_router.get_model_access_groups() - key_models = get_key_models( + # Return router deployments (same source as /v2/model/info), not wildcard- + # expanded model names from get_complete_model_list(). Team-scoped rows + # use internal routing keys (model_name_{team_id}_{uuid}) and were omitted + # when v1 resolved models only via public model_name strings. + all_models: List[dict] = copy.deepcopy(llm_router.model_list) + allowed_model_names = _get_v1_model_info_allowed_model_names( user_api_key_dict=user_api_key_dict, - proxy_model_list=proxy_model_list, - model_access_groups=model_access_groups, - ) - team_models = get_team_models( - team_models=user_api_key_dict.team_models, - proxy_model_list=proxy_model_list, - model_access_groups=model_access_groups, - ) - all_models_str = get_complete_model_list( - key_models=key_models, - team_models=team_models, - proxy_model_list=proxy_model_list, - user_model=user_model, - infer_model_from_keys=general_settings.get("infer_model_from_keys", False), llm_router=llm_router, ) - if len(all_models_str) > 0: - _relevant_models = [] - for model in all_models_str: - router_models = llm_router.get_model_list(model_name=model) - if router_models is not None: - _relevant_models.extend(router_models) - if llm_model_list is not None: - all_models = copy.deepcopy(_relevant_models) # type: ignore - else: - all_models = [] + all_models = _filter_v1_model_info_deployments( + all_models=all_models, + allowed_model_names=allowed_model_names, + ) - # Reassign each entry: _get_proxy_model_info returns a (possibly new) - # dict via _translate_model_name_for_response, which does NOT mutate in - # place. Binding only the loop variable would drop the public-name swap - # for team-scoped rows and leak the internal routing key (#28382). - all_models = [_get_proxy_model_info(model=model) for model in all_models] + # Team BYOK deployments carry an internal routing key and other teams' + # public name/team_id/api_base; drop the ones the caller cannot access so + # listing the full router model_list does not leak cross-team metadata. + allowed_team_ids = await _get_caller_byok_team_scope( + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + ) + all_models = [ + model + for model in all_models + if not _byok_row_outside_caller_teams( + model.get("model_info") or {}, allowed_team_ids + ) + ] + + all_models = [ + _translate_model_name_for_response( + _enrich_model_info_with_litellm_data(model=model, llm_router=llm_router) + ) + for model in all_models + ] verbose_proxy_logger.debug("all_models: %s", all_models) return {"data": all_models} diff --git a/tests/test_litellm/proxy/proxy_server/test_team_model_name_translation.py b/tests/test_litellm/proxy/proxy_server/test_team_model_name_translation.py index 97e5c49491..9757999c85 100644 --- a/tests/test_litellm/proxy/proxy_server/test_team_model_name_translation.py +++ b/tests/test_litellm/proxy/proxy_server/test_team_model_name_translation.py @@ -151,21 +151,24 @@ async def test_model_info_v2_translates_team_model_name(monkeypatch): @pytest.mark.asyncio async def test_model_info_v1_list_path_translates_team_model_name(monkeypatch): - """/v1/model/info list path (no litellm_model_id) must surface the public - name. Covers the list comprehension that assigns _get_proxy_model_info's - return back into all_models (#28382 review).""" + """/v1/model/info list path (no litellm_model_id) must include team-scoped + deployments from the router model list and surface the public name (#28382).""" + team_row = _team_row() + global_row = { + "model_name": "gpt-4o", + "litellm_params": {"model": "gpt-4o"}, + "model_info": {"id": "normal-id-1", "db_model": False}, + } router = MagicMock() - router.get_model_names.return_value = ["team-claude-sonnet"] + router.model_list = [team_row, global_row] + router.get_model_names.return_value = ["gpt-4o"] router.get_model_access_groups.return_value = {} - router.get_model_list.return_value = [_team_row()] monkeypatch.setattr(ps, "user_model", None) - monkeypatch.setattr(ps, "llm_model_list", [_team_row()]) + monkeypatch.setattr(ps, "llm_model_list", router.model_list) monkeypatch.setattr(ps, "llm_router", router) - monkeypatch.setattr(ps, "get_key_models", lambda **kw: []) - monkeypatch.setattr(ps, "get_team_models", lambda **kw: []) monkeypatch.setattr( - ps, "get_complete_model_list", lambda **kw: ["team-claude-sonnet"] + ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model ) admin = UserAPIKeyAuth( @@ -176,3 +179,167 @@ async def test_model_info_v1_list_path_translates_team_model_name(monkeypatch): names = [m["model_name"] for m in resp["data"]] assert "team-claude-sonnet" in names assert "model_name_team-abc-123_4a6b8" not in names + + +@pytest.mark.asyncio +async def test_model_info_v1_unrestricted_key_returns_all_deployments(monkeypatch): + """Unrestricted keys must see all router deployments (legacy v1 access logic).""" + deployment = { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + "model_info": {"id": "global-id-1", "db_model": False}, + } + router = MagicMock() + router.model_list = [deployment] + router.get_model_names.return_value = ["gpt-4"] + router.get_model_access_groups.return_value = {} + + monkeypatch.setattr(ps, "user_model", None) + monkeypatch.setattr(ps, "llm_model_list", router.model_list) + monkeypatch.setattr(ps, "llm_router", router) + monkeypatch.setattr( + ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model + ) + + caller = UserAPIKeyAuth( + user_id="user-1", + user_role=LitellmUserRoles.INTERNAL_USER, + models=[], + team_models=[], + ) + resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None) + + assert [m["model_name"] for m in resp["data"]] == ["gpt-4"] + + +@pytest.mark.asyncio +async def test_model_info_v1_restricted_key_filters_deployments(monkeypatch): + """Key-level model allowlists must filter router deployments.""" + team_row = _team_row() + global_row = { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + "model_info": {"id": "global-id-1", "db_model": False}, + } + router = MagicMock() + router.model_list = [team_row, global_row] + router.get_model_names.return_value = ["gpt-4", "team-claude-sonnet"] + router.get_model_access_groups.return_value = {} + + monkeypatch.setattr(ps, "user_model", None) + monkeypatch.setattr(ps, "llm_model_list", router.model_list) + monkeypatch.setattr(ps, "llm_router", router) + monkeypatch.setattr( + ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model + ) + + caller = UserAPIKeyAuth( + user_id="user-1", + user_role=LitellmUserRoles.INTERNAL_USER, + models=["gpt-4"], + team_models=[], + ) + resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None) + + assert [m["model_name"] for m in resp["data"]] == ["gpt-4"] + + +def _other_team_row() -> dict: + return { + "model_name": "model_name_team-other_9f2c1", + "litellm_params": { + "model": "azure/gpt-5.2-low-rpm-testing", + "api_base": "https://team-other-private.example.com", + }, + "model_info": { + "id": "byok-id-other", + "team_id": "team-other", + "team_public_model_name": "team-claude-sonnet", + "db_model": True, + }, + } + + +@pytest.mark.asyncio +async def test_model_info_v1_unrestricted_key_hides_other_team_byok(monkeypatch): + """Unrestricted non-admin keys must not enumerate other teams' BYOK + deployments, but must still see global models and their own team's.""" + team_row = _team_row() + other_team_row = _other_team_row() + global_row = { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + "model_info": {"id": "global-id-1", "db_model": False}, + } + router = MagicMock() + router.model_list = [team_row, other_team_row, global_row] + router.get_model_names.return_value = ["gpt-4"] + router.get_model_access_groups.return_value = {} + + prisma_client = MagicMock() + caller_user_row = MagicMock() + caller_user_row.teams = ["team-abc-123"] + prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=caller_user_row + ) + + monkeypatch.setattr(ps, "user_model", None) + monkeypatch.setattr(ps, "llm_model_list", router.model_list) + monkeypatch.setattr(ps, "llm_router", router) + monkeypatch.setattr(ps, "prisma_client", prisma_client) + monkeypatch.setattr( + ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model + ) + + caller = UserAPIKeyAuth( + user_id="user-1", + user_role=LitellmUserRoles.INTERNAL_USER, + models=[], + team_models=[], + ) + resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None) + + returned_ids = {m["model_info"]["id"] for m in resp["data"]} + assert returned_ids == {"global-id-1", "byok-id-1"} + assert "byok-id-other" not in returned_ids + names = [m["model_name"] for m in resp["data"]] + assert "team-claude-sonnet" in names + assert "gpt-4" in names + + +@pytest.mark.asyncio +async def test_model_info_v1_service_key_hides_all_team_byok(monkeypatch): + """A key without a resolvable user (e.g. CI/service token) sees only + global deployments, never any team-scoped BYOK rows.""" + team_row = _team_row() + other_team_row = _other_team_row() + global_row = { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + "model_info": {"id": "global-id-1", "db_model": False}, + } + router = MagicMock() + router.model_list = [team_row, other_team_row, global_row] + router.get_model_names.return_value = ["gpt-4"] + router.get_model_access_groups.return_value = {} + + prisma_client = MagicMock() + + monkeypatch.setattr(ps, "user_model", None) + monkeypatch.setattr(ps, "llm_model_list", router.model_list) + monkeypatch.setattr(ps, "llm_router", router) + monkeypatch.setattr(ps, "prisma_client", prisma_client) + monkeypatch.setattr( + ps, "_enrich_model_info_with_litellm_data", lambda model, **kw: model + ) + + caller = UserAPIKeyAuth( + user_id=None, + user_role=LitellmUserRoles.INTERNAL_USER, + team_id="team-abc-123", + models=[], + team_models=[], + ) + resp = await ps.model_info_v1(user_api_key_dict=caller, litellm_model_id=None) + + assert [m["model_info"]["id"] for m in resp["data"]] == ["global-id-1"] diff --git a/tests/test_litellm/proxy/test_model_info_default_limits.py b/tests/test_litellm/proxy/test_model_info_default_limits.py index 641199c96f..8111a7af00 100644 --- a/tests/test_litellm/proxy/test_model_info_default_limits.py +++ b/tests/test_litellm/proxy/test_model_info_default_limits.py @@ -146,9 +146,9 @@ class TestModelInfoEndpointWithRouter: deployment_dict = deployment.model_dump(exclude_none=True) mock_router = MagicMock() + mock_router.model_list = [deployment_dict] mock_router.get_model_names.return_value = ["model1"] mock_router.get_model_access_groups.return_value = {} - mock_router.get_model_list.return_value = [deployment_dict] user_api_key_dict = UserAPIKeyAuth(api_key="sk-test") @@ -156,6 +156,7 @@ class TestModelInfoEndpointWithRouter: patch("litellm.proxy.proxy_server.llm_router", mock_router), patch("litellm.proxy.proxy_server.llm_model_list", [deployment_dict]), patch("litellm.proxy.proxy_server.user_model", None), + patch("litellm.proxy.proxy_server.prisma_client", None), patch("litellm.proxy.proxy_server.get_key_models", return_value=["model1"]), patch( "litellm.proxy.proxy_server.get_team_models", return_value=["model1"] diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 9f2c5ffd61..9eaccdfcbc 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -3840,14 +3840,15 @@ async def test_model_info_v1_oci_secrets_not_leaked(): # Mock the llm_router to return our test data mock_router = MagicMock() + mock_router.model_list = [mock_model_data] mock_router.get_model_names.return_value = ["oci-grok-test"] mock_router.get_model_access_groups.return_value = {} - mock_router.get_model_list.return_value = [mock_model_data] # Mock global variables with ( patch("litellm.proxy.proxy_server.llm_router", mock_router), patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]), + patch("litellm.proxy.proxy_server.prisma_client", None), patch( "litellm.proxy.proxy_server.general_settings", {"infer_model_from_keys": False},