chore: simplify ownership tracking — drop thin stores, in-memory fallback, hand-rolled cache

Substantial reduction (~765 LOC) without changing the security
boundary:

* Drop ContainerOwnershipStore and LiteLLMSkillsStore — both were
  one-method-per-Prisma-call wrappers. Inline the calls instead,
  matching the established pattern in vector_store_endpoints,
  agent_endpoints, and mcp_server/db.py.

* Drop the prisma_client is None in-memory fallback. Production
  deploys always have Prisma; running ownership-critical paths on a
  process-local dict is a security footgun in the dev-mode case it
  was meant to support, and complicates every code path with a
  branch. Fail-secure: skip recording if Prisma is unavailable, and
  treat reads as "not found" (admin-only).

* Drop the hand-rolled module-level cache. Replace with the existing
  litellm.caching.in_memory_cache.InMemoryCache, which already has
  TTL + max-size + eviction tested in its own module. Sentinel string
  for negative caching since InMemoryCache can't disambiguate "miss"
  from "cached as None".

* Tests: drop coverage for removed code paths (in-memory fallback,
  hand-rolled cache internals). Keep tests for actual behavior (cache
  hit-rate, negative caching, owner check, list filtering,
  identity-less reject, admin bypass).
This commit is contained in:
user 2026-05-05 00:23:32 +00:00
parent 12fe945e7b
commit 6ce84effe1
No known key found for this signature in database
6 changed files with 180 additions and 945 deletions

View File

