diff --git a/litellm/llms/litellm_proxy/skills/handler.py b/litellm/llms/litellm_proxy/skills/handler.py index 48c02660a0..37aabd8b47 100644 --- a/litellm/llms/litellm_proxy/skills/handler.py +++ b/litellm/llms/litellm_proxy/skills/handler.py @@ -5,13 +5,11 @@ This module contains the actual database operations for skills CRUD. Used by the transformation layer and skills injection hook. """ -import time import uuid -from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from litellm._logging import verbose_logger -from litellm.llms.litellm_proxy.skills.store import LiteLLMSkillsStore +from litellm.caching.in_memory_cache import InMemoryCache from litellm.proxy._types import LiteLLM_SkillsTable, NewSkillRequest, UserAPIKeyAuth from litellm.proxy.common_utils.resource_ownership import ( get_primary_resource_owner_scope, @@ -21,88 +19,37 @@ from litellm.proxy.common_utils.resource_ownership import ( ) # Skills are looked up on every chat completion that has skills enabled -# (`LiteLLMSkillsHandler.fetch_skill_from_db` in the injection hook). Cache -# the Prisma skill row for a short window so the hot path doesn't issue a DB -# round-trip per request. Same shape as `_byok_cred_cache` and the container -# ownership cache: (value, monotonic_timestamp). `None` is cached as a true -# negative ("skill does not exist") so repeated misses also avoid the DB. -_SKILL_CACHE: "OrderedDict[str, Tuple[Optional[Any], float]]" = OrderedDict() -_SKILL_CACHE_TTL = 60 # seconds -_SKILL_CACHE_MAX_SIZE = 10000 - - -def _read_skill_cache(skill_id: str) -> Tuple[bool, Optional[Any]]: - """Return (hit, value). hit=False means caller must consult the DB.""" - entry = _SKILL_CACHE.get(skill_id) - if entry is None: - return False, None - value, timestamp = entry - if time.monotonic() - timestamp > _SKILL_CACHE_TTL: - _SKILL_CACHE.pop(skill_id, None) - return False, None - _SKILL_CACHE.move_to_end(skill_id) - return True, value - - -def _write_skill_cache(skill_id: str, skill: Optional[Any]) -> None: - # LRU eviction (popitem(last=False)) instead of full ``clear()`` — - # see container ownership cache for rationale. - if skill_id in _SKILL_CACHE: - _SKILL_CACHE.move_to_end(skill_id) - _SKILL_CACHE[skill_id] = (skill, time.monotonic()) - while len(_SKILL_CACHE) > _SKILL_CACHE_MAX_SIZE: - _SKILL_CACHE.popitem(last=False) - - -def _invalidate_skill_cache(skill_id: str) -> None: - """Drop the cache entry after a write so the next read sees the new row.""" - _SKILL_CACHE.pop(skill_id, None) - - -def _user_can_access_skill_owner( - owner: Optional[str], - user_api_key_dict: Optional[UserAPIKeyAuth], -) -> bool: - # Pre-isolation skills with no ``created_by`` are admin-only — same - # rule as untracked containers. Owners need to either re-create via - # the now-tracked flow or have an admin assign ``created_by`` on the - # row. - return user_can_access_resource_owner(owner, user_api_key_dict) +# (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache +# absorbs the hot read before it reaches Prisma. ``_NEGATIVE_SKILL_SENTINEL`` +# lets us cache a true "skill does not exist" so repeated misses also +# avoid the DB — ``InMemoryCache`` returns ``None`` indistinguishably for +# "miss" and "cached as None". +_NEGATIVE_SKILL_SENTINEL = "__litellm_skill_not_found__" +_SKILL_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60) def _prisma_skill_to_litellm(prisma_skill) -> LiteLLM_SkillsTable: - """ - Convert a Prisma skill record to LiteLLM_SkillsTable. + """Convert a Prisma skill record to LiteLLM_SkillsTable. - Handles Base64 decoding of file_content field. + Handles Base64 decoding of file_content field — model_dump() converts + Base64 fields to base64-encoded strings. """ import base64 data = prisma_skill.model_dump() - # Decode Base64 file_content back to bytes - # model_dump() converts Base64 field to base64-encoded string if data.get("file_content") is not None: if isinstance(data["file_content"], str): data["file_content"] = base64.b64decode(data["file_content"]) - elif isinstance(data["file_content"], bytes): - # Already bytes, no conversion needed - pass return LiteLLM_SkillsTable(**data) class LiteLLMSkillsHandler: - """ - Handler for LiteLLM database-backed skills operations. - - This class provides static methods for CRUD operations on skills - stored in the LiteLLM proxy database (LiteLLM_SkillsTable). - """ + """CRUD for skills stored in ``litellm_skillstable``.""" @staticmethod async def _get_prisma_client(): - """Get the prisma client from proxy server.""" from litellm.proxy.proxy_server import prisma_client if prisma_client is None: @@ -118,28 +65,16 @@ class LiteLLMSkillsHandler: user_id: Optional[str] = None, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> LiteLLM_SkillsTable: - """ - Create a new skill in the LiteLLM database. - - Args: - data: NewSkillRequest with skill details - user_id: Optional user ID for tracking - - Returns: - LiteLLM_SkillsTable record - """ prisma_client = await LiteLLMSkillsHandler._get_prisma_client() - store = LiteLLMSkillsStore(prisma_client) skill_id = f"litellm_skill_{uuid.uuid4()}" owner = get_primary_resource_owner_scope(user_api_key_dict) or user_id if owner is None: - # Caller has no identity scope (no user_id / team_id / org_id / - # api_key / token). Stamping a placeholder would let any two - # identity-less callers see each other's skills via the shared - # owner — the cross-tenant primitive we avoid. ValueError keeps - # this module FastAPI-free per the project layering rule - # (litellm_proxy provider integrations live outside proxy/). + # Identity-less callers (no user_id / team_id / org_id / + # api_key / token) can't be uniquely stamped on the row. + # Stamping a placeholder would let any two such callers see + # each other's skills via the shared owner. ValueError keeps + # this module FastAPI-free per the project layering rule. raise ValueError( "Unable to record skill ownership: caller has no identity scope." ) @@ -154,13 +89,11 @@ class LiteLLMSkillsHandler: "updated_by": owner, } - # Handle metadata if data.metadata is not None: from litellm.litellm_core_utils.safe_json_dumps import safe_dumps skill_data["metadata"] = safe_dumps(data.metadata) - # Handle file content - wrap bytes in Base64 for Prisma if data.file_content is not None: from prisma.fields import Base64 @@ -174,8 +107,7 @@ class LiteLLMSkillsHandler: f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}" ) - new_skill = await store.create_skill(skill_data) - + new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data) return _prisma_skill_to_litellm(new_skill) @staticmethod @@ -184,18 +116,7 @@ class LiteLLMSkillsHandler: offset: int = 0, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> List[LiteLLM_SkillsTable]: - """ - List skills from the LiteLLM database. - - Args: - limit: Maximum number of skills to return - offset: Number of skills to skip - - Returns: - List of LiteLLM_SkillsTable records - """ prisma_client = await LiteLLMSkillsHandler._get_prisma_client() - store = LiteLLMSkillsStore(prisma_client) verbose_logger.debug( f"LiteLLMSkillsHandler: Listing skills with limit={limit}, offset={offset}" @@ -212,26 +133,29 @@ class LiteLLMSkillsHandler: return [] find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}} - skills = await store.list_skills(find_many_kwargs) - + skills = await prisma_client.db.litellm_skillstable.find_many( + **find_many_kwargs + ) return [_prisma_skill_to_litellm(s) for s in skills] @staticmethod async def _load_skill(skill_id: str) -> Optional[Any]: - """Cache-first read of the Prisma skill row. - - Caching here keeps `fetch_skill_from_db` (called per chat completion in - the skills injection hook) off the DB. Owner-scope filtering happens - on the cached row, so the cache is per-skill and not per-caller. + """Cache-first read of the Prisma skill row. Owner-scope filtering + happens on the cached row, so the cache is per-skill not per-caller. """ - cached_hit, cached_skill = _read_skill_cache(skill_id) - if cached_hit: - return cached_skill + cached = _SKILL_CACHE.get_cache(skill_id) + if cached == _NEGATIVE_SKILL_SENTINEL: + return None + if cached is not None: + return cached prisma_client = await LiteLLMSkillsHandler._get_prisma_client() - store = LiteLLMSkillsStore(prisma_client) - skill = await store.find_skill(skill_id) - _write_skill_cache(skill_id, skill) + skill = await prisma_client.db.litellm_skillstable.find_unique( + where={"skill_id": skill_id} + ) + _SKILL_CACHE.set_cache( + skill_id, skill if skill is not None else _NEGATIVE_SKILL_SENTINEL + ) return skill @staticmethod @@ -239,26 +163,12 @@ class LiteLLMSkillsHandler: skill_id: str, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> LiteLLM_SkillsTable: - """ - Get a skill by ID from the LiteLLM database. - - Args: - skill_id: The skill ID to retrieve - - Returns: - LiteLLM_SkillsTable record - - Raises: - ValueError: If skill not found - """ verbose_logger.debug(f"LiteLLMSkillsHandler: Getting skill {skill_id}") skill = await LiteLLMSkillsHandler._load_skill(skill_id) - - if skill is None: - raise ValueError(f"Skill not found: {skill_id}") - - if not _user_can_access_skill_owner( + # Same "not found" message for both "missing" and "cross-tenant" + # so callers can't enumerate skill IDs they don't own. + if skill is None or not user_can_access_resource_owner( getattr(skill, "created_by", None), user_api_key_dict ): raise ValueError(f"Skill not found: {skill_id}") @@ -270,37 +180,17 @@ class LiteLLMSkillsHandler: skill_id: str, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> Dict[str, str]: - """ - Delete a skill by ID from the LiteLLM database. - - Args: - skill_id: The skill ID to delete - - Returns: - Dict with id and type of deleted skill - - Raises: - ValueError: If skill not found - """ prisma_client = await LiteLLMSkillsHandler._get_prisma_client() - store = LiteLLMSkillsStore(prisma_client) - verbose_logger.debug(f"LiteLLMSkillsHandler: Deleting skill {skill_id}") - # Check if skill exists skill = await LiteLLMSkillsHandler._load_skill(skill_id) - - if skill is None: - raise ValueError(f"Skill not found: {skill_id}") - - if not _user_can_access_skill_owner( + if skill is None or not user_can_access_resource_owner( getattr(skill, "created_by", None), user_api_key_dict ): raise ValueError(f"Skill not found: {skill_id}") - # Delete the skill - await store.delete_skill(skill_id) - _invalidate_skill_cache(skill_id) + await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id}) + _SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL) return {"id": skill_id, "type": "skill_deleted"} @@ -309,22 +199,11 @@ class LiteLLMSkillsHandler: skill_id: str, user_api_key_dict: Optional[UserAPIKeyAuth] = None, ) -> Optional[LiteLLM_SkillsTable]: - """ - Fetch a skill from the database (used by skills injection hook). - - This is a convenience method that returns None instead of raising - an exception if the skill is not found. - - Args: - skill_id: The skill ID to fetch - - Returns: - LiteLLM_SkillsTable or None if not found - """ + """Skills-injection-hook helper: returns None instead of raising on + not-found / not-authorized so the hook can silently skip.""" try: return await LiteLLMSkillsHandler.get_skill( - skill_id, - user_api_key_dict=user_api_key_dict, + skill_id, user_api_key_dict=user_api_key_dict ) except ValueError: return None diff --git a/litellm/llms/litellm_proxy/skills/store.py b/litellm/llms/litellm_proxy/skills/store.py deleted file mode 100644 index 6e69a7a962..0000000000 --- a/litellm/llms/litellm_proxy/skills/store.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any, Dict, List, Optional - - -class LiteLLMSkillsStore: - def __init__(self, prisma_client: Any): - self.prisma_client = prisma_client - - @property - def _table(self) -> Any: - return self.prisma_client.db.litellm_skillstable - - async def create_skill(self, data: Dict[str, Any]) -> Any: - return await self._table.create(data=data) - - async def list_skills(self, find_many_kwargs: Dict[str, Any]) -> List[Any]: - return await self._table.find_many(**find_many_kwargs) - - async def find_skill(self, skill_id: str) -> Optional[Any]: - return await self._table.find_unique(where={"skill_id": skill_id}) - - async def delete_skill(self, skill_id: str) -> None: - await self._table.delete(where={"skill_id": skill_id}) diff --git a/litellm/proxy/container_endpoints/ownership.py b/litellm/proxy/container_endpoints/ownership.py index b5b340fe2b..137366c955 100644 --- a/litellm/proxy/container_endpoints/ownership.py +++ b/litellm/proxy/container_endpoints/ownership.py @@ -1,10 +1,9 @@ -import time -from collections import OrderedDict from typing import Any, Dict, List, Optional, Set, Tuple from fastapi import HTTPException from litellm._logging import verbose_proxy_logger +from litellm.caching.in_memory_cache import InMemoryCache from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.common_utils.resource_ownership import ( get_primary_resource_owner_scope, @@ -12,75 +11,26 @@ from litellm.proxy.common_utils.resource_ownership import ( is_proxy_admin, user_can_access_resource_owner, ) -from litellm.proxy.container_endpoints.ownership_store import ( - CONTAINER_OBJECT_PURPOSE, - ContainerOwnershipStore, -) from litellm.responses.utils import ResponsesAPIRequestUtils -MAX_IN_MEMORY_CONTAINER_OWNERS = 10000 -_IN_MEMORY_CONTAINER_OWNERS: "OrderedDict[str, str]" = OrderedDict() +CONTAINER_OBJECT_PURPOSE = "container" -# Short-lived cache keeps every container access check from hitting the DB -# (`_get_container_owner` is invoked on retrieve / delete / list / file-content -# paths). Mirrors the `_byok_cred_cache` pattern in mcp_server/server.py: -# (value, monotonic_timestamp) tuples, TTL'd, LRU-evicted at capacity, -# invalidated by writes. A ``None`` value caches "untracked" so repeated -# negative lookups also avoid DB. -_CONTAINER_OWNER_CACHE: "OrderedDict[str, Tuple[Optional[str], float]]" = OrderedDict() -_CONTAINER_OWNER_CACHE_TTL = 60 # seconds -_CONTAINER_OWNER_CACHE_MAX_SIZE = 10000 - - -def _read_container_owner_cache(model_object_id: str) -> Tuple[bool, Optional[str]]: - """Return (hit, value). hit=False means caller must consult the DB.""" - entry = _CONTAINER_OWNER_CACHE.get(model_object_id) - if entry is None: - return False, None - value, timestamp = entry - if time.monotonic() - timestamp > _CONTAINER_OWNER_CACHE_TTL: - _CONTAINER_OWNER_CACHE.pop(model_object_id, None) - return False, None - _CONTAINER_OWNER_CACHE.move_to_end(model_object_id) - return True, value - - -def _write_container_owner_cache(model_object_id: str, owner: Optional[str]) -> None: - # LRU eviction (popitem(last=False)) instead of full ``clear()`` — a - # full clear at capacity converts a steady-state cached workload into - # a periodic full-DB-load oscillation as the cache repopulates from - # zero and clears again. - if model_object_id in _CONTAINER_OWNER_CACHE: - _CONTAINER_OWNER_CACHE.move_to_end(model_object_id) - _CONTAINER_OWNER_CACHE[model_object_id] = (owner, time.monotonic()) - while len(_CONTAINER_OWNER_CACHE) > _CONTAINER_OWNER_CACHE_MAX_SIZE: - _CONTAINER_OWNER_CACHE.popitem(last=False) - - -def _invalidate_container_owner_cache(model_object_id: str) -> None: - """Drop a cache entry after a write so the next read sees the new owner.""" - _CONTAINER_OWNER_CACHE.pop(model_object_id, None) - - -def _remember_container_owner(model_object_id: str, owner: str) -> None: - existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) - if existing_owner is not None: - _IN_MEMORY_CONTAINER_OWNERS.move_to_end(model_object_id) - _IN_MEMORY_CONTAINER_OWNERS[model_object_id] = owner - while len(_IN_MEMORY_CONTAINER_OWNERS) > MAX_IN_MEMORY_CONTAINER_OWNERS: - _IN_MEMORY_CONTAINER_OWNERS.popitem(last=False) +# 60s LRU/TTL cache absorbs every container access check before it reaches +# Prisma. ``_NEGATIVE_OWNER_SENTINEL`` lets us cache a true "untracked" +# answer so repeated misses also avoid the DB — ``InMemoryCache`` returns +# ``None`` indistinguishably for "miss" and "cached as None". +_NEGATIVE_OWNER_SENTINEL = "__litellm_container_no_owner__" +_CONTAINER_OWNER_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60) def _container_model_object_id( - original_container_id: str, - custom_llm_provider: str, + original_container_id: str, custom_llm_provider: str ) -> str: return f"{CONTAINER_OBJECT_PURPOSE}:{custom_llm_provider}:{original_container_id}" def decode_container_id_for_ownership( - container_id: str, - custom_llm_provider: str, + container_id: str, custom_llm_provider: str ) -> Tuple[str, str]: decoded = ResponsesAPIRequestUtils._decode_container_id(container_id) original_container_id = decoded.get("response_id", container_id) @@ -91,9 +41,7 @@ def decode_container_id_for_ownership( def get_container_forwarding_params( - container_id: str, - original_container_id: str, - custom_llm_provider: str, + container_id: str, original_container_id: str, custom_llm_provider: str ) -> Dict[str, str]: params = { "container_id": original_container_id, @@ -147,122 +95,93 @@ async def record_container_owner( ) return response if owner is None: - # Caller has no identity (no user_id / team_id / org_id / api_key / - # token) we can stamp on the row. Recording with a placeholder - # would collapse every such caller into a single shared owner — - # the cross-tenant data-access primitive we explicitly avoid. - # Reject with 403 rather than fall back to a sentinel. + # Identity-less callers (no user_id / team_id / org_id / api_key / + # token) can't be uniquely stamped on the row. Stamping a + # placeholder would collapse every such caller into a shared + # owner — the cross-tenant primitive we explicitly avoid. raise HTTPException( status_code=403, detail="Unable to record container ownership: caller has no identity scope.", ) original_container_id, resolved_provider = decode_container_id_for_ownership( - container_id, - custom_llm_provider, + container_id, custom_llm_provider ) model_object_id = _container_model_object_id( - original_container_id, - resolved_provider, + original_container_id, resolved_provider ) file_object = _dump_response(response) file_object["custom_llm_provider"] = resolved_provider file_object["provider_container_id"] = original_container_id - try: - prisma_client = await _get_prisma_client() - if prisma_client is None: - existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) - if existing_owner is not None and not user_can_access_resource_owner( - existing_owner, user_api_key_dict - ): - raise HTTPException(status_code=403, detail="Forbidden") - _remember_container_owner(model_object_id, owner) - _invalidate_container_owner_cache(model_object_id) - return response - - store = ContainerOwnershipStore(prisma_client) - existing = await store.find_by_model_object_id(model_object_id) - if existing is not None: - if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE: - raise HTTPException(status_code=500, detail="Unable to track container") - if not user_can_access_resource_owner( - getattr(existing, "created_by", None), user_api_key_dict - ): - raise HTTPException(status_code=403, detail="Forbidden") - await store.update_owner_record( - model_object_id=model_object_id, - data={ - "unified_object_id": container_id, - "file_object": file_object, - "updated_by": owner, - }, - ) - else: - await store.create_owner_record( - data={ - "unified_object_id": container_id, - "model_object_id": model_object_id, - "file_object": file_object, - "file_purpose": CONTAINER_OBJECT_PURPOSE, - "created_by": owner, - "updated_by": owner, - } - ) - except HTTPException: - raise - except Exception as e: + prisma_client = await _get_prisma_client() + if prisma_client is None: verbose_proxy_logger.warning( - "Failed to persist container ownership for container_id=%s; " - "falling back to in-process tracking: %s", - model_object_id, - e, + "Skipping container ownership tracking because prisma_client is None" ) - existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) - if existing_owner is not None and not user_can_access_resource_owner( - existing_owner, user_api_key_dict + return response + + table = prisma_client.db.litellm_managedobjecttable + existing = await table.find_unique(where={"model_object_id": model_object_id}) + if existing is not None: + if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE: + raise HTTPException(status_code=500, detail="Unable to track container") + if not user_can_access_resource_owner( + getattr(existing, "created_by", None), user_api_key_dict ): raise HTTPException(status_code=403, detail="Forbidden") - _remember_container_owner(model_object_id, owner) + await table.update( + where={"model_object_id": model_object_id}, + data={ + "unified_object_id": container_id, + "file_object": file_object, + "updated_by": owner, + }, + ) + else: + await table.create( + data={ + "unified_object_id": container_id, + "model_object_id": model_object_id, + "file_object": file_object, + "file_purpose": CONTAINER_OBJECT_PURPOSE, + "created_by": owner, + "updated_by": owner, + } + ) - _invalidate_container_owner_cache(model_object_id) + _CONTAINER_OWNER_CACHE.set_cache(model_object_id, owner) return response async def _get_container_owner( - original_container_id: str, - custom_llm_provider: str, + original_container_id: str, custom_llm_provider: str ) -> Optional[str]: model_object_id = _container_model_object_id( - original_container_id, - custom_llm_provider, + original_container_id, custom_llm_provider ) - cached_hit, cached_value = _read_container_owner_cache(model_object_id) - if cached_hit: - return cached_value + cached = _CONTAINER_OWNER_CACHE.get_cache(model_object_id) + if cached == _NEGATIVE_OWNER_SENTINEL: + return None + if cached is not None: + return cached - try: - prisma_client = await _get_prisma_client() - if prisma_client is None: - owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) - _write_container_owner_cache(model_object_id, owner) - return owner + prisma_client = await _get_prisma_client() + if prisma_client is None: + return None - owner = await ContainerOwnershipStore(prisma_client).get_owner(model_object_id) - if owner is None: - owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) - _write_container_owner_cache(model_object_id, owner) - return owner - except Exception as e: - verbose_proxy_logger.warning( - "Failed to load container ownership for container_id=%s; " - "falling back to in-process tracking: %s", - model_object_id, - e, - ) - # Don't cache transient DB errors — let the next request retry. - return _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id) + row = await prisma_client.db.litellm_managedobjecttable.find_first( + where={ + "model_object_id": model_object_id, + "file_purpose": CONTAINER_OBJECT_PURPOSE, + } + ) + owner = getattr(row, "created_by", None) if row is not None else None + _CONTAINER_OWNER_CACHE.set_cache( + model_object_id, owner if owner is not None else _NEGATIVE_OWNER_SENTINEL + ) + return owner async def assert_user_can_access_container( @@ -271,17 +190,15 @@ async def assert_user_can_access_container( custom_llm_provider: str, ) -> Tuple[str, str]: original_container_id, resolved_provider = decode_container_id_for_ownership( - container_id, - custom_llm_provider, + container_id, custom_llm_provider ) if is_proxy_admin(user_api_key_dict): return original_container_id, resolved_provider - # Untracked containers (no ownership row) are admin-only. Pre-isolation - # rows that pre-date this enforcement need the admin to either re-create - # via the now-tracked flow or explicitly assign ``created_by`` on the - # ``litellm_managedobjecttable`` row. + # Untracked rows (no ownership) are admin-only. Pre-isolation rows + # that pre-date this enforcement need an admin to either re-create + # via the now-tracked flow or assign ``created_by`` on the row. owner = await _get_container_owner(original_container_id, resolved_provider) if not user_can_access_resource_owner(owner, user_api_key_dict): raise HTTPException(status_code=403, detail="Forbidden") @@ -327,35 +244,26 @@ def _set_container_list_data( async def _get_allowed_container_ids( user_api_key_dict: UserAPIKeyAuth, - custom_llm_provider: str, ) -> Set[str]: owner_scopes = get_resource_owner_scopes(user_api_key_dict) if not owner_scopes: return set() - in_memory_allowed_ids = { - model_object_id - for model_object_id, owner in _IN_MEMORY_CONTAINER_OWNERS.items() - if owner in owner_scopes - } - try: - prisma_client = await _get_prisma_client() - if prisma_client is None: - return in_memory_allowed_ids + prisma_client = await _get_prisma_client() + if prisma_client is None: + return set() - db_allowed_ids = await ContainerOwnershipStore( - prisma_client - ).list_model_object_ids_for_owners( - owner_scopes=owner_scopes, - ) - return in_memory_allowed_ids | db_allowed_ids - except Exception as e: - verbose_proxy_logger.warning( - "Failed to load allowed container ids; falling back to in-process " - "tracking: %s", - e, - ) - return in_memory_allowed_ids + rows = await prisma_client.db.litellm_managedobjecttable.find_many( + where={ + "file_purpose": CONTAINER_OBJECT_PURPOSE, + "created_by": {"in": owner_scopes}, + } + ) + return { + row.model_object_id + for row in rows + if getattr(row, "model_object_id", None) is not None + } async def filter_container_list_response( @@ -370,18 +278,14 @@ async def filter_container_list_response( if data is None: return response - allowed_container_ids = await _get_allowed_container_ids( - user_api_key_dict, - custom_llm_provider, - ) + allowed_container_ids = await _get_allowed_container_ids(user_api_key_dict) filtered: List[Any] = [] for item in data: container_id = _get_response_id(item) if container_id is None: continue original_container_id, resolved_provider = decode_container_id_for_ownership( - container_id, - custom_llm_provider, + container_id, custom_llm_provider ) if ( _container_model_object_id(original_container_id, resolved_provider) @@ -390,7 +294,5 @@ async def filter_container_list_response( filtered.append(item) return _set_container_list_data( - response, - filtered, - removed_filtered_items=len(filtered) != len(data), + response, filtered, removed_filtered_items=len(filtered) != len(data) ) diff --git a/litellm/proxy/container_endpoints/ownership_store.py b/litellm/proxy/container_endpoints/ownership_store.py deleted file mode 100644 index b3c406f618..0000000000 --- a/litellm/proxy/container_endpoints/ownership_store.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Any, Dict, List, Optional, Set - -CONTAINER_OBJECT_PURPOSE = "container" - - -class ContainerOwnershipStore: - def __init__(self, prisma_client: Any): - self.prisma_client = prisma_client - - @property - def _table(self) -> Any: - return self.prisma_client.db.litellm_managedobjecttable - - async def find_by_model_object_id(self, model_object_id: str) -> Optional[Any]: - return await self._table.find_unique(where={"model_object_id": model_object_id}) - - async def create_owner_record(self, data: Dict[str, Any]) -> None: - await self._table.create(data=data) - - async def update_owner_record( - self, - model_object_id: str, - data: Dict[str, Any], - ) -> None: - await self._table.update( - where={"model_object_id": model_object_id}, - data=data, - ) - - async def get_owner(self, model_object_id: str) -> Optional[str]: - row = await self._table.find_first( - where={ - "model_object_id": model_object_id, - "file_purpose": CONTAINER_OBJECT_PURPOSE, - } - ) - if row is None: - return None - return getattr(row, "created_by", None) - - async def list_model_object_ids_for_owners( - self, - owner_scopes: List[str], - ) -> Set[str]: - rows = await self._table.find_many( - where={ - "file_purpose": CONTAINER_OBJECT_PURPOSE, - "created_by": {"in": owner_scopes}, - } - ) - return { - row.model_object_id - for row in rows - if getattr(row, "model_object_id", None) is not None - } diff --git a/tests/test_litellm/containers/test_container_proxy_ownership.py b/tests/test_litellm/containers/test_container_proxy_ownership.py index 850c5a96e5..b046fa1536 100644 --- a/tests/test_litellm/containers/test_container_proxy_ownership.py +++ b/tests/test_litellm/containers/test_container_proxy_ownership.py @@ -12,12 +12,12 @@ from litellm.types.containers.main import ContainerListResponse, ContainerObject @pytest.fixture(autouse=True) -def clear_in_memory_container_owners(): - ownership._IN_MEMORY_CONTAINER_OWNERS.clear() - ownership._CONTAINER_OWNER_CACHE.clear() +def clear_container_owner_cache(): + ownership._CONTAINER_OWNER_CACHE.cache_dict.clear() + ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear() yield - ownership._IN_MEMORY_CONTAINER_OWNERS.clear() - ownership._CONTAINER_OWNER_CACHE.clear() + ownership._CONTAINER_OWNER_CACHE.cache_dict.clear() + ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear() def _container(container_id: str) -> ContainerObject: @@ -159,7 +159,6 @@ async def test_should_reject_record_for_identityless_proxy_auth(monkeypatch): ) assert exc.value.status_code == 403 assert "identity scope" in str(exc.value.detail) - assert ownership._IN_MEMORY_CONTAINER_OWNERS == {} @pytest.mark.asyncio @@ -186,106 +185,6 @@ async def test_should_skip_owner_record_when_provider_response_has_no_id(monkeyp table.create.assert_not_awaited() -@pytest.mark.asyncio -async def test_should_fallback_to_memory_when_persistent_owner_record_fails( - monkeypatch, -): - table = AsyncMock() - table.find_unique.side_effect = Exception("db unavailable") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - await ownership.record_container_owner( - response=_container("cntr_provider"), - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert ( - ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_provider"] - == "user-1" - ) - - -@pytest.mark.asyncio -async def test_should_track_container_owner_in_memory_without_prisma(monkeypatch): - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=None), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - await ownership.record_container_owner( - response=_container("cntr_provider"), - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - original_id, provider = await ownership.assert_user_can_access_container( - container_id="cntr_provider", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert original_id == "cntr_provider" - assert provider == "openai" - - -@pytest.mark.asyncio -async def test_should_bound_in_memory_container_owner_tracking(monkeypatch): - monkeypatch.setattr(ownership, "MAX_IN_MEMORY_CONTAINER_OWNERS", 2) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=None), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - for container_id in ("cntr_1", "cntr_2", "cntr_3"): - await ownership.record_container_owner( - response=_container(container_id), - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert list(ownership._IN_MEMORY_CONTAINER_OWNERS.keys()) == [ - "container:openai:cntr_2", - "container:openai:cntr_3", - ] - - -@pytest.mark.asyncio -async def test_should_deny_container_access_for_different_owner(monkeypatch): - table = AsyncMock() - table.find_first.return_value = SimpleNamespace(created_by="user-2") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - with pytest.raises(HTTPException) as exc: - await ownership.assert_user_can_access_container( - container_id="cntr_provider", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert exc.value.status_code == 403 - - @pytest.mark.asyncio async def test_should_deny_untracked_container_access_by_default(monkeypatch): monkeypatch.setattr( @@ -305,103 +204,6 @@ async def test_should_deny_untracked_container_access_by_default(monkeypatch): assert exc.value.status_code == 403 -@pytest.mark.asyncio -async def test_should_fallback_to_memory_when_owner_lookup_fails(monkeypatch): - table = AsyncMock() - table.find_first.side_effect = Exception("db unavailable") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1" - auth = UserAPIKeyAuth(user_id="user-1") - - original_id, provider = await ownership.assert_user_can_access_container( - container_id="cntr_owned", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert original_id == "cntr_owned" - assert provider == "openai" - - -@pytest.mark.asyncio -async def test_should_use_memory_owner_when_db_recovers_without_row(monkeypatch): - table = AsyncMock() - table.find_first.return_value = None - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1" - auth = UserAPIKeyAuth(user_id="user-1") - - original_id, provider = await ownership.assert_user_can_access_container( - container_id="cntr_owned", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert original_id == "cntr_owned" - assert provider == "openai" - - -@pytest.mark.asyncio -async def test_should_fail_closed_when_owner_lookup_fails_without_memory(monkeypatch): - table = AsyncMock() - table.find_first.side_effect = Exception("db unavailable") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - with pytest.raises(HTTPException) as exc: - await ownership.assert_user_can_access_container( - container_id="cntr_owned", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert exc.value.status_code == 403 - - -@pytest.mark.asyncio -async def test_untracked_container_is_admin_only(monkeypatch): - """Pre-isolation containers with no ownership row are admin-only. - Non-admin callers see them as 403, with no opt-out flag re-opening - the cross-tenant access primitive.""" - from fastapi import HTTPException - - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=None), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - with pytest.raises(HTTPException) as exc: - await ownership.assert_user_can_access_container( - container_id="cntr_untracked", - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - assert exc.value.status_code == 403 - - @pytest.mark.asyncio async def test_should_not_reassign_existing_container_to_different_owner(monkeypatch): table = AsyncMock() @@ -536,165 +338,6 @@ async def test_should_clear_dict_has_more_when_filtered_container_list_is_empty( assert filtered["has_more"] is False -@pytest.mark.asyncio -async def test_should_filter_container_list_with_in_memory_ownership(monkeypatch): - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=None), - ) - auth = UserAPIKeyAuth(user_id="user-1") - - await ownership.record_container_owner( - response=_container("cntr_owned"), - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - response = ContainerListResponse( - object="list", - data=[_container("cntr_owned"), _container("cntr_other")], - has_more=False, - ) - - filtered = await ownership.filter_container_list_response( - response=response, - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert [item.id for item in filtered.data] == ["cntr_owned"] - - -@pytest.mark.asyncio -async def test_should_filter_container_list_with_memory_when_db_lookup_fails( - monkeypatch, -): - table = AsyncMock() - table.find_many.side_effect = Exception("db unavailable") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1" - auth = UserAPIKeyAuth(user_id="user-1") - response = ContainerListResponse( - object="list", - data=[_container("cntr_owned"), _container("cntr_other")], - has_more=True, - ) - - filtered = await ownership.filter_container_list_response( - response=response, - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert [item.id for item in filtered.data] == ["cntr_owned"] - assert filtered.has_more is False - - -@pytest.mark.asyncio -async def test_should_include_memory_container_list_when_db_recovers_without_row( - monkeypatch, -): - table = AsyncMock() - table.find_many.return_value = [] - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1" - auth = UserAPIKeyAuth(user_id="user-1") - response = ContainerListResponse( - object="list", - data=[_container("cntr_owned"), _container("cntr_other")], - has_more=True, - ) - - filtered = await ownership.filter_container_list_response( - response=response, - user_api_key_dict=auth, - custom_llm_provider="openai", - ) - - assert [item.id for item in filtered.data] == ["cntr_owned"] - assert filtered.has_more is False - - -@pytest.mark.asyncio -async def test_should_validate_owner_and_forward_decoded_id_for_proxy_forwarding( - monkeypatch, -): - from litellm.proxy.container_endpoints import handler_factory - - proxy_server_stub = SimpleNamespace( - general_settings={}, - llm_router=None, - proxy_config=None, - proxy_logging_obj=None, - select_data_generator=None, - user_api_base=None, - user_max_tokens=None, - user_model=None, - user_request_timeout=None, - user_temperature=None, - version="test", - ) - monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_stub) - - captured = {} - - class FakeProcessor: - def __init__(self, data): - captured["data"] = data - - async def base_process_llm_request(self, **kwargs): - return captured["data"] - - async def _handle_llm_api_exception(self, **kwargs): - raise kwargs["e"] - - monkeypatch.setattr( - handler_factory, - "ProxyBaseLLMRequestProcessing", - FakeProcessor, - ) - access_check = AsyncMock(return_value=("cntr_provider", "azure")) - monkeypatch.setattr( - handler_factory, - "assert_user_can_access_container", - access_check, - ) - encoded_id = ResponsesAPIRequestUtils._build_container_id( - custom_llm_provider="azure", - model_id="router-gpt", - container_id="cntr_provider", - ) - - result = await handler_factory._process_request( - request=SimpleNamespace(query_params={}, headers={}), - fastapi_response=SimpleNamespace(), - user_api_key_dict=UserAPIKeyAuth(user_id="user-1"), - route_type="alist_container_files", - path_params={"container_id": encoded_id}, - ) - - access_check.assert_awaited_once() - assert access_check.await_args.kwargs["container_id"] == encoded_id - assert result["container_id"] == "cntr_provider" - assert result["custom_llm_provider"] == "azure" - assert result["model_id"] == "router-gpt" - - @pytest.mark.asyncio async def test_should_validate_owner_and_forward_decoded_id_for_multipart_upload( monkeypatch, @@ -1158,93 +801,3 @@ async def test_get_container_owner_caches_negative_lookups(monkeypatch): assert await ownership._get_container_owner("cntr_x", "openai") is None assert await ownership._get_container_owner("cntr_x", "openai") is None assert table.find_first.await_count == 1 - - -@pytest.mark.asyncio -async def test_record_container_owner_invalidates_cache(monkeypatch): - """A recorded owner must drop the cached value so the next read re-fetches. - - Otherwise a stale `None` from a prior negative lookup would survive the - create and the new owner would be invisible until the TTL elapses. - """ - # Seed the cache with a stale negative result. - ownership._write_container_owner_cache("container:openai:cntr_new", None) - cached_hit, cached_value = ownership._read_container_owner_cache( - "container:openai:cntr_new" - ) - assert cached_hit and cached_value is None - - table = AsyncMock() - table.find_unique.return_value = None - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - - await ownership.record_container_owner( - response=_container("cntr_new"), - user_api_key_dict=UserAPIKeyAuth(user_id="user-1"), - custom_llm_provider="openai", - ) - - # Invalidation drops the entry — next read goes to the DB. - cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_new") - assert not cached_hit - - -@pytest.mark.asyncio -async def test_get_container_owner_does_not_cache_on_db_error(monkeypatch): - """DB errors must skip caching so transient failures don't pin a `None`.""" - table = AsyncMock() - table.find_first.side_effect = Exception("db unavailable") - prisma_client = SimpleNamespace( - db=SimpleNamespace(litellm_managedobjecttable=table) - ) - monkeypatch.setattr( - ownership, - "_get_prisma_client", - AsyncMock(return_value=prisma_client), - ) - - result = await ownership._get_container_owner("cntr_x", "openai") - assert result is None - cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_x") - assert not cached_hit - - -def test_container_owner_cache_expires_after_ttl(monkeypatch): - """Entries past the TTL count as misses so writes elsewhere are eventually - visible to this process.""" - monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_TTL", 0.0) - ownership._write_container_owner_cache("k", "user-1") - cached_hit, _ = ownership._read_container_owner_cache("k") - # TTL of 0 means anything in the cache is already stale. - assert not cached_hit - - -def test_container_owner_cache_evicts_when_at_capacity(monkeypatch): - """The cache must not grow unbounded; reaching capacity LRU-evicts the - oldest entry, not the entire cache.""" - monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2) - ownership._write_container_owner_cache("a", "user-a") - ownership._write_container_owner_cache("b", "user-b") - ownership._write_container_owner_cache("c", "user-c") - # ``a`` was the oldest and is dropped; ``b`` and ``c`` survive. - assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["b", "c"] - - -def test_container_owner_cache_read_marks_as_recently_used(monkeypatch): - """Reading an entry should reset its position so a subsequent eviction - drops a less-recently-used entry instead of the just-touched one.""" - monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2) - ownership._write_container_owner_cache("a", "user-a") - ownership._write_container_owner_cache("b", "user-b") - # Touch ``a`` so it becomes the most-recently-used. - ownership._read_container_owner_cache("a") - ownership._write_container_owner_cache("c", "user-c") - # ``b`` is the LRU at this point; ``a`` and ``c`` survive. - assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["a", "c"] diff --git a/tests/test_litellm/llms/litellm_proxy/test_skills_ownership.py b/tests/test_litellm/llms/litellm_proxy/test_skills_ownership.py index c6dd6c4e33..3ffba9723b 100644 --- a/tests/test_litellm/llms/litellm_proxy/test_skills_ownership.py +++ b/tests/test_litellm/llms/litellm_proxy/test_skills_ownership.py @@ -19,9 +19,11 @@ from litellm.skills import main as skills_main @pytest.fixture(autouse=True) def clear_skill_cache(): - skills_handler._SKILL_CACHE.clear() + skills_handler._SKILL_CACHE.cache_dict.clear() + skills_handler._SKILL_CACHE.ttl_dict.clear() yield - skills_handler._SKILL_CACHE.clear() + skills_handler._SKILL_CACHE.cache_dict.clear() + skills_handler._SKILL_CACHE.ttl_dict.clear() def _skill(skill_id: str, created_by: str | None) -> LiteLLM_SkillsTable: @@ -440,99 +442,75 @@ async def test_should_scope_skill_injection_fetch_to_authenticated_user(monkeypa @pytest.mark.asyncio async def test_load_skill_uses_cache_after_first_db_hit(monkeypatch): - """`fetch_skill_from_db` is hit per-chat-completion; the cache absorbs + """``fetch_skill_from_db`` runs per chat-completion; the cache absorbs repeats so we don't issue a Prisma query on every request.""" fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a") - store_factory = AsyncMock() - store = Mock() - store.find_skill = AsyncMock(return_value=fake_skill) + table = AsyncMock() + table.find_unique = AsyncMock(return_value=fake_skill) + prisma_client = type( + "Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()} + )() monkeypatch.setattr( skills_handler.LiteLLMSkillsHandler, "_get_prisma_client", - AsyncMock(return_value=store_factory), - ) - monkeypatch.setattr( - skills_handler, - "LiteLLMSkillsStore", - Mock(return_value=store), + AsyncMock(return_value=prisma_client), ) - first = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a") - second = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a") - third = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a") - - assert first is fake_skill - assert second is fake_skill - assert third is fake_skill - assert store.find_skill.await_count == 1 + for _ in range(3): + assert ( + await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a") + is fake_skill + ) + assert table.find_unique.await_count == 1 @pytest.mark.asyncio async def test_load_skill_caches_negative_lookups(monkeypatch): - """Missing skills must cache as `None` so repeated lookups skip the DB.""" - store_factory = AsyncMock() - store = Mock() - store.find_skill = AsyncMock(return_value=None) + """Missing skills cache as the negative sentinel so repeated misses skip + the DB and the caller still sees ``None``.""" + table = AsyncMock() + table.find_unique = AsyncMock(return_value=None) + prisma_client = type( + "Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()} + )() monkeypatch.setattr( skills_handler.LiteLLMSkillsHandler, "_get_prisma_client", - AsyncMock(return_value=store_factory), - ) - monkeypatch.setattr( - skills_handler, - "LiteLLMSkillsStore", - Mock(return_value=store), + AsyncMock(return_value=prisma_client), ) assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None - assert store.find_skill.await_count == 1 + assert table.find_unique.await_count == 1 @pytest.mark.asyncio async def test_delete_skill_invalidates_cache(monkeypatch): - """After delete, the next read must consult the DB rather than the cached - pre-delete row.""" + """After delete, the next read should not see the pre-delete cached row.""" fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a") - store = Mock() - store.find_skill = AsyncMock(return_value=fake_skill) - store.delete_skill = AsyncMock() + table = AsyncMock() + table.find_unique = AsyncMock(return_value=fake_skill) + table.delete = AsyncMock() + prisma_client = type( + "Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()} + )() monkeypatch.setattr( skills_handler.LiteLLMSkillsHandler, "_get_prisma_client", - AsyncMock(return_value=Mock()), - ) - monkeypatch.setattr( - skills_handler, - "LiteLLMSkillsStore", - Mock(return_value=store), + AsyncMock(return_value=prisma_client), ) # Prime the cache via the read path. await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a") - cached_hit, _ = skills_handler._read_skill_cache("litellm_skill_a") - assert cached_hit + assert skills_handler._SKILL_CACHE.get_cache("litellm_skill_a") is fake_skill auth = UserAPIKeyAuth(user_id="user-1") await skills_handler.LiteLLMSkillsHandler.delete_skill( "litellm_skill_a", user_api_key_dict=auth ) - cached_hit_after, _ = skills_handler._read_skill_cache("litellm_skill_a") - assert not cached_hit_after - - -def test_skill_cache_expires_after_ttl(monkeypatch): - monkeypatch.setattr(skills_handler, "_SKILL_CACHE_TTL", 0.0) - skills_handler._write_skill_cache("k", Mock()) - cached_hit, _ = skills_handler._read_skill_cache("k") - assert not cached_hit - - -def test_skill_cache_evicts_when_at_capacity(monkeypatch): - monkeypatch.setattr(skills_handler, "_SKILL_CACHE_MAX_SIZE", 2) - skills_handler._write_skill_cache("a", Mock()) - skills_handler._write_skill_cache("b", Mock()) - skills_handler._write_skill_cache("c", Mock()) - # ``a`` was the oldest and is LRU-evicted; ``b`` and ``c`` survive. - assert list(skills_handler._SKILL_CACHE.keys()) == ["b", "c"] + # Post-delete, the cache holds the negative sentinel — not the stale row. + assert ( + skills_handler._SKILL_CACHE.get_cache("litellm_skill_a") + == skills_handler._NEGATIVE_SKILL_SENTINEL + )