diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 0cbccfc18a..0be476469a 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -561,7 +561,7 @@ async def _setup_new_team_model_assignment( async def _get_team_deployments( - team_id: str, prisma_client: PrismaClient + team_id: str, prisma_client: PrismaClient, table: Optional[Any] = None ) -> List[LiteLLM_ProxyModelTable]: """ Fetch all deployments for a given team_id from the database. @@ -572,9 +572,13 @@ async def _get_team_deployments( Note: prisma-client-py 0.11.0 does not support JSON path filtering, so we filter by the model_name prefix (team models use "model_name_{team_id}_*") and confirm team_id in model_info with Python-side filtering. + + Pass ``table`` (a transaction's proxy-model table) to run the read inside an + existing transaction. """ prefix = f"model_name_{team_id}_" - response = await ModelRepository(prisma_client).table.find_many( + table = table or ModelRepository(prisma_client).table + response = await table.find_many( where={ "model_name": {"startswith": prefix}, } @@ -596,6 +600,42 @@ async def _get_team_deployments( return result +async def delete_team_models( + team_ids: List[str], + prisma_client: PrismaClient, + llm_router: Optional[Any], +) -> List[str]: + """ + Delete every BYOK model owned by the given teams, from the DB and the router. + + The DB rows are removed inside a single transaction, so deletion is atomic + across all team_ids. Each team's rows are deleted by the exact model_ids read + in the same transaction, which keeps the deleted set identical to the set + handed to the router. The router is synced only after the transaction commits, + so a rollback can never leave a deployment live in the router without its row. + + Returns the model_ids that were deleted. + """ + deleted_model_ids: List[str] = [] + async with prisma_client.db.tx() as tx: + for team_id in team_ids: + rows = await _get_team_deployments( + team_id, prisma_client, table=tx.litellm_proxymodeltable + ) + model_ids = [row.model_id for row in rows] + if model_ids: + await tx.litellm_proxymodeltable.delete_many( + where={"model_id": {"in": model_ids}} + ) + deleted_model_ids.extend(model_ids) + + if llm_router is not None: + for model_id in deleted_model_ids: + llm_router.delete_deployment(id=model_id) + + return deleted_model_ids + + async def _get_team_public_model_names( team_id: str, prisma_client: PrismaClient, diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index f2eafcbf83..0e69de87ce 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -3289,6 +3289,20 @@ async def delete_team( await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + ## DELETE ASSOCIATED BYOK MODELS + # Runs before the team rows are deleted so a mid-flight failure never leaves + # the team gone with its models orphaned. + from litellm.proxy.management_endpoints.model_management_endpoints import ( + delete_team_models, + ) + from litellm.proxy.proxy_server import llm_router + + await delete_team_models( + team_ids=data.team_ids, + prisma_client=prisma_client, + llm_router=llm_router, + ) + # ## DELETE TEAM MEMBERSHIPS for team_row in team_rows: ### get all team members diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 2ba604e5da..6a81b1b613 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import ( ModelManagementAuthChecks, _get_team_deployments, clear_cache, + delete_team_models, ) from litellm.proxy.utils import PrismaClient from litellm.types.router import Deployment, LiteLLM_Params, updateDeployment @@ -2184,6 +2185,148 @@ class TestGetTeamDeployments: assert result[0] is dep1 +def _model_row(model_id: str, team_id: str): + row = MagicMock() + row.model_id = model_id + row.model_name = f"model_name_{team_id}_{model_id}" + row.model_info = {"team_id": team_id} + return row + + +class _TxProxyModelTable: + """Transactional proxy-model table that records the order of DB writes.""" + + def __init__(self, rows, events): + self._rows = list(rows) + self.events = events + + async def find_many(self, where): + prefix = where["model_name"]["startswith"] + return [r for r in self._rows if r.model_name.startswith(prefix)] + + async def delete_many(self, where): + ids = list(where["model_id"]["in"]) + self.events.append(("delete_many", tuple(ids))) + self._rows = [r for r in self._rows if r.model_id not in ids] + return len(ids) + + +class _TxPrismaClient: + """Minimal prisma stub whose ``db.tx()`` yields a transaction and records commit.""" + + def __init__(self, rows): + self.events: list = [] + self._table = _TxProxyModelTable(rows, self.events) + tx = MagicMock() + tx.litellm_proxymodeltable = self._table + outer = self + + class _TxCM: + async def __aenter__(self): + return tx + + async def __aexit__(self, *exc): + outer.events.append(("commit",)) + return False + + self.db = MagicMock() + self.db.tx = MagicMock(return_value=_TxCM()) + + +class _RecordingRouter: + def __init__(self, events): + self.events = events + self.deleted: list = [] + + def delete_deployment(self, id): # noqa: A002 - matches router signature + self.events.append(("router", id)) + self.deleted.append(id) + + +class TestDeleteTeamModels: + """delete_team_models must remove every team's BYOK models in one transaction + and sync the in-memory router only after that transaction commits.""" + + @pytest.mark.asyncio + async def test_deletes_all_teams_models_and_syncs_router(self): + rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")] + prisma = _TxPrismaClient(rows) + router = _RecordingRouter(prisma.events) + + deleted = await delete_team_models( + team_ids=["team_a", "team_b"], + prisma_client=prisma, + llm_router=router, + ) + + assert sorted(deleted) == ["a1", "b1"] + assert sorted(router.deleted) == ["a1", "b1"] + + @pytest.mark.asyncio + async def test_router_sync_happens_after_commit(self): + """Race-safety: the router is touched only once the DB transaction has + committed, so a rollback can never leave a deployment without its row.""" + rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")] + prisma = _TxPrismaClient(rows) + router = _RecordingRouter(prisma.events) + + await delete_team_models( + team_ids=["team_a", "team_b"], prisma_client=prisma, llm_router=router + ) + + commit_idx = prisma.events.index(("commit",)) + router_indices = [i for i, e in enumerate(prisma.events) if e[0] == "router"] + delete_indices = [ + i for i, e in enumerate(prisma.events) if e[0] == "delete_many" + ] + assert router_indices, "router was never synced" + assert all(i > commit_idx for i in router_indices) + assert all(i < commit_idx for i in delete_indices) + + @pytest.mark.asyncio + async def test_only_owning_team_models_deleted(self): + """A row sharing the prefix but a different model_info.team_id is left alone.""" + mine = _model_row("a1", "team_a") + intruder = MagicMock() + intruder.model_id = "x9" + intruder.model_name = "model_name_team_a_x9" + intruder.model_info = {"team_id": "someone_else"} + prisma = _TxPrismaClient([mine, intruder]) + router = _RecordingRouter(prisma.events) + + deleted = await delete_team_models( + team_ids=["team_a"], prisma_client=prisma, llm_router=router + ) + + assert deleted == ["a1"] + assert router.deleted == ["a1"] + + @pytest.mark.asyncio + async def test_no_models_no_writes(self): + prisma = _TxPrismaClient([]) + router = _RecordingRouter(prisma.events) + + deleted = await delete_team_models( + team_ids=["team_a"], prisma_client=prisma, llm_router=router + ) + + assert deleted == [] + assert router.deleted == [] + assert not any(e[0] == "delete_many" for e in prisma.events) + + @pytest.mark.asyncio + async def test_missing_router_is_safe(self): + rows = [_model_row("a1", "team_a")] + prisma = _TxPrismaClient(rows) + + deleted = await delete_team_models( + team_ids=["team_a"], prisma_client=prisma, llm_router=None + ) + + assert deleted == ["a1"] + assert any(e[0] == "delete_many" for e in prisma.events) + + def _build_db_model_for_blocked_test(): from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index 400ed28780..06adcf8070 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -6350,6 +6350,14 @@ async def test_delete_team_persists_deleted_teams(monkeypatch): mock_find_many_keys = AsyncMock(return_value=[]) mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many_keys + # delete_team now deletes team BYOK models inside a transaction; this team has none. + mock_tx = AsyncMock() + mock_tx.litellm_proxymodeltable.find_many = AsyncMock(return_value=[]) + mock_tx_cm = MagicMock() + mock_tx_cm.__aenter__ = AsyncMock(return_value=mock_tx) + mock_tx_cm.__aexit__ = AsyncMock(return_value=False) + mock_prisma_client.db.tx = MagicMock(return_value=mock_tx_cm) + monkeypatch.setattr( "litellm.proxy.proxy_server.prisma_client", mock_prisma_client, @@ -8499,6 +8507,8 @@ async def test_new_team_encrypts_callback_vars( assert cv["langfuse_secret_key"] != "sk-real" recovered = decrypt_callback_vars(metadata)["logging"][0]["callback_vars"] assert recovered["langfuse_secret_key"] == "sk-real" + + def _non_admin_auth(): return UserAPIKeyAuth( user_id="u-team-admin", user_role=LitellmUserRoles.INTERNAL_USER