fix(team-management): delete a team's BYOK models when the team is deleted (#29977)

A team's BYOK models (rows in LiteLLM_ProxyModelTable with model_info.team_id set)
were left orphaned when the team was deleted; they lingered in the database and kept
showing on the Models + Endpoints page. delete_team now removes them via a new
delete_team_models helper that deletes the rows in one transaction and syncs the
in-memory router only after that transaction commits, run before the team rows are
deleted so a mid-flight failure never leaves the team gone with its models orphaned
This commit is contained in:
yuneng-jiang 2026-06-08 16:55:35 -07:00 committed by GitHub
parent bac2590b39
commit c24a3603d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 209 additions and 2 deletions

View File

@ -561,7 +561,7 @@ async def _setup_new_team_model_assignment(
async def _get_team_deployments(
team_id: str, prisma_client: PrismaClient
team_id: str, prisma_client: PrismaClient, table: Optional[Any] = None
) -> List[LiteLLM_ProxyModelTable]:
"""
Fetch all deployments for a given team_id from the database.
@ -572,9 +572,13 @@ async def _get_team_deployments(
Note: prisma-client-py 0.11.0 does not support JSON path filtering, so we filter
by the model_name prefix (team models use "model_name_{team_id}_*") and confirm
team_id in model_info with Python-side filtering.
Pass ``table`` (a transaction's proxy-model table) to run the read inside an
existing transaction.
"""
prefix = f"model_name_{team_id}_"
response = await ModelRepository(prisma_client).table.find_many(
table = table or ModelRepository(prisma_client).table
response = await table.find_many(
where={
"model_name": {"startswith": prefix},
}
@ -596,6 +600,42 @@ async def _get_team_deployments(
return result
async def delete_team_models(
team_ids: List[str],
prisma_client: PrismaClient,
llm_router: Optional[Any],
) -> List[str]:
"""
Delete every BYOK model owned by the given teams, from the DB and the router.
The DB rows are removed inside a single transaction, so deletion is atomic
across all team_ids. Each team's rows are deleted by the exact model_ids read
in the same transaction, which keeps the deleted set identical to the set
handed to the router. The router is synced only after the transaction commits,
so a rollback can never leave a deployment live in the router without its row.
Returns the model_ids that were deleted.
"""
deleted_model_ids: List[str] = []
async with prisma_client.db.tx() as tx:
for team_id in team_ids:
rows = await _get_team_deployments(
team_id, prisma_client, table=tx.litellm_proxymodeltable
)
model_ids = [row.model_id for row in rows]
if model_ids:
await tx.litellm_proxymodeltable.delete_many(
where={"model_id": {"in": model_ids}}
)
deleted_model_ids.extend(model_ids)
if llm_router is not None:
for model_id in deleted_model_ids:
llm_router.delete_deployment(id=model_id)
return deleted_model_ids
async def _get_team_public_model_names(
team_id: str,
prisma_client: PrismaClient,

View File

@ -3289,6 +3289,20 @@ async def delete_team(
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
## DELETE ASSOCIATED BYOK MODELS
# Runs before the team rows are deleted so a mid-flight failure never leaves
# the team gone with its models orphaned.
from litellm.proxy.management_endpoints.model_management_endpoints import (
delete_team_models,
)
from litellm.proxy.proxy_server import llm_router
await delete_team_models(
team_ids=data.team_ids,
prisma_client=prisma_client,
llm_router=llm_router,
)
# ## DELETE TEAM MEMBERSHIPS
for team_row in team_rows:
### get all team members

View File

@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
ModelManagementAuthChecks,
_get_team_deployments,
clear_cache,
delete_team_models,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.router import Deployment, LiteLLM_Params, updateDeployment
@ -2184,6 +2185,148 @@ class TestGetTeamDeployments:
assert result[0] is dep1
def _model_row(model_id: str, team_id: str):
row = MagicMock()
row.model_id = model_id
row.model_name = f"model_name_{team_id}_{model_id}"
row.model_info = {"team_id": team_id}
return row
class _TxProxyModelTable:
"""Transactional proxy-model table that records the order of DB writes."""
def __init__(self, rows, events):
self._rows = list(rows)
self.events = events
async def find_many(self, where):
prefix = where["model_name"]["startswith"]
return [r for r in self._rows if r.model_name.startswith(prefix)]
async def delete_many(self, where):
ids = list(where["model_id"]["in"])
self.events.append(("delete_many", tuple(ids)))
self._rows = [r for r in self._rows if r.model_id not in ids]
return len(ids)
class _TxPrismaClient:
"""Minimal prisma stub whose ``db.tx()`` yields a transaction and records commit."""
def __init__(self, rows):
self.events: list = []
self._table = _TxProxyModelTable(rows, self.events)
tx = MagicMock()
tx.litellm_proxymodeltable = self._table
outer = self
class _TxCM:
async def __aenter__(self):
return tx
async def __aexit__(self, *exc):
outer.events.append(("commit",))
return False
self.db = MagicMock()
self.db.tx = MagicMock(return_value=_TxCM())
class _RecordingRouter:
def __init__(self, events):
self.events = events
self.deleted: list = []
def delete_deployment(self, id): # noqa: A002 - matches router signature
self.events.append(("router", id))
self.deleted.append(id)
class TestDeleteTeamModels:
"""delete_team_models must remove every team's BYOK models in one transaction
and sync the in-memory router only after that transaction commits."""
@pytest.mark.asyncio
async def test_deletes_all_teams_models_and_syncs_router(self):
rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")]
prisma = _TxPrismaClient(rows)
router = _RecordingRouter(prisma.events)
deleted = await delete_team_models(
team_ids=["team_a", "team_b"],
prisma_client=prisma,
llm_router=router,
)
assert sorted(deleted) == ["a1", "b1"]
assert sorted(router.deleted) == ["a1", "b1"]
@pytest.mark.asyncio
async def test_router_sync_happens_after_commit(self):
"""Race-safety: the router is touched only once the DB transaction has
committed, so a rollback can never leave a deployment without its row."""
rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")]
prisma = _TxPrismaClient(rows)
router = _RecordingRouter(prisma.events)
await delete_team_models(
team_ids=["team_a", "team_b"], prisma_client=prisma, llm_router=router
)
commit_idx = prisma.events.index(("commit",))
router_indices = [i for i, e in enumerate(prisma.events) if e[0] == "router"]
delete_indices = [
i for i, e in enumerate(prisma.events) if e[0] == "delete_many"
]
assert router_indices, "router was never synced"
assert all(i > commit_idx for i in router_indices)
assert all(i < commit_idx for i in delete_indices)
@pytest.mark.asyncio
async def test_only_owning_team_models_deleted(self):
"""A row sharing the prefix but a different model_info.team_id is left alone."""
mine = _model_row("a1", "team_a")
intruder = MagicMock()
intruder.model_id = "x9"
intruder.model_name = "model_name_team_a_x9"
intruder.model_info = {"team_id": "someone_else"}
prisma = _TxPrismaClient([mine, intruder])
router = _RecordingRouter(prisma.events)
deleted = await delete_team_models(
team_ids=["team_a"], prisma_client=prisma, llm_router=router
)
assert deleted == ["a1"]
assert router.deleted == ["a1"]
@pytest.mark.asyncio
async def test_no_models_no_writes(self):
prisma = _TxPrismaClient([])
router = _RecordingRouter(prisma.events)
deleted = await delete_team_models(
team_ids=["team_a"], prisma_client=prisma, llm_router=router
)
assert deleted == []
assert router.deleted == []
assert not any(e[0] == "delete_many" for e in prisma.events)
@pytest.mark.asyncio
async def test_missing_router_is_safe(self):
rows = [_model_row("a1", "team_a")]
prisma = _TxPrismaClient(rows)
deleted = await delete_team_models(
team_ids=["team_a"], prisma_client=prisma, llm_router=None
)
assert deleted == ["a1"]
assert any(e[0] == "delete_many" for e in prisma.events)
def _build_db_model_for_blocked_test():
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo

View File

@ -6350,6 +6350,14 @@ async def test_delete_team_persists_deleted_teams(monkeypatch):
mock_find_many_keys = AsyncMock(return_value=[])
mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many_keys
# delete_team now deletes team BYOK models inside a transaction; this team has none.
mock_tx = AsyncMock()
mock_tx.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
mock_tx_cm = MagicMock()
mock_tx_cm.__aenter__ = AsyncMock(return_value=mock_tx)
mock_tx_cm.__aexit__ = AsyncMock(return_value=False)
mock_prisma_client.db.tx = MagicMock(return_value=mock_tx_cm)
monkeypatch.setattr(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client,
@ -8499,6 +8507,8 @@ async def test_new_team_encrypts_callback_vars(
assert cv["langfuse_secret_key"] != "sk-real"
recovered = decrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
assert recovered["langfuse_secret_key"] == "sk-real"
def _non_admin_auth():
return UserAPIKeyAuth(
user_id="u-team-admin", user_role=LitellmUserRoles.INTERNAL_USER