diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260224203854_add_agent_object_permissions_table/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260224203854_add_agent_object_permissions_table/migration.sql new file mode 100644 index 0000000000..78e364d547 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260224203854_add_agent_object_permissions_table/migration.sql @@ -0,0 +1,40 @@ +-- AlterTable +ALTER TABLE "LiteLLM_AgentsTable" ADD COLUMN "object_permission_id" TEXT; + +-- AlterTable +ALTER TABLE "LiteLLM_MCPServerTable" DROP COLUMN "spec_path"; + +-- AlterTable +ALTER TABLE "LiteLLM_VerificationToken" ADD COLUMN "agent_id" TEXT; + +-- CreateTable +CREATE TABLE "LiteLLM_ToolTable" ( + "tool_id" TEXT NOT NULL, + "tool_name" TEXT NOT NULL, + "origin" TEXT, + "call_policy" TEXT NOT NULL DEFAULT 'untrusted', + "call_count" INTEGER NOT NULL DEFAULT 0, + "assignments" JSONB DEFAULT '{}', + "key_hash" TEXT, + "team_id" TEXT, + "key_alias" TEXT, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "created_by" TEXT, + "updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_by" TEXT, + + CONSTRAINT "LiteLLM_ToolTable_pkey" PRIMARY KEY ("tool_id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_ToolTable_tool_name_key" ON "LiteLLM_ToolTable"("tool_name"); + +-- CreateIndex +CREATE INDEX "LiteLLM_ToolTable_call_policy_idx" ON "LiteLLM_ToolTable"("call_policy"); + +-- CreateIndex +CREATE INDEX "LiteLLM_ToolTable_team_id_idx" ON "LiteLLM_ToolTable"("team_id"); + +-- AddForeignKey +ALTER TABLE "LiteLLM_AgentsTable" ADD CONSTRAINT "LiteLLM_AgentsTable_object_permission_id_fkey" FOREIGN KEY ("object_permission_id") REFERENCES "LiteLLM_ObjectPermissionTable"("object_permission_id") ON DELETE SET NULL ON UPDATE CASCADE; + diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 4af7484148..155cea12ca 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -64,6 +64,8 @@ model LiteLLM_AgentsTable { litellm_params Json? agent_card_params Json agent_access_groups String[] @default([]) + object_permission_id String? + object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) created_at DateTime @default(now()) @map("created_at") created_by String updated_at DateTime @default(now()) @updatedAt @map("updated_at") @@ -264,6 +266,7 @@ model LiteLLM_ObjectPermissionTable { organizations LiteLLM_OrganizationTable[] users LiteLLM_UserTable[] end_users LiteLLM_EndUserTable[] + agents_table LiteLLM_AgentsTable[] } // Holds the MCP server configuration @@ -273,7 +276,6 @@ model LiteLLM_MCPServerTable { alias String? description String? url String? - spec_path String? transport String @default("sse") auth_type String? credentials Json? @default("{}") @@ -315,6 +317,7 @@ model LiteLLM_VerificationToken { router_settings Json? @default("{}") user_id String? team_id String? + agent_id String? project_id String? permissions Json @default("{}") max_parallel_requests Int? @@ -1052,6 +1055,26 @@ model LiteLLM_PolicyAttachmentTable { updated_by String? } +// Global tool registry - auto-discovered from LLM responses; admins set call_policy here +model LiteLLM_ToolTable { + tool_id String @id @default(uuid()) + tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" + origin String? // MCP server name or "user_defined" + call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked" + call_count Int @default(0) // cumulative number of times this tool was seen + assignments Json? @default("{}") + key_hash String? // hash of the virtual key that first called this tool + team_id String? // team that first called this tool + key_alias String? // human-readable alias of the virtual key + created_at DateTime @default(now()) + created_by String? + updated_at DateTime @default(now()) @updatedAt + updated_by String? + + @@index([call_policy]) + @@index([team_id]) +} + //Unified Access Groups table for storing unified access groups model LiteLLM_AccessGroupTable { access_group_id String @id @default(uuid()) diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index 60b29b975f..860569d24c 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -412,6 +412,26 @@ class MCPRequestHandler: ) return [] + ######################################################### + # Check agent permissions if agent_id is set on the key + ######################################################### + if user_api_key_auth and user_api_key_auth.agent_id: + allowed_mcp_servers_for_agent = ( + await MCPRequestHandler._get_allowed_mcp_servers_for_agent( + user_api_key_auth + ) + ) + if len(allowed_mcp_servers_for_agent) > 0: + # Intersect: agent can only use servers allowed by BOTH key/team AND agent config + allowed_mcp_servers = [ + s + for s in allowed_mcp_servers + if s in allowed_mcp_servers_for_agent + ] + verbose_logger.debug( + f"Applied agent intersection filter. Final allowed servers: {allowed_mcp_servers}" + ) + return list(set(allowed_mcp_servers)) except Exception as e: verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}") @@ -513,13 +533,33 @@ class MCPRequestHandler: if team_tools: if key_tools: # Both have restrictions → intersection - return list(set(team_tools) & set(key_tools)) + allowed_tools = list(set(team_tools) & set(key_tools)) else: # Only team has restrictions → inherit from team - return team_tools + allowed_tools = team_tools else: # No team restrictions → use key restrictions - return key_tools + allowed_tools = key_tools + + # Intersect with agent's tool permissions if agent_id is set + if user_api_key_auth.agent_id: + # Pre-fetch agent object_permission once to avoid duplicate DB query + agent_obj_perm = await MCPRequestHandler._get_agent_object_permission( + user_api_key_auth + ) + agent_tools = await MCPRequestHandler._get_agent_tool_permissions_for_server( + server_id=server_id, + user_api_key_auth=user_api_key_auth, + agent_object_permission=agent_obj_perm, + ) + if agent_tools is not None: + if allowed_tools is not None: + allowed_tools = list( + set(allowed_tools) & set(agent_tools) + ) + else: + allowed_tools = agent_tools + return allowed_tools except Exception as e: verbose_logger.warning(f"Failed to get allowed tools for server: {str(e)}") @@ -715,6 +755,131 @@ class MCPRequestHandler: ) return [] + @staticmethod + async def _get_agent_object_permission( + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + ): + """ + Fetch the agent's object_permission from the DB (single query). + + Returns the object_permission object or None. + """ + from litellm.proxy.proxy_server import prisma_client + + if not user_api_key_auth or not user_api_key_auth.agent_id: + return None + + if prisma_client is None: + verbose_logger.debug("prisma_client is None") + return None + + try: + agent_row = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": user_api_key_auth.agent_id}, + include={"object_permission": True}, + ) + if agent_row is None or agent_row.object_permission is None: + return None + + return agent_row.object_permission + except Exception as e: + verbose_logger.warning( + f"Failed to get agent object permission: {str(e)}" + ) + return None + + @staticmethod + async def _get_allowed_mcp_servers_for_agent( + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + agent_object_permission=None, + ) -> List[str]: + """ + Get allowed MCP servers for an agent (from the agent's object_permission). + + Returns the MCP servers from the agent's object_permission. + If agent has no object_permission, returns [] (no extra restriction). + + Args: + user_api_key_auth: User auth with agent_id + agent_object_permission: Pre-fetched object_permission to avoid duplicate DB query. + If None, will be fetched from DB. + """ + if not user_api_key_auth or not user_api_key_auth.agent_id: + return [] + + try: + obj_perm = agent_object_permission + if obj_perm is None: + obj_perm = await MCPRequestHandler._get_agent_object_permission( + user_api_key_auth + ) + if obj_perm is None: + return [] + + direct_mcp_servers = getattr(obj_perm, "mcp_servers", None) or [] + if isinstance(direct_mcp_servers, str): + direct_mcp_servers = [] + mcp_access_groups = getattr(obj_perm, "mcp_access_groups", None) or [] + if isinstance(mcp_access_groups, str): + mcp_access_groups = [] + + access_group_servers = ( + await MCPRequestHandler._get_mcp_servers_from_access_groups( + mcp_access_groups + ) + ) + all_servers = list(direct_mcp_servers) + access_group_servers + return list(set(all_servers)) + except Exception as e: + verbose_logger.warning( + f"Failed to get allowed MCP servers for agent: {str(e)}" + ) + return [] + + @staticmethod + async def _get_agent_tool_permissions_for_server( + server_id: str, + user_api_key_auth: Optional[UserAPIKeyAuth] = None, + agent_object_permission=None, + ) -> Optional[List[str]]: + """ + Get allowed tool names for a server from the agent's object_permission. + Returns None if agent has no tool restrictions for this server. + + Args: + server_id: Server ID to check permissions for + user_api_key_auth: User auth with agent_id + agent_object_permission: Pre-fetched object_permission to avoid duplicate DB query. + If None, will be fetched from DB. + """ + if not user_api_key_auth or not user_api_key_auth.agent_id: + return None + + try: + obj_perm = agent_object_permission + if obj_perm is None: + obj_perm = await MCPRequestHandler._get_agent_object_permission( + user_api_key_auth + ) + if obj_perm is None: + return None + + mcp_tool_permissions = getattr( + obj_perm, "mcp_tool_permissions", None + ) + if not mcp_tool_permissions: + return None + if isinstance(mcp_tool_permissions, dict): + tools = mcp_tool_permissions.get(server_id) + else: + tools = None + return list(tools) if tools else None + except Exception as e: + verbose_logger.warning( + f"Failed to get agent tool permissions for server: {str(e)}" + ) + return None + @staticmethod def _get_config_server_ids_for_access_groups( config_mcp_servers, access_groups: List[str] diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4053d9d077..5430d7a360 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,60 +1,40 @@ import enum import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, + Optional, Union) import httpx -from pydantic import ( - BaseModel, - ConfigDict, - Field, - Json, - field_validator, - model_validator, -) +from pydantic import (BaseModel, ConfigDict, Field, Json, field_validator, + model_validator) from typing_extensions import Required, TypedDict from litellm._uuid import uuid from litellm.types.integrations.slack_alerting import AlertType -from litellm.types.llms.openai import ( - AllMessageValues, - OpenAIFileObject, - ResponsesAPIResponse, -) -from litellm.types.mcp import ( - MCPAuth, - MCPAuthType, - MCPCredentials, - MCPTransport, - MCPTransportType, -) +from litellm.types.llms.openai import (AllMessageValues, OpenAIFileObject, + ResponsesAPIResponse) +from litellm.types.mcp import (MCPAuth, MCPAuthType, MCPCredentials, + MCPTransport, MCPTransportType) from litellm.types.mcp_server.mcp_server_manager import MCPInfo from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.secret_managers.main import KeyManagementSystem -from litellm.types.utils import ( - CallTypes, - CostBreakdown, - EmbeddingResponse, - GenericBudgetConfigType, - ImageResponse, - LiteLLMBatch, - LiteLLMFineTuningJob, - LiteLLMPydanticObjectBase, - ModelResponse, - ProviderField, - StandardCallbackDynamicParams, - StandardLoggingGuardrailInformation, - StandardLoggingMCPToolCall, - StandardLoggingModelInformation, - StandardLoggingPayloadErrorInformation, - StandardLoggingPayloadStatus, - StandardLoggingVectorStoreRequest, - StandardPassThroughResponseObject, - TextCompletionResponse, -) +from litellm.types.utils import (CallTypes, CostBreakdown, EmbeddingResponse, + GenericBudgetConfigType, ImageResponse, + LiteLLMBatch, LiteLLMFineTuningJob, + LiteLLMPydanticObjectBase, ModelResponse, + ProviderField, StandardCallbackDynamicParams, + StandardLoggingGuardrailInformation, + StandardLoggingMCPToolCall, + StandardLoggingModelInformation, + StandardLoggingPayloadErrorInformation, + StandardLoggingPayloadStatus, + StandardLoggingVectorStoreRequest, + StandardPassThroughResponseObject, + TextCompletionResponse) from litellm.types.videos.main import VideoObject -from .types_utils.utils import get_instance_fn, validate_custom_validate_return_type +from .types_utils.utils import (get_instance_fn, + validate_custom_validate_return_type) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -2202,6 +2182,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase): config: Dict = {} user_id: Optional[str] = None team_id: Optional[str] = None + agent_id: Optional[str] = None project_id: Optional[str] = None max_parallel_requests: Optional[int] = None metadata: Dict = {} @@ -2379,7 +2360,8 @@ class UserAPIKeyAuth( This is used to track number of requests/spend for health check calls. """ - from litellm.constants import LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME + from litellm.constants import \ + LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME return cls( api_key=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME, @@ -2411,7 +2393,8 @@ class UserAPIKeyAuth( This is used to track actions performed by automated system jobs. """ - from litellm.constants import LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME + from litellm.constants import \ + LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME return cls( api_key=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME, @@ -2802,7 +2785,8 @@ class LiteLLM_AuditLogs(LiteLLMPydanticObjectBase): @model_validator(mode="after") def mask_api_keys(self): - from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker + from litellm.litellm_core_utils.sensitive_data_masker import \ + SensitiveDataMasker masker = SensitiveDataMasker(sensitive_patterns={"key"}) diff --git a/litellm/proxy/agent_endpoints/agent_registry.py b/litellm/proxy/agent_endpoints/agent_registry.py index 0d2df3856a..159c9fb93d 100644 --- a/litellm/proxy/agent_endpoints/agent_registry.py +++ b/litellm/proxy/agent_endpoints/agent_registry.py @@ -5,6 +5,9 @@ from typing import Any, Dict, List, Optional import litellm from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.proxy.management_helpers.object_permission_utils import ( + handle_update_object_permission_common, +) from litellm.proxy.utils import PrismaClient from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest @@ -117,20 +120,39 @@ class AgentRegistry: ) agent_card_params: str = safe_dumps(agent_card_params_dict) + # Handle object_permission (MCP tool access for agent) + object_permission_id: Optional[str] = None + if agent.get("object_permission") is not None: + agent_copy = dict(agent) + object_permission_id = await handle_update_object_permission_common( + agent_copy, None, prisma_client + ) + + create_data: Dict[str, Any] = { + "agent_name": agent_name, + "litellm_params": litellm_params, + "agent_card_params": agent_card_params, + "created_by": created_by, + "updated_by": created_by, + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + if object_permission_id is not None: + create_data["object_permission_id"] = object_permission_id + # Create agent in DB created_agent = await prisma_client.db.litellm_agentstable.create( - data={ - "agent_name": agent_name, - "litellm_params": litellm_params, - "agent_card_params": agent_card_params, - "created_by": created_by, - "updated_by": created_by, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - } + data=create_data, + include={"object_permission": True}, ) - return AgentResponse(**created_agent.model_dump()) # type: ignore + created_agent_dict = created_agent.model_dump() + if created_agent.object_permission is not None: + try: + created_agent_dict["object_permission"] = created_agent.object_permission.model_dump() + except Exception: + created_agent_dict["object_permission"] = created_agent.object_permission.dict() + return AgentResponse(**created_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error adding agent to DB: {str(e)}") @@ -181,7 +203,7 @@ class AgentRegistry: raise Exception(f"Agent with ID {agent_id} not found") augment_agent = {**existing_agent, **agent} - update_data = {} + update_data: Dict[str, Any] = {} if augment_agent.get("agent_name"): update_data["agent_name"] = augment_agent.get("agent_name") if augment_agent.get("litellm_params"): @@ -192,6 +214,20 @@ class AgentRegistry: update_data["agent_card_params"] = safe_dumps( augment_agent.get("agent_card_params") ) + if agent.get("object_permission") is not None: + agent_copy = dict(augment_agent) + existing_object_permission_id = existing_agent.get( + "object_permission_id" + ) + object_permission_id = ( + await handle_update_object_permission_common( + agent_copy, + existing_object_permission_id, + prisma_client, + ) + ) + if object_permission_id is not None: + update_data["object_permission_id"] = object_permission_id # Patch agent in DB patched_agent = await prisma_client.db.litellm_agentstable.update( where={"agent_id": agent_id}, @@ -200,8 +236,15 @@ class AgentRegistry: "updated_by": updated_by, "updated_at": datetime.now(timezone.utc), }, + include={"object_permission": True}, ) - return AgentResponse(**patched_agent.model_dump()) # type: ignore + patched_agent_dict = patched_agent.model_dump() + if patched_agent.object_permission is not None: + try: + patched_agent_dict["object_permission"] = patched_agent.object_permission.model_dump() + except Exception: + patched_agent_dict["object_permission"] = patched_agent.object_permission.dict() + return AgentResponse(**patched_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error patching agent in DB: {str(e)}") @@ -238,19 +281,47 @@ class AgentRegistry: ) agent_card_params: str = safe_dumps(agent_card_params_dict) + update_data: Dict[str, Any] = { + "agent_name": agent_name, + "litellm_params": litellm_params, + "agent_card_params": agent_card_params, + "updated_by": updated_by, + "updated_at": datetime.now(timezone.utc), + } + if agent.get("object_permission") is not None: + existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": agent_id} + ) + existing_object_permission_id = ( + existing_agent.object_permission_id + if existing_agent is not None + else None + ) + agent_copy = dict(agent) + object_permission_id = ( + await handle_update_object_permission_common( + agent_copy, + existing_object_permission_id, + prisma_client, + ) + ) + if object_permission_id is not None: + update_data["object_permission_id"] = object_permission_id + # Update agent in DB updated_agent = await prisma_client.db.litellm_agentstable.update( where={"agent_id": agent_id}, - data={ - "agent_name": agent_name, - "litellm_params": litellm_params, - "agent_card_params": agent_card_params, - "updated_by": updated_by, - "updated_at": datetime.now(timezone.utc), - }, + data=update_data, + include={"object_permission": True}, ) - return AgentResponse(**updated_agent.model_dump()) # type: ignore + updated_agent_dict = updated_agent.model_dump() + if updated_agent.object_permission is not None: + try: + updated_agent_dict["object_permission"] = updated_agent.object_permission.model_dump() + except Exception: + updated_agent_dict["object_permission"] = updated_agent.object_permission.dict() + return AgentResponse(**updated_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error updating agent in DB: {str(e)}") @@ -264,11 +335,19 @@ class AgentRegistry: try: agents_from_db = await prisma_client.db.litellm_agentstable.find_many( order={"created_at": "desc"}, + include={"object_permission": True}, ) agents: List[Dict[str, Any]] = [] for agent in agents_from_db: - agents.append(dict(agent)) + agent_dict = dict(agent) + # object_permission is eagerly loaded via include above + if agent.object_permission is not None: + try: + agent_dict["object_permission"] = agent.object_permission.model_dump() + except Exception: + agent_dict["object_permission"] = agent.object_permission.dict() + agents.append(agent_dict) return agents except Exception as e: diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py index 4a8d615f0b..b411b81b43 100644 --- a/litellm/proxy/agent_endpoints/endpoints.py +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -16,6 +16,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity from litellm.types.agents import ( AgentConfig, AgentMakePublicResponse, @@ -23,8 +24,6 @@ from litellm.types.agents import ( MakeAgentsPublicRequest, PatchAgentRequest, ) - -from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, ) @@ -233,11 +232,18 @@ async def get_agent_by_id(agent_id: str): try: agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id) if agent is None: - agent = await prisma_client.db.litellm_agentstable.find_unique( - where={"agent_id": agent_id} + agent_row = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": agent_id}, + include={"object_permission": True}, ) - if agent is not None: - agent = AgentResponse(**agent.model_dump()) # type: ignore + if agent_row is not None: + agent_dict = agent_row.model_dump() + if agent_row.object_permission is not None: + try: + agent_dict["object_permission"] = agent_row.object_permission.model_dump() + except Exception: + agent_dict["object_permission"] = agent_row.object_permission.dict() + agent = AgentResponse(**agent_dict) # type: ignore if agent is None: raise HTTPException( diff --git a/litellm/proxy/hooks/max_iterations_limiter.py b/litellm/proxy/hooks/max_iterations_limiter.py new file mode 100644 index 0000000000..8d481f6b26 --- /dev/null +++ b/litellm/proxy/hooks/max_iterations_limiter.py @@ -0,0 +1,208 @@ +""" +Max Iterations Limiter for LiteLLM Proxy. + +Enforces a per-session cap on the number of LLM calls an agentic loop can make. +Callers send a `session_id` with each request (via `x-litellm-session-id` header +or `metadata.session_id`), and this hook counts calls per session. When the count +exceeds `max_iterations` (configured in key/team metadata), returns 429. + +Works across multiple proxy instances via DualCache (in-memory + Redis). +Follows the same pattern as parallel_request_limiter_v3.py. +""" + +import os +from typing import TYPE_CHECKING, Any, Optional, Union + +from fastapi import HTTPException + +from litellm import DualCache +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + +if TYPE_CHECKING: + from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache + + InternalUsageCache = _InternalUsageCache +else: + InternalUsageCache = Any + + +# Redis Lua script for atomic increment with TTL. +# Returns the new count after increment. +# Only sets EXPIRE on first increment (when count becomes 1). +MAX_ITERATIONS_INCREMENT_SCRIPT = """ +local key = KEYS[1] +local ttl = tonumber(ARGV[1]) + +local current = redis.call('INCR', key) +if current == 1 then + redis.call('EXPIRE', key, ttl) +end + +return current +""" + +# Default TTL for session iteration counters (1 hour) +DEFAULT_MAX_ITERATIONS_TTL = 3600 + + +class _PROXY_MaxIterationsHandler(CustomLogger): + """ + Pre-call hook that enforces max_iterations per session. + + Configuration: + - max_iterations: set in key metadata via /key/generate or /key/update + e.g. metadata={"max_iterations": 25} + - session_id: sent by caller via x-litellm-session-id header or + metadata.session_id in request body + + Cache key pattern: + {session_iterations:}:count + + Multi-instance support: + Uses Redis Lua script for atomic increment (same pattern as + parallel_request_limiter_v3). Falls back to in-memory cache + when Redis is unavailable. + """ + + def __init__(self, internal_usage_cache: InternalUsageCache): + self.internal_usage_cache = internal_usage_cache + self.ttl = int( + os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL) + ) + + # Register Lua script with Redis if available (same pattern as v3 limiter) + if self.internal_usage_cache.dual_cache.redis_cache is not None: + self.increment_script = ( + self.internal_usage_cache.dual_cache.redis_cache.async_register_script( + MAX_ITERATIONS_INCREMENT_SCRIPT + ) + ) + else: + self.increment_script = None + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ) -> Optional[Union[Exception, str, dict]]: + """ + Check session iteration count before making the API call. + + Extracts session_id from request metadata and max_iterations from + key metadata. If the session has exceeded max_iterations, raises 429. + """ + # Extract session_id from request data + session_id = self._get_session_id(data) + if session_id is None: + return None + + # Extract max_iterations from key metadata + max_iterations = self._get_max_iterations(user_api_key_dict) + if max_iterations is None: + return None + + verbose_proxy_logger.debug( + "MaxIterationsHandler: session_id=%s, max_iterations=%s", + session_id, + max_iterations, + ) + + # Increment and check + cache_key = self._make_cache_key(session_id) + current_count = await self._increment_and_get(cache_key) + + if current_count > max_iterations: + raise HTTPException( + status_code=429, + detail=( + f"Max iterations exceeded for session {session_id}. " + f"Current count: {current_count}, max_iterations: {max_iterations}." + ), + ) + + verbose_proxy_logger.debug( + "MaxIterationsHandler: session_id=%s, count=%s/%s", + session_id, + current_count, + max_iterations, + ) + + return None + + def _get_session_id(self, data: dict) -> Optional[str]: + """Extract session_id from request metadata.""" + metadata = data.get("metadata") or {} + session_id = metadata.get("session_id") + if session_id is not None: + return str(session_id) + + # Also check litellm_metadata (used for /thread and /assistant endpoints) + litellm_metadata = data.get("litellm_metadata") or {} + session_id = litellm_metadata.get("session_id") + if session_id is not None: + return str(session_id) + + return None + + def _get_max_iterations( + self, user_api_key_dict: UserAPIKeyAuth + ) -> Optional[int]: + """Extract max_iterations from key metadata.""" + metadata = user_api_key_dict.metadata or {} + max_iterations = metadata.get("max_iterations") + if max_iterations is not None: + return int(max_iterations) + return None + + def _make_cache_key(self, session_id: str) -> str: + """ + Create cache key for session iteration counter. + + Uses Redis hash-tag pattern {session_iterations:} so all + keys for a session land on the same Redis Cluster slot. + """ + return f"{{session_iterations:{session_id}}}:count" + + async def _increment_and_get(self, cache_key: str) -> int: + """ + Atomically increment the session counter and return the new value. + + Tries Redis first (via registered Lua script for atomicity across + instances), falls back to in-memory cache. + """ + if self.increment_script is not None: + try: + result = await self.increment_script( + keys=[cache_key], + args=[self.ttl], + ) + return int(result) + except Exception as e: + verbose_proxy_logger.warning( + "MaxIterationsHandler: Redis failed, falling back to in-memory: %s", + str(e), + ) + + # Fallback: in-memory cache + return await self._in_memory_increment(cache_key) + + async def _in_memory_increment(self, cache_key: str) -> int: + """Increment counter in in-memory cache with TTL.""" + current = await self.internal_usage_cache.async_get_cache( + key=cache_key, + litellm_parent_otel_span=None, + local_only=True, + ) + new_value = (int(current) if current is not None else 0) + 1 + await self.internal_usage_cache.async_set_cache( + key=cache_key, + value=new_value, + ttl=self.ttl, + litellm_parent_otel_span=None, + local_only=True, + ) + return new_value diff --git a/litellm/types/agents.py b/litellm/types/agents.py index f4e410a3e2..3ad898b193 100644 --- a/litellm/types/agents.py +++ b/litellm/types/agents.py @@ -167,16 +167,25 @@ class AugmentedAgentCard(AgentCard): is_public: bool +# Object permission shape for agent MCP tool access (mirrors LiteLLM_ObjectPermissionBase) +class AgentObjectPermission(TypedDict, total=False): + mcp_servers: Optional[List[str]] + mcp_access_groups: Optional[List[str]] + mcp_tool_permissions: Optional[Dict[str, List[str]]] + + class AgentConfig(TypedDict, total=False): agent_name: Required[str] agent_card_params: Required[AgentCard] litellm_params: Dict[str, Any] # allow for any future litellm params + object_permission: AgentObjectPermission class PatchAgentRequest(TypedDict, total=False): agent_name: str agent_card_params: AgentCard litellm_params: Dict[str, Any] + object_permission: AgentObjectPermission # Request/Response models for CRUD endpoints @@ -187,6 +196,7 @@ class AgentResponse(BaseModel): agent_name: str litellm_params: Optional[Dict[str, Any]] = None agent_card_params: Dict[str, Any] + object_permission: Optional[Dict[str, Any]] = None created_at: Optional[datetime] = None updated_at: Optional[datetime] = None created_by: Optional[str] = None diff --git a/schema.prisma b/schema.prisma index 4af7484148..155cea12ca 100644 --- a/schema.prisma +++ b/schema.prisma @@ -64,6 +64,8 @@ model LiteLLM_AgentsTable { litellm_params Json? agent_card_params Json agent_access_groups String[] @default([]) + object_permission_id String? + object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id]) created_at DateTime @default(now()) @map("created_at") created_by String updated_at DateTime @default(now()) @updatedAt @map("updated_at") @@ -264,6 +266,7 @@ model LiteLLM_ObjectPermissionTable { organizations LiteLLM_OrganizationTable[] users LiteLLM_UserTable[] end_users LiteLLM_EndUserTable[] + agents_table LiteLLM_AgentsTable[] } // Holds the MCP server configuration @@ -273,7 +276,6 @@ model LiteLLM_MCPServerTable { alias String? description String? url String? - spec_path String? transport String @default("sse") auth_type String? credentials Json? @default("{}") @@ -315,6 +317,7 @@ model LiteLLM_VerificationToken { router_settings Json? @default("{}") user_id String? team_id String? + agent_id String? project_id String? permissions Json @default("{}") max_parallel_requests Int? @@ -1052,6 +1055,26 @@ model LiteLLM_PolicyAttachmentTable { updated_by String? } +// Global tool registry - auto-discovered from LLM responses; admins set call_policy here +model LiteLLM_ToolTable { + tool_id String @id @default(uuid()) + tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" + origin String? // MCP server name or "user_defined" + call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked" + call_count Int @default(0) // cumulative number of times this tool was seen + assignments Json? @default("{}") + key_hash String? // hash of the virtual key that first called this tool + team_id String? // team that first called this tool + key_alias String? // human-readable alias of the virtual key + created_at DateTime @default(now()) + created_by String? + updated_at DateTime @default(now()) @updatedAt + updated_by String? + + @@index([call_policy]) + @@index([team_id]) +} + //Unified Access Groups table for storing unified access groups model LiteLLM_AccessGroupTable { access_group_id String @id @default(uuid()) diff --git a/scripts/test_agent_mcp_endpoints.sh b/scripts/test_agent_mcp_endpoints.sh new file mode 100755 index 0000000000..93cc68db2e --- /dev/null +++ b/scripts/test_agent_mcp_endpoints.sh @@ -0,0 +1,186 @@ +#!/usr/bin/env bash +# +# Test agent endpoint-level changes for MCP tool permissions (object_permission). +# Requires: proxy running, valid admin API key, curl, jq. +# +# Usage: +# export LITELLM_PROXY_BASE_URL="http://localhost:4000" # optional, default below +# export LITELLM_API_KEY="sk-..." # required +# ./scripts/test_agent_mcp_endpoints.sh +# +set -euo pipefail + +BASE_URL="${LITELLM_PROXY_BASE_URL:-http://localhost:4000}" +API_KEY="${LITELLM_API_KEY:-}" + +if ! command -v jq &>/dev/null; then + echo "Error: jq is required. Install with: brew install jq (macOS) or apt install jq (Linux)" + exit 1 +fi +if [[ -z "$API_KEY" ]]; then + echo "Error: LITELLM_API_KEY is not set. Export it or pass via env." + exit 1 +fi + +AUTH_HEADER="Authorization: Bearer $API_KEY" +AGENT_NAME="test-agent-mcp-$(date +%s)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +pass() { echo -e "${GREEN}PASS${NC}: $*"; } +fail() { echo -e "${RED}FAIL${NC}: $*"; exit 1; } +info() { echo -e "${YELLOW}INFO${NC}: $*"; } + +# --- 1. Create agent with object_permission --- +info "Creating agent with object_permission (mcp_servers, mcp_tool_permissions)..." +CREATE_RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/v1/agents" \ + -H "$AUTH_HEADER" \ + -H "Content-Type: application/json" \ + -d '{ + "agent_name": "'"$AGENT_NAME"'", + "agent_card_params": { + "protocolVersion": "1.0", + "name": "Test MCP Agent", + "description": "Agent for endpoint tests", + "url": "http://localhost:9999/", + "version": "1.0.0", + "defaultInputModes": ["text"], + "defaultOutputModes": ["text"], + "capabilities": {"streaming": true}, + "skills": [] + }, + "object_permission": { + "mcp_servers": ["server_1", "server_2"], + "mcp_access_groups": ["group_a"], + "mcp_tool_permissions": {"server_1": ["tool_a", "tool_b"], "server_2": ["tool_c"]} + } + }') +HTTP_CODE=$(echo "$CREATE_RESP" | tail -n1) +BODY=$(echo "$CREATE_RESP" | sed '$d') +if [[ "$HTTP_CODE" != "200" ]]; then + fail "POST /v1/agents returned $HTTP_CODE. Body: $BODY" +fi +AGENT_ID=$(echo "$BODY" | jq -r '.agent_id') +if [[ -z "$AGENT_ID" || "$AGENT_ID" == "null" ]]; then + fail "POST /v1/agents did not return agent_id. Body: $BODY" +fi +pass "Created agent $AGENT_ID" + +# Check create response includes object_permission +OP=$(echo "$BODY" | jq '.object_permission') +if [[ "$OP" == "null" || -z "$OP" ]]; then + fail "POST /v1/agents response missing object_permission. Body: $BODY" +fi +SERVERS=$(echo "$OP" | jq -r '.mcp_servers | join(",")') +if [[ "$SERVERS" != "server_1,server_2" ]]; then + fail "object_permission.mcp_servers unexpected: $SERVERS" +fi +pass "Create response includes object_permission with mcp_servers and mcp_tool_permissions" + +# --- 2. GET /v1/agents (list) includes object_permission for our agent --- +info "GET /v1/agents and check one agent has object_permission..." +LIST_RESP=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1/agents" -H "$AUTH_HEADER") +LIST_CODE=$(echo "$LIST_RESP" | tail -n1) +LIST_BODY=$(echo "$LIST_RESP" | sed '$d') +if [[ "$LIST_CODE" != "200" ]]; then + fail "GET /v1/agents returned $LIST_CODE" +fi +AGENT_IN_LIST=$(echo "$LIST_BODY" | jq --arg id "$AGENT_ID" '.[] | select(.agent_id == $id)') +if [[ -z "$AGENT_IN_LIST" ]]; then + fail "GET /v1/agents did not return agent $AGENT_ID (list might be key-scoped)" +fi +OP_LIST=$(echo "$AGENT_IN_LIST" | jq '.object_permission') +if [[ "$OP_LIST" == "null" || -z "$OP_LIST" ]]; then + fail "GET /v1/agents list entry for agent missing object_permission" +fi +pass "GET /v1/agents list includes object_permission for agent" + +# --- 3. GET /v1/agents/{agent_id} returns object_permission --- +info "GET /v1/agents/{agent_id}..." +GET_RESP=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1/agents/$AGENT_ID" -H "$AUTH_HEADER") +GET_CODE=$(echo "$GET_RESP" | tail -n1) +GET_BODY=$(echo "$GET_RESP" | sed '$d') +if [[ "$GET_CODE" != "200" ]]; then + fail "GET /v1/agents/$AGENT_ID returned $GET_CODE. Body: $GET_BODY" +fi +OP_GET=$(echo "$GET_BODY" | jq '.object_permission') +if [[ "$OP_GET" == "null" || -z "$OP_GET" ]]; then + fail "GET /v1/agents/$AGENT_ID response missing object_permission" +fi +TOOL_PERMS=$(echo "$OP_GET" | jq -r '.mcp_tool_permissions.server_1 | join(",")') +if [[ "$TOOL_PERMS" != "tool_a,tool_b" ]]; then + fail "object_permission.mcp_tool_permissions.server_1 unexpected: $TOOL_PERMS" +fi +pass "GET /v1/agents/{agent_id} returns object_permission with mcp_tool_permissions" + +# --- 4. PATCH /v1/agents/{agent_id} with new object_permission --- +info "PATCH /v1/agents/{agent_id} with updated object_permission..." +PATCH_RESP=$(curl -s -w "\n%{http_code}" -X PATCH "$BASE_URL/v1/agents/$AGENT_ID" \ + -H "$AUTH_HEADER" \ + -H "Content-Type: application/json" \ + -d '{ + "object_permission": { + "mcp_servers": ["server_3"], + "mcp_tool_permissions": {"server_3": ["tool_x"]} + } + }') +PATCH_CODE=$(echo "$PATCH_RESP" | tail -n1) +PATCH_BODY=$(echo "$PATCH_RESP" | sed '$d') +if [[ "$PATCH_CODE" != "200" ]]; then + fail "PATCH /v1/agents/$AGENT_ID returned $PATCH_CODE. Body: $PATCH_BODY" +fi +OP_PATCH=$(echo "$PATCH_BODY" | jq '.object_permission') +if [[ "$OP_PATCH" == "null" || -z "$OP_PATCH" ]]; then + fail "PATCH response missing object_permission" +fi +PATCH_SERVERS=$(echo "$OP_PATCH" | jq -r '.mcp_servers | join(",")') +if [[ "$PATCH_SERVERS" != "server_3" ]]; then + fail "PATCH object_permission.mcp_servers unexpected: $PATCH_SERVERS" +fi +pass "PATCH /v1/agents/{agent_id} updates and returns object_permission" + +# --- 5. Create agent without object_permission; GET should still work --- +info "Creating agent without object_permission..." +AGENT_NAME_2="test-agent-no-mcp-$(date +%s)" +CREATE2_RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/v1/agents" \ + -H "$AUTH_HEADER" \ + -H "Content-Type: application/json" \ + -d '{ + "agent_name": "'"$AGENT_NAME_2"'", + "agent_card_params": { + "protocolVersion": "1.0", + "name": "No MCP Agent", + "description": "No object_permission", + "url": "http://localhost:9999/", + "version": "1.0.0", + "defaultInputModes": ["text"], + "defaultOutputModes": ["text"], + "capabilities": {}, + "skills": [] + } + }') +CODE2=$(echo "$CREATE2_RESP" | tail -n1) +BODY2=$(echo "$CREATE2_RESP" | sed '$d') +if [[ "$CODE2" != "200" ]]; then + fail "POST /v1/agents (no object_permission) returned $CODE2. Body: $BODY2" +fi +AGENT_ID_2=$(echo "$BODY2" | jq -r '.agent_id') +# object_permission may be null or absent +pass "Created agent without object_permission: $AGENT_ID_2" + +# --- 6. Cleanup: delete both agents --- +info "Deleting test agents..." +for AID in "$AGENT_ID" "$AGENT_ID_2"; do + DEL_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE_URL/v1/agents/$AID" -H "$AUTH_HEADER") + if [[ "$DEL_CODE" != "200" ]]; then + info "DELETE /v1/agents/$AID returned $DEL_CODE (non-fatal)" + fi +done +pass "Cleanup done" + +echo "" +echo -e "${GREEN}All endpoint checks passed.${NC}" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index c2dbc94f72..b7ae33d1f8 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -1595,3 +1595,146 @@ async def test_get_allowed_mcp_servers_for_key_prefers_in_memory_permission(): assert set(result) == {"direct-server", "group-server"} mock_get_perm.assert_not_called() mock_access_groups.assert_called_once_with(["grp-alpha"]) + + +@pytest.mark.asyncio +class TestAgentMCPPermissions: + """Test agent-level MCP server and tool permission intersection.""" + + async def test_get_allowed_mcp_servers_agent_intersection(self): + """Key/team allow [server_1, server_2]; agent allows [server_1]. Result = [server_1].""" + user_api_key_auth = UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + team_id="test-team", + agent_id="agent-123", + ) + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_key" + ) as mock_key: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_team" + ) as mock_team: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_agent" + ) as mock_agent: + mock_key.return_value = ["server_1", "server_2"] + mock_team.return_value = [] + mock_agent.return_value = ["server_1"] + result = await MCPRequestHandler.get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth + ) + assert sorted(result) == ["server_1"] + mock_agent.assert_called_once_with(user_api_key_auth) + + async def test_get_allowed_mcp_servers_agent_no_restriction(self): + """Agent with no object_permission returns []; no intersection applied (inherit key/team).""" + user_api_key_auth = UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + agent_id="agent-456", + ) + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_key" + ) as mock_key: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_team" + ) as mock_team: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_agent" + ) as mock_agent: + mock_key.return_value = ["server_1", "server_2"] + mock_team.return_value = [] + mock_agent.return_value = [] # no agent-level restriction + result = await MCPRequestHandler.get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth + ) + assert sorted(result) == ["server_1", "server_2"] + mock_agent.assert_called_once_with(user_api_key_auth) + + async def test_get_allowed_mcp_servers_key_team_agent_intersection(self): + """Key allows [1, 2], agent allows [2, 3]. Result = [2].""" + user_api_key_auth = UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + agent_id="agent-789", + ) + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_key" + ) as mock_key: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_team" + ) as mock_team: + with patch.object( + MCPRequestHandler, "_get_allowed_mcp_servers_for_agent" + ) as mock_agent: + mock_key.return_value = ["server_1", "server_2"] + mock_team.return_value = [] + mock_agent.return_value = ["server_2", "server_3"] + result = await MCPRequestHandler.get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth + ) + assert sorted(result) == ["server_2"] + + async def test_get_allowed_tools_for_server_agent_intersection(self): + """Key allows [tool_a, tool_b], agent allows [tool_a]. Result = [tool_a].""" + user_api_key_auth = UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + agent_id="agent-tools", + ) + key_perm = MagicMock() + key_perm.mcp_tool_permissions = {"server_1": ["tool_a", "tool_b"]} + team_perm = None + with patch.object( + MCPRequestHandler, "_get_key_object_permission", return_value=key_perm + ): + with patch.object( + MCPRequestHandler, "_get_team_object_permission", + new_callable=AsyncMock, + return_value=team_perm, + ): + with patch.object( + MCPRequestHandler, + "_get_agent_tool_permissions_for_server", + new_callable=AsyncMock, + return_value=["tool_a"], + ) as mock_agent_tools: + result = await MCPRequestHandler.get_allowed_tools_for_server( + server_id="server_1", + user_api_key_auth=user_api_key_auth, + ) + assert result == ["tool_a"] + mock_agent_tools.assert_called_once() + call_kwargs = mock_agent_tools.call_args.kwargs + assert call_kwargs["server_id"] == "server_1" + assert call_kwargs["user_api_key_auth"] == user_api_key_auth + + async def test_get_allowed_tools_for_server_agent_no_restriction(self): + """Agent has no tool permissions for server; key/team result is unchanged.""" + user_api_key_auth = UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + agent_id="agent-no-tools", + ) + key_perm = MagicMock() + key_perm.mcp_tool_permissions = {"server_1": ["tool_a", "tool_b"]} + with patch.object( + MCPRequestHandler, "_get_key_object_permission", return_value=key_perm + ): + with patch.object( + MCPRequestHandler, "_get_team_object_permission", + new_callable=AsyncMock, + return_value=None, + ): + with patch.object( + MCPRequestHandler, + "_get_agent_tool_permissions_for_server", + new_callable=AsyncMock, + return_value=None, + ): + result = await MCPRequestHandler.get_allowed_tools_for_server( + server_id="server_1", + user_api_key_auth=user_api_key_auth, + ) + assert sorted(result) == ["tool_a", "tool_b"] diff --git a/tests/test_litellm/proxy/hooks/test_max_iterations_limiter.py b/tests/test_litellm/proxy/hooks/test_max_iterations_limiter.py new file mode 100644 index 0000000000..deb1c483b8 --- /dev/null +++ b/tests/test_litellm/proxy/hooks/test_max_iterations_limiter.py @@ -0,0 +1,106 @@ +""" +Unit Tests for the max iterations limiter for the proxy. + +Tests that session-scoped iteration counting works correctly: +- Enforces max_iterations per session_id +- Different sessions have independent counters +""" + +import pytest +from fastapi import HTTPException + +from litellm.caching.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.hooks.max_iterations_limiter import _PROXY_MaxIterationsHandler +from litellm.proxy.utils import InternalUsageCache + + +@pytest.mark.asyncio +async def test_max_iterations_basic_enforcement(): + """ + Test that max_iterations is enforced per session_id. + + - 3 requests with the same session_id should succeed when max_iterations=3 + - 4th request should raise 429 + """ + local_cache = DualCache() + handler = _PROXY_MaxIterationsHandler( + internal_usage_cache=InternalUsageCache(local_cache), + ) + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-test-key-1234", metadata={"max_iterations": 3} + ) + + # First 3 requests should succeed + for i in range(3): + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-abc"}}, + call_type="", + ) + + # 4th request should fail with 429 + with pytest.raises(HTTPException) as exc_info: + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-abc"}}, + call_type="", + ) + assert exc_info.value.status_code == 429 + assert "max_iterations" in str(exc_info.value.detail).lower() + + +@pytest.mark.asyncio +async def test_max_iterations_different_sessions_independent(): + """ + Test that different session_ids have independent iteration counters. + + - Session A and Session B each get their own max_iterations budget + - Exhausting Session A does not affect Session B + """ + local_cache = DualCache() + handler = _PROXY_MaxIterationsHandler( + internal_usage_cache=InternalUsageCache(local_cache), + ) + user_api_key_dict = UserAPIKeyAuth( + api_key="sk-test-key-5678", metadata={"max_iterations": 2} + ) + + # Session A: 2 calls succeed + for _ in range(2): + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-A"}}, + call_type="", + ) + + # Session B: 2 calls succeed (independent counter) + for _ in range(2): + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-B"}}, + call_type="", + ) + + # Session A: 3rd call fails + with pytest.raises(HTTPException) as exc_info: + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-A"}}, + call_type="", + ) + assert exc_info.value.status_code == 429 + + # Session B: 3rd call also fails + with pytest.raises(HTTPException): + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"metadata": {"session_id": "session-B"}}, + call_type="", + ) diff --git a/ui/litellm-dashboard/package-lock.json b/ui/litellm-dashboard/package-lock.json index fc2aa1599d..cc04e67400 100644 --- a/ui/litellm-dashboard/package-lock.json +++ b/ui/litellm-dashboard/package-lock.json @@ -13056,6 +13056,21 @@ "type": "github", "url": "https://github.com/sponsors/wooorm" } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.2.33", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.33.tgz", + "integrity": "sha512-pc9LpGNKhJ0dXQhZ5QMmYxtARwwmWLpeocFmVG5Z0DzWq5Uf0izcI8tLc+qOpqxO1PWqZ5A7J1blrUIKrIFc7Q==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } } } } diff --git a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/ModelRetrySettingsTab.test.tsx b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/ModelRetrySettingsTab.test.tsx index 32e619d084..5b756a833d 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/ModelRetrySettingsTab.test.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/ModelRetrySettingsTab.test.tsx @@ -1,7 +1,7 @@ import { render, screen } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; import React from "react"; -import { describe, beforeEach, expect, it, vi } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import ModelRetrySettingsTab from "./ModelRetrySettingsTab"; // TabPanel requires a parent Tabs context in Tremor. We stub it to render children diff --git a/ui/litellm-dashboard/src/components/agents.tsx b/ui/litellm-dashboard/src/components/agents.tsx index d6d8b320ec..7ab6a5f038 100644 --- a/ui/litellm-dashboard/src/components/agents.tsx +++ b/ui/litellm-dashboard/src/components/agents.tsx @@ -1,13 +1,13 @@ import React, { useState, useEffect } from "react"; import { Button } from "@tremor/react"; -import { Modal } from "antd"; -import { getAgentsList, deleteAgentCall } from "./networking"; +import { Modal, Alert } from "antd"; +import { getAgentsList, deleteAgentCall, keyListCall } from "./networking"; import AddAgentForm from "./agents/add_agent_form"; -import AgentTable from "./agents/agent_table"; +import AgentCardGrid from "./agents/agent_card_grid"; import { isAdminRole } from "@/utils/roles"; import AgentInfoView from "./agents/agent_info"; import NotificationsManager from "./molecules/notifications_manager"; -import { Agent } from "./agents/types"; +import { Agent, AgentKeyInfo } from "./agents/types"; interface AgentsPanelProps { accessToken: string | null; @@ -20,6 +20,7 @@ interface AgentsResponse { const AgentsPanel: React.FC = ({ accessToken, userRole }) => { const [agentsList, setAgentsList] = useState([]); + const [keyInfoMap, setKeyInfoMap] = useState>({}); const [isAddModalVisible, setIsAddModalVisible] = useState(false); const [isLoading, setIsLoading] = useState(false); const [isDeleting, setIsDeleting] = useState(false); @@ -36,8 +37,7 @@ const AgentsPanel: React.FC = ({ accessToken, userRole }) => { setIsLoading(true); try { const response: AgentsResponse = await getAgentsList(accessToken); - console.log(`agents: ${JSON.stringify(response)}`); - setAgentsList(response.agents); + setAgentsList(response.agents || []); } catch (error) { console.error("Error fetching agents:", error); } finally { @@ -45,10 +45,50 @@ const AgentsPanel: React.FC = ({ accessToken, userRole }) => { } }; + const fetchKeysForAgents = async () => { + if (!accessToken) return; + try { + const { keys = [] } = await keyListCall( + accessToken, + null, + null, + null, + null, + null, + 1, + 500 + ); + const map: Record = {}; + for (const key of keys) { + const agentId = (key as { agent_id?: string }).agent_id; + if (agentId && !map[agentId]) { + map[agentId] = { + has_key: true, + key_alias: (key as { key_alias?: string }).key_alias, + token_prefix: (key as { token?: string }).token + ? `${(key as { token: string }).token.slice(0, 8)}…` + : undefined, + }; + } + } + setKeyInfoMap(map); + } catch (error) { + console.error("Error fetching keys for agents:", error); + } + }; + useEffect(() => { fetchAgents(); }, [accessToken]); + useEffect(() => { + if (accessToken && agentsList.length > 0) { + fetchKeysForAgents(); + } else if (agentsList.length === 0) { + setKeyInfoMap({}); + } + }, [accessToken, agentsList.length]); + const handleAddAgent = () => { if (selectedAgentId) { setSelectedAgentId(null); @@ -94,6 +134,13 @@ const AgentsPanel: React.FC = ({ accessToken, userRole }) => {

Agents

List of A2A-spec agents that are available to be used in your organization. Go to AI Hub, to make agents public.

+