@ -5,13 +5,11 @@ This module contains the actual database operations for skills CRUD.
Used by the transformation layer and skills injection hook.
"""
import time
import uuid
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.litellm_proxy.skills.store import LiteLLMSkillsStore
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.proxy._types import LiteLLM_SkillsTable, NewSkillRequest, UserAPIKeyAuth
from litellm.proxy.common_utils.resource_ownership import (
get_primary_resource_owner_scope,
@ -21,88 +19,37 @@ from litellm.proxy.common_utils.resource_ownership import (
)
# Skills are looked up on every chat completion that has skills enabled
# (`LiteLLMSkillsHandler.fetch_skill_from_db` in the injection hook). Cache
# the Prisma skill row for a short window so the hot path doesn't issue a DB
# round-trip per request. Same shape as `_byok_cred_cache` and the container
# ownership cache: (value, monotonic_timestamp). `None` is cached as a true
# negative ("skill does not exist") so repeated misses also avoid the DB.
_SKILL_CACHE: "OrderedDict[str, Tuple[Optional[Any], float]]" = OrderedDict()
_SKILL_CACHE_TTL = 60 # seconds
_SKILL_CACHE_MAX_SIZE = 10000
def _read_skill_cache(skill_id: str) -> Tuple[bool, Optional[Any]]:
"""Return (hit, value). hit=False means caller must consult the DB."""
entry = _SKILL_CACHE.get(skill_id)
if entry is None:
return False, None
value, timestamp = entry
if time.monotonic() - timestamp > _SKILL_CACHE_TTL:
_SKILL_CACHE.pop(skill_id, None)
return False, None
_SKILL_CACHE.move_to_end(skill_id)
return True, value
def _write_skill_cache(skill_id: str, skill: Optional[Any]) -> None:
# LRU eviction (popitem(last=False)) instead of full ``clear()`` —
# see container ownership cache for rationale.
if skill_id in _SKILL_CACHE:
_SKILL_CACHE.move_to_end(skill_id)
_SKILL_CACHE[skill_id] = (skill, time.monotonic())
while len(_SKILL_CACHE) > _SKILL_CACHE_MAX_SIZE:
_SKILL_CACHE.popitem(last=False)
def _invalidate_skill_cache(skill_id: str) -> None:
"""Drop the cache entry after a write so the next read sees the new row."""
_SKILL_CACHE.pop(skill_id, None)
def _user_can_access_skill_owner(
owner: Optional[str],
user_api_key_dict: Optional[UserAPIKeyAuth],
) -> bool:
# Pre-isolation skills with no ``created_by`` are admin-only — same
# rule as untracked containers. Owners need to either re-create via
# the now-tracked flow or have an admin assign ``created_by`` on the
# row.
return user_can_access_resource_owner(owner, user_api_key_dict)
# (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache
# absorbs the hot read before it reaches Prisma. ``_NEGATIVE_SKILL_SENTINEL``
# lets us cache a true "skill does not exist" so repeated misses also
# avoid the DB — ``InMemoryCache`` returns ``None`` indistinguishably for
# "miss" and "cached as None".
_NEGATIVE_SKILL_SENTINEL = "__litellm_skill_not_found__"
_SKILL_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)
def _prisma_skill_to_litellm(prisma_skill) -> LiteLLM_SkillsTable:
"""
Convert a Prisma skill record to LiteLLM_SkillsTable.
"""Convert a Prisma skill record to LiteLLM_SkillsTable.
Handles Base64 decoding of file_content field.
Handles Base64 decoding of file_content field model_dump() converts
Base64 fields to base64-encoded strings.
"""
import base64
data = prisma_skill.model_dump()
# Decode Base64 file_content back to bytes
# model_dump() converts Base64 field to base64-encoded string
if data.get("file_content") is not None:
if isinstance(data["file_content"], str):
data["file_content"] = base64.b64decode(data["file_content"])
elif isinstance(data["file_content"], bytes):
# Already bytes, no conversion needed
pass
return LiteLLM_SkillsTable(**data)
class LiteLLMSkillsHandler:
"""
Handler for LiteLLM database-backed skills operations.
This class provides static methods for CRUD operations on skills
stored in the LiteLLM proxy database (LiteLLM_SkillsTable).
"""
"""CRUD for skills stored in ``litellm_skillstable``."""
@staticmethod
async def _get_prisma_client():
"""Get the prisma client from proxy server."""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
@ -118,28 +65,16 @@ class LiteLLMSkillsHandler:
user_id: Optional[str] = None,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> LiteLLM_SkillsTable:
"""
Create a new skill in the LiteLLM database.
Args:
data: NewSkillRequest with skill details
user_id: Optional user ID for tracking
Returns:
LiteLLM_SkillsTable record
"""
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
store = LiteLLMSkillsStore(prisma_client)
skill_id = f"litellm_skill_{uuid.uuid4()}"
owner = get_primary_resource_owner_scope(user_api_key_dict) or user_id
if owner is None:
# Caller has no identity scope (no user_id / team_id / org_id /
# api_key / token). Stamping a placeholder would let any two
# identity-less callers see each other's skills via the shared
# owner — the cross-tenant primitive we avoid. ValueError keeps
# this module FastAPI-free per the project layering rule
# (litellm_proxy provider integrations live outside proxy/).
# Identity-less callers (no user_id / team_id / org_id /
# api_key / token) can't be uniquely stamped on the row.
# Stamping a placeholder would let any two such callers see
# each other's skills via the shared owner. ValueError keeps
# this module FastAPI-free per the project layering rule.
raise ValueError(
"Unable to record skill ownership: caller has no identity scope."
)
@ -154,13 +89,11 @@ class LiteLLMSkillsHandler:
"updated_by": owner,
}
# Handle metadata
if data.metadata is not None:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
skill_data["metadata"] = safe_dumps(data.metadata)
# Handle file content - wrap bytes in Base64 for Prisma
if data.file_content is not None:
from prisma.fields import Base64
@ -174,8 +107,7 @@ class LiteLLMSkillsHandler:
f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}"
)
new_skill = await store.create_skill(skill_data)
new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data)
return _prisma_skill_to_litellm(new_skill)
@staticmethod
@ -184,18 +116,7 @@ class LiteLLMSkillsHandler:
offset: int = 0,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> List[LiteLLM_SkillsTable]:
"""
List skills from the LiteLLM database.
Args:
limit: Maximum number of skills to return
offset: Number of skills to skip
Returns:
List of LiteLLM_SkillsTable records
"""
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
store = LiteLLMSkillsStore(prisma_client)
verbose_logger.debug(
f"LiteLLMSkillsHandler: Listing skills with limit={limit}, offset={offset}"
@ -212,26 +133,29 @@ class LiteLLMSkillsHandler:
return []
find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}}
skills = await store.list_skills(find_many_kwargs)
skills = await prisma_client.db.litellm_skillstable.find_many(
**find_many_kwargs
)
return [_prisma_skill_to_litellm(s) for s in skills]
@staticmethod
async def _load_skill(skill_id: str) -> Optional[Any]:
"""Cache-first read of the Prisma skill row.
Caching here keeps `fetch_skill_from_db` (called per chat completion in
the skills injection hook) off the DB. Owner-scope filtering happens
on the cached row, so the cache is per-skill and not per-caller.
"""Cache-first read of the Prisma skill row. Owner-scope filtering
happens on the cached row, so the cache is per-skill not per-caller.
"""
cached_hit, cached_skill = _read_skill_cache(skill_id)
if cached_hit:
return cached_skill
cached = _SKILL_CACHE.get_cache(skill_id)
if cached == _NEGATIVE_SKILL_SENTINEL:
return None
if cached is not None:
return cached
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
store = LiteLLMSkillsStore(prisma_client)
skill = await store.find_skill(skill_id)
_write_skill_cache(skill_id, skill)
skill = await prisma_client.db.litellm_skillstable.find_unique(
where={"skill_id": skill_id}
)
_SKILL_CACHE.set_cache(
skill_id, skill if skill is not None else _NEGATIVE_SKILL_SENTINEL
)
return skill
@staticmethod
@ -239,26 +163,12 @@ class LiteLLMSkillsHandler:
skill_id: str,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> LiteLLM_SkillsTable:
"""
Get a skill by ID from the LiteLLM database.
Args:
skill_id: The skill ID to retrieve
Returns:
LiteLLM_SkillsTable record
Raises:
ValueError: If skill not found
"""
verbose_logger.debug(f"LiteLLMSkillsHandler: Getting skill {skill_id}")
skill = await LiteLLMSkillsHandler._load_skill(skill_id)
if skill is None:
raise ValueError(f"Skill not found: {skill_id}")
if not _user_can_access_skill_owner(
# Same "not found" message for both "missing" and "cross-tenant"
# so callers can't enumerate skill IDs they don't own.
if skill is None or not user_can_access_resource_owner(
getattr(skill, "created_by", None), user_api_key_dict
):
raise ValueError(f"Skill not found: {skill_id}")
@ -270,37 +180,17 @@ class LiteLLMSkillsHandler:
skill_id: str,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> Dict[str, str]:
"""
Delete a skill by ID from the LiteLLM database.
Args:
skill_id: The skill ID to delete
Returns:
Dict with id and type of deleted skill
Raises:
ValueError: If skill not found
"""
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
store = LiteLLMSkillsStore(prisma_client)
verbose_logger.debug(f"LiteLLMSkillsHandler: Deleting skill {skill_id}")
# Check if skill exists
skill = await LiteLLMSkillsHandler._load_skill(skill_id)
if skill is None:
raise ValueError(f"Skill not found: {skill_id}")
if not _user_can_access_skill_owner(
if skill is None or not user_can_access_resource_owner(
getattr(skill, "created_by", None), user_api_key_dict
):
raise ValueError(f"Skill not found: {skill_id}")
# Delete the skill
await store.delete_skill(skill_id)
_invalidate_skill_cache(skill_id)
await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id})
_SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL)
return {"id": skill_id, "type": "skill_deleted"}
@ -309,22 +199,11 @@ class LiteLLMSkillsHandler:
skill_id: str,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> Optional[LiteLLM_SkillsTable]:
"""
Fetch a skill from the database (used by skills injection hook).
This is a convenience method that returns None instead of raising
an exception if the skill is not found.
Args:
skill_id: The skill ID to fetch
Returns:
LiteLLM_SkillsTable or None if not found
"""
"""Skills-injection-hook helper: returns None instead of raising on
not-found / not-authorized so the hook can silently skip."""
try:
return await LiteLLMSkillsHandler.get_skill(
skill_id,
user_api_key_dict=user_api_key_dict,
skill_id, user_api_key_dict=user_api_key_dict
)
except ValueError:
return None

View File

@ -1,22 +0,0 @@
from typing import Any, Dict, List, Optional
class LiteLLMSkillsStore:
def __init__(self, prisma_client: Any):
self.prisma_client = prisma_client
@property
def _table(self) -> Any:
return self.prisma_client.db.litellm_skillstable
async def create_skill(self, data: Dict[str, Any]) -> Any:
return await self._table.create(data=data)
async def list_skills(self, find_many_kwargs: Dict[str, Any]) -> List[Any]:
return await self._table.find_many(**find_many_kwargs)
async def find_skill(self, skill_id: str) -> Optional[Any]:
return await self._table.find_unique(where={"skill_id": skill_id})
async def delete_skill(self, skill_id: str) -> None:
await self._table.delete(where={"skill_id": skill_id})

View File

@ -1,10 +1,9 @@
import time
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Tuple
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.common_utils.resource_ownership import (
get_primary_resource_owner_scope,
@ -12,75 +11,26 @@ from litellm.proxy.common_utils.resource_ownership import (
is_proxy_admin,
user_can_access_resource_owner,
)
from litellm.proxy.container_endpoints.ownership_store import (
CONTAINER_OBJECT_PURPOSE,
ContainerOwnershipStore,
)
from litellm.responses.utils import ResponsesAPIRequestUtils
MAX_IN_MEMORY_CONTAINER_OWNERS = 10000
_IN_MEMORY_CONTAINER_OWNERS: "OrderedDict[str, str]" = OrderedDict()
CONTAINER_OBJECT_PURPOSE = "container"
# Short-lived cache keeps every container access check from hitting the DB
# (`_get_container_owner` is invoked on retrieve / delete / list / file-content
# paths). Mirrors the `_byok_cred_cache` pattern in mcp_server/server.py:
# (value, monotonic_timestamp) tuples, TTL'd, LRU-evicted at capacity,
# invalidated by writes. A ``None`` value caches "untracked" so repeated
# negative lookups also avoid DB.
_CONTAINER_OWNER_CACHE: "OrderedDict[str, Tuple[Optional[str], float]]" = OrderedDict()
_CONTAINER_OWNER_CACHE_TTL = 60 # seconds
_CONTAINER_OWNER_CACHE_MAX_SIZE = 10000
def _read_container_owner_cache(model_object_id: str) -> Tuple[bool, Optional[str]]:
"""Return (hit, value). hit=False means caller must consult the DB."""
entry = _CONTAINER_OWNER_CACHE.get(model_object_id)
if entry is None:
return False, None
value, timestamp = entry
if time.monotonic() - timestamp > _CONTAINER_OWNER_CACHE_TTL:
_CONTAINER_OWNER_CACHE.pop(model_object_id, None)
return False, None
_CONTAINER_OWNER_CACHE.move_to_end(model_object_id)
return True, value
def _write_container_owner_cache(model_object_id: str, owner: Optional[str]) -> None:
# LRU eviction (popitem(last=False)) instead of full ``clear()`` — a
# full clear at capacity converts a steady-state cached workload into
# a periodic full-DB-load oscillation as the cache repopulates from
# zero and clears again.
if model_object_id in _CONTAINER_OWNER_CACHE:
_CONTAINER_OWNER_CACHE.move_to_end(model_object_id)
_CONTAINER_OWNER_CACHE[model_object_id] = (owner, time.monotonic())
while len(_CONTAINER_OWNER_CACHE) > _CONTAINER_OWNER_CACHE_MAX_SIZE:
_CONTAINER_OWNER_CACHE.popitem(last=False)
def _invalidate_container_owner_cache(model_object_id: str) -> None:
"""Drop a cache entry after a write so the next read sees the new owner."""
_CONTAINER_OWNER_CACHE.pop(model_object_id, None)
def _remember_container_owner(model_object_id: str, owner: str) -> None:
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
if existing_owner is not None:
_IN_MEMORY_CONTAINER_OWNERS.move_to_end(model_object_id)
_IN_MEMORY_CONTAINER_OWNERS[model_object_id] = owner
while len(_IN_MEMORY_CONTAINER_OWNERS) > MAX_IN_MEMORY_CONTAINER_OWNERS:
_IN_MEMORY_CONTAINER_OWNERS.popitem(last=False)
# 60s LRU/TTL cache absorbs every container access check before it reaches
# Prisma. ``_NEGATIVE_OWNER_SENTINEL`` lets us cache a true "untracked"
# answer so repeated misses also avoid the DB — ``InMemoryCache`` returns
# ``None`` indistinguishably for "miss" and "cached as None".
_NEGATIVE_OWNER_SENTINEL = "__litellm_container_no_owner__"
_CONTAINER_OWNER_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)
def _container_model_object_id(
original_container_id: str,
custom_llm_provider: str,
original_container_id: str, custom_llm_provider: str
) -> str:
return f"{CONTAINER_OBJECT_PURPOSE}:{custom_llm_provider}:{original_container_id}"
def decode_container_id_for_ownership(
container_id: str,
custom_llm_provider: str,
container_id: str, custom_llm_provider: str
) -> Tuple[str, str]:
decoded = ResponsesAPIRequestUtils._decode_container_id(container_id)
original_container_id = decoded.get("response_id", container_id)
@ -91,9 +41,7 @@ def decode_container_id_for_ownership(
def get_container_forwarding_params(
container_id: str,
original_container_id: str,
custom_llm_provider: str,
container_id: str, original_container_id: str, custom_llm_provider: str
) -> Dict[str, str]:
params = {
"container_id": original_container_id,
@ -147,122 +95,93 @@ async def record_container_owner(
)
return response
if owner is None:
# Caller has no identity (no user_id / team_id / org_id / api_key /
# token) we can stamp on the row. Recording with a placeholder
# would collapse every such caller into a single shared owner —
# the cross-tenant data-access primitive we explicitly avoid.
# Reject with 403 rather than fall back to a sentinel.
# Identity-less callers (no user_id / team_id / org_id / api_key /
# token) can't be uniquely stamped on the row. Stamping a
# placeholder would collapse every such caller into a shared
# owner — the cross-tenant primitive we explicitly avoid.
raise HTTPException(
status_code=403,
detail="Unable to record container ownership: caller has no identity scope.",
)
original_container_id, resolved_provider = decode_container_id_for_ownership(
container_id,
custom_llm_provider,
container_id, custom_llm_provider
)
model_object_id = _container_model_object_id(
original_container_id,
resolved_provider,
original_container_id, resolved_provider
)
file_object = _dump_response(response)
file_object["custom_llm_provider"] = resolved_provider
file_object["provider_container_id"] = original_container_id
try:
prisma_client = await _get_prisma_client()
if prisma_client is None:
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
if existing_owner is not None and not user_can_access_resource_owner(
existing_owner, user_api_key_dict
):
raise HTTPException(status_code=403, detail="Forbidden")
_remember_container_owner(model_object_id, owner)
_invalidate_container_owner_cache(model_object_id)
return response
store = ContainerOwnershipStore(prisma_client)
existing = await store.find_by_model_object_id(model_object_id)
if existing is not None:
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
raise HTTPException(status_code=500, detail="Unable to track container")
if not user_can_access_resource_owner(
getattr(existing, "created_by", None), user_api_key_dict
):
raise HTTPException(status_code=403, detail="Forbidden")
await store.update_owner_record(
model_object_id=model_object_id,
data={
"unified_object_id": container_id,
"file_object": file_object,
"updated_by": owner,
},
)
else:
await store.create_owner_record(
data={
"unified_object_id": container_id,
"model_object_id": model_object_id,
"file_object": file_object,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": owner,
"updated_by": owner,
}
)
except HTTPException:
raise
except Exception as e:
prisma_client = await _get_prisma_client()
if prisma_client is None:
verbose_proxy_logger.warning(
"Failed to persist container ownership for container_id=%s; "
"falling back to in-process tracking: %s",
model_object_id,
e,
"Skipping container ownership tracking because prisma_client is None"
)
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
if existing_owner is not None and not user_can_access_resource_owner(
existing_owner, user_api_key_dict
return response
table = prisma_client.db.litellm_managedobjecttable
existing = await table.find_unique(where={"model_object_id": model_object_id})
if existing is not None:
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
raise HTTPException(status_code=500, detail="Unable to track container")
if not user_can_access_resource_owner(
getattr(existing, "created_by", None), user_api_key_dict
):
raise HTTPException(status_code=403, detail="Forbidden")
_remember_container_owner(model_object_id, owner)
await table.update(
where={"model_object_id": model_object_id},
data={
"unified_object_id": container_id,
"file_object": file_object,
"updated_by": owner,
},
)
else:
await table.create(
data={
"unified_object_id": container_id,
"model_object_id": model_object_id,
"file_object": file_object,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": owner,
"updated_by": owner,
}
)
_invalidate_container_owner_cache(model_object_id)
_CONTAINER_OWNER_CACHE.set_cache(model_object_id, owner)
return response
async def _get_container_owner(
original_container_id: str,
custom_llm_provider: str,
original_container_id: str, custom_llm_provider: str
) -> Optional[str]:
model_object_id = _container_model_object_id(
original_container_id,
custom_llm_provider,
original_container_id, custom_llm_provider
)
cached_hit, cached_value = _read_container_owner_cache(model_object_id)
if cached_hit:
return cached_value
cached = _CONTAINER_OWNER_CACHE.get_cache(model_object_id)
if cached == _NEGATIVE_OWNER_SENTINEL:
return None
if cached is not None:
return cached
try:
prisma_client = await _get_prisma_client()
if prisma_client is None:
owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
_write_container_owner_cache(model_object_id, owner)
return owner
prisma_client = await _get_prisma_client()
if prisma_client is None:
return None
owner = await ContainerOwnershipStore(prisma_client).get_owner(model_object_id)
if owner is None:
owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
_write_container_owner_cache(model_object_id, owner)
return owner
except Exception as e:
verbose_proxy_logger.warning(
"Failed to load container ownership for container_id=%s; "
"falling back to in-process tracking: %s",
model_object_id,
e,
)
# Don't cache transient DB errors — let the next request retry.
return _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
row = await prisma_client.db.litellm_managedobjecttable.find_first(
where={
"model_object_id": model_object_id,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
}
)
owner = getattr(row, "created_by", None) if row is not None else None
_CONTAINER_OWNER_CACHE.set_cache(
model_object_id, owner if owner is not None else _NEGATIVE_OWNER_SENTINEL
)
return owner
async def assert_user_can_access_container(
@ -271,17 +190,15 @@ async def assert_user_can_access_container(
custom_llm_provider: str,
) -> Tuple[str, str]:
original_container_id, resolved_provider = decode_container_id_for_ownership(
container_id,
custom_llm_provider,
container_id, custom_llm_provider
)
if is_proxy_admin(user_api_key_dict):
return original_container_id, resolved_provider
# Untracked containers (no ownership row) are admin-only. Pre-isolation
# rows that pre-date this enforcement need the admin to either re-create
# via the now-tracked flow or explicitly assign ``created_by`` on the
# ``litellm_managedobjecttable`` row.
# Untracked rows (no ownership) are admin-only. Pre-isolation rows
# that pre-date this enforcement need an admin to either re-create
# via the now-tracked flow or assign ``created_by`` on the row.
owner = await _get_container_owner(original_container_id, resolved_provider)
if not user_can_access_resource_owner(owner, user_api_key_dict):
raise HTTPException(status_code=403, detail="Forbidden")
@ -327,35 +244,26 @@ def _set_container_list_data(
async def _get_allowed_container_ids(
user_api_key_dict: UserAPIKeyAuth,
custom_llm_provider: str,
) -> Set[str]:
owner_scopes = get_resource_owner_scopes(user_api_key_dict)
if not owner_scopes:
return set()
in_memory_allowed_ids = {
model_object_id
for model_object_id, owner in _IN_MEMORY_CONTAINER_OWNERS.items()
if owner in owner_scopes
}
try:
prisma_client = await _get_prisma_client()
if prisma_client is None:
return in_memory_allowed_ids
prisma_client = await _get_prisma_client()
if prisma_client is None:
return set()
db_allowed_ids = await ContainerOwnershipStore(
prisma_client
).list_model_object_ids_for_owners(
owner_scopes=owner_scopes,
)
return in_memory_allowed_ids | db_allowed_ids
except Exception as e:
verbose_proxy_logger.warning(
"Failed to load allowed container ids; falling back to in-process "
"tracking: %s",
e,
)
return in_memory_allowed_ids
rows = await prisma_client.db.litellm_managedobjecttable.find_many(
where={
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": {"in": owner_scopes},
}
)
return {
row.model_object_id
for row in rows
if getattr(row, "model_object_id", None) is not None
}
async def filter_container_list_response(
@ -370,18 +278,14 @@ async def filter_container_list_response(
if data is None:
return response
allowed_container_ids = await _get_allowed_container_ids(
user_api_key_dict,
custom_llm_provider,
)
allowed_container_ids = await _get_allowed_container_ids(user_api_key_dict)
filtered: List[Any] = []
for item in data:
container_id = _get_response_id(item)
if container_id is None:
continue
original_container_id, resolved_provider = decode_container_id_for_ownership(
container_id,
custom_llm_provider,
container_id, custom_llm_provider
)
if (
_container_model_object_id(original_container_id, resolved_provider)
@ -390,7 +294,5 @@ async def filter_container_list_response(
filtered.append(item)
return _set_container_list_data(
response,
filtered,
removed_filtered_items=len(filtered) != len(data),
response, filtered, removed_filtered_items=len(filtered) != len(data)
)

View File

@ -1,55 +0,0 @@
from typing import Any, Dict, List, Optional, Set
CONTAINER_OBJECT_PURPOSE = "container"
class ContainerOwnershipStore:
def __init__(self, prisma_client: Any):
self.prisma_client = prisma_client
@property
def _table(self) -> Any:
return self.prisma_client.db.litellm_managedobjecttable
async def find_by_model_object_id(self, model_object_id: str) -> Optional[Any]:
return await self._table.find_unique(where={"model_object_id": model_object_id})
async def create_owner_record(self, data: Dict[str, Any]) -> None:
await self._table.create(data=data)
async def update_owner_record(
self,
model_object_id: str,
data: Dict[str, Any],
) -> None:
await self._table.update(
where={"model_object_id": model_object_id},
data=data,
)
async def get_owner(self, model_object_id: str) -> Optional[str]:
row = await self._table.find_first(
where={
"model_object_id": model_object_id,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
}
)
if row is None:
return None
return getattr(row, "created_by", None)
async def list_model_object_ids_for_owners(
self,
owner_scopes: List[str],
) -> Set[str]:
rows = await self._table.find_many(
where={
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": {"in": owner_scopes},
}
)
return {
row.model_object_id
for row in rows
if getattr(row, "model_object_id", None) is not None
}

View File

@ -12,12 +12,12 @@ from litellm.types.containers.main import ContainerListResponse, ContainerObject
@pytest.fixture(autouse=True)
def clear_in_memory_container_owners():
ownership._IN_MEMORY_CONTAINER_OWNERS.clear()
ownership._CONTAINER_OWNER_CACHE.clear()
def clear_container_owner_cache():
ownership._CONTAINER_OWNER_CACHE.cache_dict.clear()
ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear()
yield
ownership._IN_MEMORY_CONTAINER_OWNERS.clear()
ownership._CONTAINER_OWNER_CACHE.clear()
ownership._CONTAINER_OWNER_CACHE.cache_dict.clear()
ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear()
def _container(container_id: str) -> ContainerObject:
@ -159,7 +159,6 @@ async def test_should_reject_record_for_identityless_proxy_auth(monkeypatch):
)
assert exc.value.status_code == 403
assert "identity scope" in str(exc.value.detail)
assert ownership._IN_MEMORY_CONTAINER_OWNERS == {}
@pytest.mark.asyncio
@ -186,106 +185,6 @@ async def test_should_skip_owner_record_when_provider_response_has_no_id(monkeyp
table.create.assert_not_awaited()
@pytest.mark.asyncio
async def test_should_fallback_to_memory_when_persistent_owner_record_fails(
monkeypatch,
):
table = AsyncMock()
table.find_unique.side_effect = Exception("db unavailable")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(user_id="user-1")
await ownership.record_container_owner(
response=_container("cntr_provider"),
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert (
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_provider"]
== "user-1"
)
@pytest.mark.asyncio
async def test_should_track_container_owner_in_memory_without_prisma(monkeypatch):
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=None),
)
auth = UserAPIKeyAuth(user_id="user-1")
await ownership.record_container_owner(
response=_container("cntr_provider"),
user_api_key_dict=auth,
custom_llm_provider="openai",
)
original_id, provider = await ownership.assert_user_can_access_container(
container_id="cntr_provider",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert original_id == "cntr_provider"
assert provider == "openai"
@pytest.mark.asyncio
async def test_should_bound_in_memory_container_owner_tracking(monkeypatch):
monkeypatch.setattr(ownership, "MAX_IN_MEMORY_CONTAINER_OWNERS", 2)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=None),
)
auth = UserAPIKeyAuth(user_id="user-1")
for container_id in ("cntr_1", "cntr_2", "cntr_3"):
await ownership.record_container_owner(
response=_container(container_id),
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert list(ownership._IN_MEMORY_CONTAINER_OWNERS.keys()) == [
"container:openai:cntr_2",
"container:openai:cntr_3",
]
@pytest.mark.asyncio
async def test_should_deny_container_access_for_different_owner(monkeypatch):
table = AsyncMock()
table.find_first.return_value = SimpleNamespace(created_by="user-2")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(user_id="user-1")
with pytest.raises(HTTPException) as exc:
await ownership.assert_user_can_access_container(
container_id="cntr_provider",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_should_deny_untracked_container_access_by_default(monkeypatch):
monkeypatch.setattr(
@ -305,103 +204,6 @@ async def test_should_deny_untracked_container_access_by_default(monkeypatch):
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_should_fallback_to_memory_when_owner_lookup_fails(monkeypatch):
table = AsyncMock()
table.find_first.side_effect = Exception("db unavailable")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
auth = UserAPIKeyAuth(user_id="user-1")
original_id, provider = await ownership.assert_user_can_access_container(
container_id="cntr_owned",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert original_id == "cntr_owned"
assert provider == "openai"
@pytest.mark.asyncio
async def test_should_use_memory_owner_when_db_recovers_without_row(monkeypatch):
table = AsyncMock()
table.find_first.return_value = None
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
auth = UserAPIKeyAuth(user_id="user-1")
original_id, provider = await ownership.assert_user_can_access_container(
container_id="cntr_owned",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert original_id == "cntr_owned"
assert provider == "openai"
@pytest.mark.asyncio
async def test_should_fail_closed_when_owner_lookup_fails_without_memory(monkeypatch):
table = AsyncMock()
table.find_first.side_effect = Exception("db unavailable")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(user_id="user-1")
with pytest.raises(HTTPException) as exc:
await ownership.assert_user_can_access_container(
container_id="cntr_owned",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_untracked_container_is_admin_only(monkeypatch):
"""Pre-isolation containers with no ownership row are admin-only.
Non-admin callers see them as 403, with no opt-out flag re-opening
the cross-tenant access primitive."""
from fastapi import HTTPException
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=None),
)
auth = UserAPIKeyAuth(user_id="user-1")
with pytest.raises(HTTPException) as exc:
await ownership.assert_user_can_access_container(
container_id="cntr_untracked",
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_should_not_reassign_existing_container_to_different_owner(monkeypatch):
table = AsyncMock()
@ -536,165 +338,6 @@ async def test_should_clear_dict_has_more_when_filtered_container_list_is_empty(
assert filtered["has_more"] is False
@pytest.mark.asyncio
async def test_should_filter_container_list_with_in_memory_ownership(monkeypatch):
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=None),
)
auth = UserAPIKeyAuth(user_id="user-1")
await ownership.record_container_owner(
response=_container("cntr_owned"),
user_api_key_dict=auth,
custom_llm_provider="openai",
)
response = ContainerListResponse(
object="list",
data=[_container("cntr_owned"), _container("cntr_other")],
has_more=False,
)
filtered = await ownership.filter_container_list_response(
response=response,
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert [item.id for item in filtered.data] == ["cntr_owned"]
@pytest.mark.asyncio
async def test_should_filter_container_list_with_memory_when_db_lookup_fails(
monkeypatch,
):
table = AsyncMock()
table.find_many.side_effect = Exception("db unavailable")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
auth = UserAPIKeyAuth(user_id="user-1")
response = ContainerListResponse(
object="list",
data=[_container("cntr_owned"), _container("cntr_other")],
has_more=True,
)
filtered = await ownership.filter_container_list_response(
response=response,
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert [item.id for item in filtered.data] == ["cntr_owned"]
assert filtered.has_more is False
@pytest.mark.asyncio
async def test_should_include_memory_container_list_when_db_recovers_without_row(
monkeypatch,
):
table = AsyncMock()
table.find_many.return_value = []
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
auth = UserAPIKeyAuth(user_id="user-1")
response = ContainerListResponse(
object="list",
data=[_container("cntr_owned"), _container("cntr_other")],
has_more=True,
)
filtered = await ownership.filter_container_list_response(
response=response,
user_api_key_dict=auth,
custom_llm_provider="openai",
)
assert [item.id for item in filtered.data] == ["cntr_owned"]
assert filtered.has_more is False
@pytest.mark.asyncio
async def test_should_validate_owner_and_forward_decoded_id_for_proxy_forwarding(
monkeypatch,
):
from litellm.proxy.container_endpoints import handler_factory
proxy_server_stub = SimpleNamespace(
general_settings={},
llm_router=None,
proxy_config=None,
proxy_logging_obj=None,
select_data_generator=None,
user_api_base=None,
user_max_tokens=None,
user_model=None,
user_request_timeout=None,
user_temperature=None,
version="test",
)
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_stub)
captured = {}
class FakeProcessor:
def __init__(self, data):
captured["data"] = data
async def base_process_llm_request(self, **kwargs):
return captured["data"]
async def _handle_llm_api_exception(self, **kwargs):
raise kwargs["e"]
monkeypatch.setattr(
handler_factory,
"ProxyBaseLLMRequestProcessing",
FakeProcessor,
)
access_check = AsyncMock(return_value=("cntr_provider", "azure"))
monkeypatch.setattr(
handler_factory,
"assert_user_can_access_container",
access_check,
)
encoded_id = ResponsesAPIRequestUtils._build_container_id(
custom_llm_provider="azure",
model_id="router-gpt",
container_id="cntr_provider",
)
result = await handler_factory._process_request(
request=SimpleNamespace(query_params={}, headers={}),
fastapi_response=SimpleNamespace(),
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
route_type="alist_container_files",
path_params={"container_id": encoded_id},
)
access_check.assert_awaited_once()
assert access_check.await_args.kwargs["container_id"] == encoded_id
assert result["container_id"] == "cntr_provider"
assert result["custom_llm_provider"] == "azure"
assert result["model_id"] == "router-gpt"
@pytest.mark.asyncio
async def test_should_validate_owner_and_forward_decoded_id_for_multipart_upload(
monkeypatch,
@ -1158,93 +801,3 @@ async def test_get_container_owner_caches_negative_lookups(monkeypatch):
assert await ownership._get_container_owner("cntr_x", "openai") is None
assert await ownership._get_container_owner("cntr_x", "openai") is None
assert table.find_first.await_count == 1
@pytest.mark.asyncio
async def test_record_container_owner_invalidates_cache(monkeypatch):
"""A recorded owner must drop the cached value so the next read re-fetches.
Otherwise a stale `None` from a prior negative lookup would survive the
create and the new owner would be invisible until the TTL elapses.
"""
# Seed the cache with a stale negative result.
ownership._write_container_owner_cache("container:openai:cntr_new", None)
cached_hit, cached_value = ownership._read_container_owner_cache(
"container:openai:cntr_new"
)
assert cached_hit and cached_value is None
table = AsyncMock()
table.find_unique.return_value = None
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
await ownership.record_container_owner(
response=_container("cntr_new"),
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
custom_llm_provider="openai",
)
# Invalidation drops the entry — next read goes to the DB.
cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_new")
assert not cached_hit
@pytest.mark.asyncio
async def test_get_container_owner_does_not_cache_on_db_error(monkeypatch):
"""DB errors must skip caching so transient failures don't pin a `None`."""
table = AsyncMock()
table.find_first.side_effect = Exception("db unavailable")
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
result = await ownership._get_container_owner("cntr_x", "openai")
assert result is None
cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_x")
assert not cached_hit
def test_container_owner_cache_expires_after_ttl(monkeypatch):
"""Entries past the TTL count as misses so writes elsewhere are eventually
visible to this process."""
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_TTL", 0.0)
ownership._write_container_owner_cache("k", "user-1")
cached_hit, _ = ownership._read_container_owner_cache("k")
# TTL of 0 means anything in the cache is already stale.
assert not cached_hit
def test_container_owner_cache_evicts_when_at_capacity(monkeypatch):
"""The cache must not grow unbounded; reaching capacity LRU-evicts the
oldest entry, not the entire cache."""
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2)
ownership._write_container_owner_cache("a", "user-a")
ownership._write_container_owner_cache("b", "user-b")
ownership._write_container_owner_cache("c", "user-c")
# ``a`` was the oldest and is dropped; ``b`` and ``c`` survive.
assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["b", "c"]
def test_container_owner_cache_read_marks_as_recently_used(monkeypatch):
"""Reading an entry should reset its position so a subsequent eviction
drops a less-recently-used entry instead of the just-touched one."""
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2)
ownership._write_container_owner_cache("a", "user-a")
ownership._write_container_owner_cache("b", "user-b")
# Touch ``a`` so it becomes the most-recently-used.
ownership._read_container_owner_cache("a")
ownership._write_container_owner_cache("c", "user-c")
# ``b`` is the LRU at this point; ``a`` and ``c`` survive.
assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["a", "c"]

View File

@ -19,9 +19,11 @@ from litellm.skills import main as skills_main
@pytest.fixture(autouse=True)
def clear_skill_cache():
skills_handler._SKILL_CACHE.clear()
skills_handler._SKILL_CACHE.cache_dict.clear()
skills_handler._SKILL_CACHE.ttl_dict.clear()
yield
skills_handler._SKILL_CACHE.clear()
skills_handler._SKILL_CACHE.cache_dict.clear()
skills_handler._SKILL_CACHE.ttl_dict.clear()
def _skill(skill_id: str, created_by: str | None) -> LiteLLM_SkillsTable:
@ -440,99 +442,75 @@ async def test_should_scope_skill_injection_fetch_to_authenticated_user(monkeypa
@pytest.mark.asyncio
async def test_load_skill_uses_cache_after_first_db_hit(monkeypatch):
"""`fetch_skill_from_db` is hit per-chat-completion; the cache absorbs
"""``fetch_skill_from_db`` runs per chat-completion; the cache absorbs
repeats so we don't issue a Prisma query on every request."""
fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a")
store_factory = AsyncMock()
store = Mock()
store.find_skill = AsyncMock(return_value=fake_skill)
table = AsyncMock()
table.find_unique = AsyncMock(return_value=fake_skill)
prisma_client = type(
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
)()
monkeypatch.setattr(
skills_handler.LiteLLMSkillsHandler,
"_get_prisma_client",
AsyncMock(return_value=store_factory),
)
monkeypatch.setattr(
skills_handler,
"LiteLLMSkillsStore",
Mock(return_value=store),
AsyncMock(return_value=prisma_client),
)
first = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
second = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
third = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
assert first is fake_skill
assert second is fake_skill
assert third is fake_skill
assert store.find_skill.await_count == 1
for _ in range(3):
assert (
await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
is fake_skill
)
assert table.find_unique.await_count == 1
@pytest.mark.asyncio
async def test_load_skill_caches_negative_lookups(monkeypatch):
"""Missing skills must cache as `None` so repeated lookups skip the DB."""
store_factory = AsyncMock()
store = Mock()
store.find_skill = AsyncMock(return_value=None)
"""Missing skills cache as the negative sentinel so repeated misses skip
the DB and the caller still sees ``None``."""
table = AsyncMock()
table.find_unique = AsyncMock(return_value=None)
prisma_client = type(
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
)()
monkeypatch.setattr(
skills_handler.LiteLLMSkillsHandler,
"_get_prisma_client",
AsyncMock(return_value=store_factory),
)
monkeypatch.setattr(
skills_handler,
"LiteLLMSkillsStore",
Mock(return_value=store),
AsyncMock(return_value=prisma_client),
)
assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None
assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None
assert store.find_skill.await_count == 1
assert table.find_unique.await_count == 1
@pytest.mark.asyncio
async def test_delete_skill_invalidates_cache(monkeypatch):
"""After delete, the next read must consult the DB rather than the cached
pre-delete row."""
"""After delete, the next read should not see the pre-delete cached row."""
fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a")
store = Mock()
store.find_skill = AsyncMock(return_value=fake_skill)
store.delete_skill = AsyncMock()
table = AsyncMock()
table.find_unique = AsyncMock(return_value=fake_skill)
table.delete = AsyncMock()
prisma_client = type(
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
)()
monkeypatch.setattr(
skills_handler.LiteLLMSkillsHandler,
"_get_prisma_client",
AsyncMock(return_value=Mock()),
)
monkeypatch.setattr(
skills_handler,
"LiteLLMSkillsStore",
Mock(return_value=store),
AsyncMock(return_value=prisma_client),
)
# Prime the cache via the read path.
await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
cached_hit, _ = skills_handler._read_skill_cache("litellm_skill_a")
assert cached_hit
assert skills_handler._SKILL_CACHE.get_cache("litellm_skill_a") is fake_skill
auth = UserAPIKeyAuth(user_id="user-1")
await skills_handler.LiteLLMSkillsHandler.delete_skill(
"litellm_skill_a", user_api_key_dict=auth
)
cached_hit_after, _ = skills_handler._read_skill_cache("litellm_skill_a")
assert not cached_hit_after
def test_skill_cache_expires_after_ttl(monkeypatch):
monkeypatch.setattr(skills_handler, "_SKILL_CACHE_TTL", 0.0)
skills_handler._write_skill_cache("k", Mock())
cached_hit, _ = skills_handler._read_skill_cache("k")
assert not cached_hit
def test_skill_cache_evicts_when_at_capacity(monkeypatch):
monkeypatch.setattr(skills_handler, "_SKILL_CACHE_MAX_SIZE", 2)
skills_handler._write_skill_cache("a", Mock())
skills_handler._write_skill_cache("b", Mock())
skills_handler._write_skill_cache("c", Mock())
# ``a`` was the oldest and is LRU-evicted; ``b`` and ``c`` survive.
assert list(skills_handler._SKILL_CACHE.keys()) == ["b", "c"]
# Post-delete, the cache holds the negative sentinel — not the stale row.
assert (
skills_handler._SKILL_CACHE.get_cache("litellm_skill_a")
== skills_handler._NEGATIVE_SKILL_SENTINEL
)