diff --git a/AGENTS.md b/AGENTS.md index d43f41dbe3..546f2997bf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -109,6 +109,8 @@ Key files: - `litellm/proxy/auth/` - Authentication logic - `litellm/proxy/management_endpoints/` - Admin API endpoints +**Database (proxy)**: Use Prisma model methods (`prisma_client.db..upsert`, `.find_many`, `.find_unique`, etc.), not raw SQL (`execute_raw`/`query_raw`). See COMMON PITFALLS for details. + ## MCP (MODEL CONTEXT PROTOCOL) SUPPORT LiteLLM supports MCP for agent workflows: @@ -176,6 +178,7 @@ When opening issues or pull requests, follow these templates: 5. **Dependencies**: Keep dependencies minimal and well-justified 6. **UI/Backend Contract Mismatch**: When adding a new entity type to the UI, always check whether the backend endpoint accepts a single value or an array. Match the UI control accordingly (single-select vs. multi-select) to avoid silently dropping user selections 7. **Missing Tests for New Entity Types**: When adding a new entity type (e.g., in `EntityUsage`, `UsageViewSelect`), always add corresponding tests in the existing test files and update any icon/component mocks +8. **Raw SQL in proxy DB code**: Do not use `execute_raw` or `query_raw` for proxy database access. Use Prisma model methods (e.g. `prisma_client.db.litellm_tooltable.upsert()`, `.find_many()`, `.find_unique()`) so behavior stays consistent with the schema, the client stays mockable in tests, and you avoid the pitfalls of hand-written SQL (parameter ordering, type casting, schema drift) 8. **Do not hardcode model-specific flags**: Put model-specific capability flags in `model_prices_and_context_window.json` and read them via `get_model_info` (or existing helpers like `supports_reasoning`). This prevents users from needing to upgrade LiteLLM each time a new model supports a feature. diff --git a/CLAUDE.md b/CLAUDE.md index 3b597fb8a9..c1eb75d251 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -107,6 +107,10 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components: - Migration files auto-generated with `prisma migrate dev` - Always test migrations against both PostgreSQL and SQLite +### Proxy database access +- **Do not write raw SQL** for proxy DB operations. Use Prisma model methods instead of `execute_raw` / `query_raw`. +- Use the generated client: `prisma_client.db.` (e.g. `litellm_tooltable`, `litellm_usertable`) with `.upsert()`, `.find_many()`, `.find_unique()`, `.update()`, `.update_many()` as appropriate. This avoids schema/client drift, keeps code testable with simple mocks, and matches patterns used in spend logs and other proxy code. + ### Enterprise Features - Enterprise-specific code in `enterprise/` directory - Optional features enabled via environment variables diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226000000_add_blocked_tools_to_object_permission/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226000000_add_blocked_tools_to_object_permission/migration.sql new file mode 100644 index 0000000000..cba0668419 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226000000_add_blocked_tools_to_object_permission/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "LiteLLM_ObjectPermissionTable" ADD COLUMN "blocked_tools" TEXT[] DEFAULT ARRAY[]::TEXT[]; diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226120000_add_spend_log_tool_index/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226120000_add_spend_log_tool_index/migration.sql new file mode 100644 index 0000000000..e3199679ce --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260226120000_add_spend_log_tool_index/migration.sql @@ -0,0 +1,11 @@ +-- CreateTable +CREATE TABLE "LiteLLM_SpendLogToolIndex" ( + "request_id" TEXT NOT NULL, + "tool_name" TEXT NOT NULL, + "start_time" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "LiteLLM_SpendLogToolIndex_pkey" PRIMARY KEY ("request_id","tool_name") +); + +-- CreateIndex +CREATE INDEX "LiteLLM_SpendLogToolIndex_tool_name_start_time_idx" ON "LiteLLM_SpendLogToolIndex"("tool_name", "start_time"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index e0b28a4e01..5abe7a0a2b 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable { vector_stores String[] @default([]) agents String[] @default([]) agent_access_groups String[] @default([]) + blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission teams LiteLLM_TeamTable[] projects LiteLLM_ProjectTable[] verification_tokens LiteLLM_VerificationToken[] @@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex { @@index([policy_id, start_time]) } +// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production +model LiteLLM_SpendLogToolIndex { + request_id String + tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc. + start_time DateTime + + @@id([request_id, tool_name]) + @@index([tool_name, start_time]) +} + // Prompt table for storing prompt configurations model LiteLLM_PromptTable { id String @id @default(uuid()) @@ -1065,26 +1076,31 @@ model LiteLLM_PolicyAttachmentTable { updated_by String? } -// Global tool registry - auto-discovered from LLM responses; admins set call_policy here +// Global tool registry - auto-discovered from LLM responses; admins set input_policy/output_policy here model LiteLLM_ToolTable { - tool_id String @id @default(uuid()) - tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" - origin String? // MCP server name or "user_defined" - call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked" - call_count Int @default(0) // cumulative number of times this tool was seen - assignments Json? @default("{}") - key_hash String? // hash of the virtual key that first called this tool - team_id String? // team that first called this tool - key_alias String? // human-readable alias of the virtual key - created_at DateTime @default(now()) - created_by String? - updated_at DateTime @default(now()) @updatedAt - updated_by String? + tool_id String @id @default(uuid()) + tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" + origin String? // MCP server name or "user_defined" + input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked" + output_policy String @default("untrusted") // "trusted" | "untrusted" + call_count Int @default(0) // cumulative number of times this tool was seen + assignments Json? @default("{}") + key_hash String? // hash of the virtual key that first called this tool + team_id String? // team that first called this tool + key_alias String? // human-readable alias of the virtual key + user_agent String? // user-agent of the first request that discovered this tool + last_used_at DateTime? // timestamp of the most recent call + created_at DateTime @default(now()) + created_by String? + updated_at DateTime @default(now()) @updatedAt + updated_by String? - @@index([call_policy]) + @@index([input_policy]) + @@index([output_policy]) @@index([team_id]) } +// Per-(tool, team/key) policy overrides. When present, override replaces global tool policy for that scope. //Unified Access Groups table for storing unified access groups model LiteLLM_AccessGroupTable { access_group_id String @id @default(uuid()) diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index 98650a238e..a6df346e8a 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -75,7 +75,7 @@ class AnthropicMessagesHandler(BaseTranslation): if messages is None: return data - chat_completion_compatible_request, tool_name_mapping = ( + chat_completion_compatible_request, _tool_name_mapping = ( LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai( # Use a shallow copy to avoid mutating request data (pop on litellm_metadata). anthropic_message_request=cast(AnthropicMessagesRequest, data.copy()) @@ -141,6 +141,14 @@ class AnthropicMessagesHandler(BaseTranslation): return data + def extract_request_tool_names(self, data: dict) -> List[str]: + """Extract tool names from Anthropic messages request (tools[].name).""" + names: List[str] = [] + for tool in data.get("tools") or []: + if isinstance(tool, dict) and tool.get("name"): + names.append(str(tool["name"])) + return names + def _extract_input_text_and_images( self, message: Dict[str, Any], diff --git a/litellm/llms/base_llm/guardrail_translation/base_translation.py b/litellm/llms/base_llm/guardrail_translation/base_translation.py index 7106c207bd..a7982cb606 100644 --- a/litellm/llms/base_llm/guardrail_translation/base_translation.py +++ b/litellm/llms/base_llm/guardrail_translation/base_translation.py @@ -98,3 +98,10 @@ class BaseTranslation(ABC): Optional to override in subclasses. """ return responses_so_far + + def extract_request_tool_names(self, data: dict) -> List[str]: + """ + Extract tool names from the request body for allowlist/policy checks. + Override in tool-capable handlers; default returns []. + """ + return [] diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index 67e9e42bc3..10b0b58b6a 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -135,6 +135,19 @@ class OpenAIChatCompletionsHandler(BaseTranslation): return data + def extract_request_tool_names(self, data: dict) -> List[str]: + """Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name).""" + names: List[str] = [] + for tool in data.get("tools") or []: + if isinstance(tool, dict) and tool.get("type") == "function": + fn = tool.get("function") + if isinstance(fn, dict) and fn.get("name"): + names.append(str(fn["name"])) + for fn in data.get("functions") or []: + if isinstance(fn, dict) and fn.get("name"): + names.append(str(fn["name"])) + return names + def _extract_inputs( self, message: Dict[str, Any], diff --git a/litellm/llms/openai/responses/guardrail_translation/handler.py b/litellm/llms/openai/responses/guardrail_translation/handler.py index 6b092911d3..7c3354cf88 100644 --- a/litellm/llms/openai/responses/guardrail_translation/handler.py +++ b/litellm/llms/openai/responses/guardrail_translation/handler.py @@ -30,27 +30,22 @@ Output: response.output is List[GenericResponseOutputItem] where each has: from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast -from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_function_tool_call import \ + ResponseFunctionToolCall from pydantic import BaseModel from litellm._logging import verbose_proxy_logger from litellm.completion_extras.litellm_responses_transformation.transformation import ( LiteLLMResponsesTransformationHandler, - OpenAiResponsesToChatCompletionStreamIterator, -) -from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation -from litellm.responses.litellm_completion_transformation.transformation import ( - LiteLLMCompletionResponsesConfig, -) -from litellm.types.llms.openai import ( - ChatCompletionToolCallChunk, - ChatCompletionToolParam, -) -from litellm.types.responses.main import ( - GenericResponseOutputItem, - OutputFunctionToolCall, - OutputText, -) + OpenAiResponsesToChatCompletionStreamIterator) +from litellm.llms.base_llm.guardrail_translation.base_translation import \ + BaseTranslation +from litellm.responses.litellm_completion_transformation.transformation import \ + LiteLLMCompletionResponsesConfig +from litellm.types.llms.openai import (ChatCompletionToolCallChunk, + ChatCompletionToolParam) +from litellm.types.responses.main import (GenericResponseOutputItem, + OutputFunctionToolCall, OutputText) from litellm.types.utils import GenericGuardrailAPIInputs if TYPE_CHECKING: @@ -188,6 +183,18 @@ class OpenAIResponsesHandler(BaseTranslation): return data + def extract_request_tool_names(self, data: dict) -> List[str]: + """Extract tool names from Responses API request (tools[].name for function, tools[].server_label for mcp).""" + names: List[str] = [] + for tool in data.get("tools") or []: + if not isinstance(tool, dict): + continue + if tool.get("type") == "function" and tool.get("name"): + names.append(str(tool["name"])) + elif tool.get("type") == "mcp" and tool.get("server_label"): + names.append(str(tool["server_label"])) + return names + def _extract_and_transform_tools( self, tools: List[Dict[str, Any]], diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 3dfbc5cb21..b92e272797 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -9779,6 +9779,122 @@ } ] }, + "dashscope/qwen3-max-2026-01-23": { + "litellm_provider": "dashscope", + "max_input_tokens": 258048, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "source": "https://www.alibabacloud.com/help/en/model-studio/models", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "tiered_pricing": [ + { + "input_cost_per_token": 1.2e-06, + "output_cost_per_token": 6e-06, + "range": [ + 0, + 32000.0 + ] + }, + { + "input_cost_per_token": 2.4e-06, + "output_cost_per_token": 1.2e-05, + "range": [ + 32000.0, + 128000.0 + ] + }, + { + "input_cost_per_token": 3e-06, + "output_cost_per_token": 1.5e-05, + "range": [ + 128000.0, + 252000.0 + ] + } + ] + }, + "dashscope/qwen3-next-80b-a3b-instruct": { + "input_cost_per_token": 1.5e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 262144, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 1.2e-06, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_tool_choice": true + }, + "dashscope/qwen3-next-80b-a3b-thinking": { + "input_cost_per_token": 1.5e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 262144, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 1.2e-06, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true + }, + "dashscope/qwen3-vl-235b-a22b-instruct": { + "input_cost_per_token": 4e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 131072, + "max_output_tokens": 32768, + "max_tokens": 32768, + "mode": "chat", + "output_cost_per_token": 1.6e-06, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "dashscope/qwen3-vl-235b-a22b-thinking": { + "input_cost_per_token": 4e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 131072, + "max_output_tokens": 32768, + "max_tokens": 32768, + "mode": "chat", + "output_cost_per_token": 4e-06, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "dashscope/qwen3-vl-32b-instruct": { + "input_cost_per_token": 1.6e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 131072, + "max_output_tokens": 32768, + "max_tokens": 32768, + "mode": "chat", + "output_cost_per_token": 6.4e-07, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_vision": true + }, + "dashscope/qwen3-vl-32b-thinking": { + "input_cost_per_token": 1.6e-07, + "litellm_provider": "dashscope", + "max_input_tokens": 131072, + "max_output_tokens": 32768, + "max_tokens": 32768, + "mode": "chat", + "output_cost_per_token": 2.87e-06, + "source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true + }, "dashscope/qwen3-vl-plus": { "litellm_provider": "dashscope", "max_input_tokens": 260096, @@ -10844,7 +10960,8 @@ "output_cost_per_token": 9e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/NousResearch/Hermes-3-Llama-3.1-405B": { "max_tokens": 131072, @@ -10854,7 +10971,8 @@ "output_cost_per_token": 1e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/NousResearch/Hermes-3-Llama-3.1-70B": { "max_tokens": 131072, @@ -10874,7 +10992,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen2.5-72B-Instruct": { "max_tokens": 32768, @@ -10884,7 +11003,8 @@ "output_cost_per_token": 3.9e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen2.5-7B-Instruct": { "max_tokens": 32768, @@ -10905,7 +11025,8 @@ "litellm_provider": "deepinfra", "mode": "chat", "supports_tool_choice": true, - "supports_vision": true + "supports_vision": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-14B": { "max_tokens": 40960, @@ -10915,7 +11036,8 @@ "output_cost_per_token": 2.4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-235B-A22B": { "max_tokens": 40960, @@ -10925,7 +11047,8 @@ "output_cost_per_token": 5.4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-235B-A22B-Instruct-2507": { "max_tokens": 262144, @@ -10935,7 +11058,8 @@ "output_cost_per_token": 6e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-235B-A22B-Thinking-2507": { "max_tokens": 262144, @@ -10945,7 +11069,8 @@ "output_cost_per_token": 2.9e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-30B-A3B": { "max_tokens": 40960, @@ -10955,7 +11080,8 @@ "output_cost_per_token": 2.9e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-32B": { "max_tokens": 40960, @@ -10965,7 +11091,8 @@ "output_cost_per_token": 2.8e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct": { "max_tokens": 262144, @@ -10975,7 +11102,8 @@ "output_cost_per_token": 1.6e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo": { "max_tokens": 262144, @@ -10985,7 +11113,8 @@ "output_cost_per_token": 1.2e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-Next-80B-A3B-Instruct": { "max_tokens": 262144, @@ -10995,7 +11124,8 @@ "output_cost_per_token": 1.4e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Qwen/Qwen3-Next-80B-A3B-Thinking": { "max_tokens": 262144, @@ -11005,7 +11135,8 @@ "output_cost_per_token": 1.4e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/Sao10K/L3-8B-Lunaris-v1-Turbo": { "max_tokens": 8192, @@ -11056,7 +11187,8 @@ "cache_read_input_token_cost": 3.3e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/anthropic/claude-4-opus": { "max_tokens": 200000, @@ -11066,7 +11198,8 @@ "output_cost_per_token": 8.25e-05, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/anthropic/claude-4-sonnet": { "max_tokens": 200000, @@ -11076,7 +11209,8 @@ "output_cost_per_token": 1.65e-05, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-R1": { "max_tokens": 163840, @@ -11086,7 +11220,8 @@ "output_cost_per_token": 2.4e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-R1-0528": { "max_tokens": 163840, @@ -11097,7 +11232,8 @@ "cache_read_input_token_cost": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-R1-0528-Turbo": { "max_tokens": 32768, @@ -11107,7 +11243,8 @@ "output_cost_per_token": 3e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-R1-Distill-Llama-70B": { "max_tokens": 131072, @@ -11127,7 +11264,8 @@ "output_cost_per_token": 2.7e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-R1-Turbo": { "max_tokens": 40960, @@ -11137,7 +11275,8 @@ "output_cost_per_token": 3e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-V3": { "max_tokens": 163840, @@ -11147,7 +11286,8 @@ "output_cost_per_token": 8.9e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-V3-0324": { "max_tokens": 163840, @@ -11157,7 +11297,8 @@ "output_cost_per_token": 8.8e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-V3.1": { "max_tokens": 163840, @@ -11169,7 +11310,8 @@ "litellm_provider": "deepinfra", "mode": "chat", "supports_tool_choice": true, - "supports_reasoning": true + "supports_reasoning": true, + "supports_function_calling": true }, "deepinfra/deepseek-ai/DeepSeek-V3.1-Terminus": { "max_tokens": 163840, @@ -11180,7 +11322,8 @@ "cache_read_input_token_cost": 2.16e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemini-2.0-flash-001": { "deprecation_date": "2026-06-01", @@ -11191,7 +11334,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemini-2.5-flash": { "max_tokens": 1000000, @@ -11201,7 +11345,8 @@ "output_cost_per_token": 2.5e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemini-2.5-pro": { "max_tokens": 1000000, @@ -11211,7 +11356,8 @@ "output_cost_per_token": 1e-05, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemma-3-12b-it": { "max_tokens": 131072, @@ -11221,7 +11367,8 @@ "output_cost_per_token": 1e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemma-3-27b-it": { "max_tokens": 131072, @@ -11231,7 +11378,8 @@ "output_cost_per_token": 1.6e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/google/gemma-3-4b-it": { "max_tokens": 131072, @@ -11241,7 +11389,8 @@ "output_cost_per_token": 8e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Llama-3.2-11B-Vision-Instruct": { "max_tokens": 131072, @@ -11261,7 +11410,8 @@ "output_cost_per_token": 2e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Llama-3.3-70B-Instruct": { "max_tokens": 131072, @@ -11271,7 +11421,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Llama-3.3-70B-Instruct-Turbo": { "max_tokens": 131072, @@ -11281,6 +11432,7 @@ "output_cost_per_token": 3.9e-07, "litellm_provider": "deepinfra", "mode": "chat", + "supports_function_calling": true, "supports_tool_choice": true }, "deepinfra/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8": { @@ -11291,7 +11443,8 @@ "output_cost_per_token": 6e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Llama-4-Scout-17B-16E-Instruct": { "max_tokens": 327680, @@ -11301,7 +11454,8 @@ "output_cost_per_token": 3e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Llama-Guard-3-8B": { "max_tokens": 131072, @@ -11331,7 +11485,8 @@ "output_cost_per_token": 6e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct": { "max_tokens": 131072, @@ -11341,7 +11496,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { "max_tokens": 131072, @@ -11351,7 +11507,8 @@ "output_cost_per_token": 2.8e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct": { "max_tokens": 131072, @@ -11361,7 +11518,8 @@ "output_cost_per_token": 5e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": { "max_tokens": 131072, @@ -11371,7 +11529,8 @@ "output_cost_per_token": 3e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/microsoft/WizardLM-2-8x22B": { "max_tokens": 65536, @@ -11391,7 +11550,8 @@ "output_cost_per_token": 1.4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/mistralai/Mistral-Nemo-Instruct-2407": { "max_tokens": 131072, @@ -11401,7 +11561,8 @@ "output_cost_per_token": 4e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/mistralai/Mistral-Small-24B-Instruct-2501": { "max_tokens": 32768, @@ -11411,7 +11572,8 @@ "output_cost_per_token": 8e-08, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/mistralai/Mistral-Small-3.2-24B-Instruct-2506": { "max_tokens": 128000, @@ -11421,7 +11583,8 @@ "output_cost_per_token": 2e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/mistralai/Mixtral-8x7B-Instruct-v0.1": { "max_tokens": 32768, @@ -11431,7 +11594,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/moonshotai/Kimi-K2-Instruct": { "max_tokens": 131072, @@ -11441,7 +11605,8 @@ "output_cost_per_token": 2e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/moonshotai/Kimi-K2-Instruct-0905": { "max_tokens": 262144, @@ -11452,7 +11617,8 @@ "cache_read_input_token_cost": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/nvidia/Llama-3.1-Nemotron-70B-Instruct": { "max_tokens": 131072, @@ -11462,7 +11628,8 @@ "output_cost_per_token": 6e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/nvidia/Llama-3.3-Nemotron-Super-49B-v1.5": { "max_tokens": 131072, @@ -11472,7 +11639,8 @@ "output_cost_per_token": 4e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/nvidia/NVIDIA-Nemotron-Nano-9B-v2": { "max_tokens": 131072, @@ -11482,7 +11650,8 @@ "output_cost_per_token": 1.6e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/openai/gpt-oss-120b": { "max_tokens": 131072, @@ -11492,7 +11661,8 @@ "output_cost_per_token": 4.5e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/openai/gpt-oss-20b": { "max_tokens": 131072, @@ -11502,7 +11672,8 @@ "output_cost_per_token": 1.5e-07, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepinfra/zai-org/GLM-4.5": { "max_tokens": 131072, @@ -11512,7 +11683,8 @@ "output_cost_per_token": 1.6e-06, "litellm_provider": "deepinfra", "mode": "chat", - "supports_tool_choice": true + "supports_tool_choice": true, + "supports_function_calling": true }, "deepseek/deepseek-chat": { "cache_creation_input_token_cost": 0.0, @@ -25806,6 +25978,30 @@ "supports_vision": true, "tool_use_system_prompt_tokens": 159 }, + "openrouter/anthropic/claude-sonnet-4.6": { + "cache_creation_input_token_cost": 3.75e-06, + "cache_creation_input_token_cost_above_200k_tokens": 7.5e-06, + "cache_read_input_token_cost": 3e-07, + "cache_read_input_token_cost_above_200k_tokens": 6e-07, + "input_cost_per_token": 3e-06, + "input_cost_per_token_above_200k_tokens": 6e-06, + "litellm_provider": "openrouter", + "max_input_tokens": 1000000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "output_cost_per_token_above_200k_tokens": 2.25e-05, + "source": "https://openrouter.ai/anthropic/claude-sonnet-4.6", + "supports_assistant_prefill": true, + "supports_computer_use": true, + "supports_function_calling": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true, + "tool_use_system_prompt_tokens": 159 + }, "openrouter/anthropic/claude-opus-4.5": { "cache_creation_input_token_cost": 6.25e-06, "cache_read_input_token_cost": 5e-07, @@ -26156,6 +26352,39 @@ "supports_web_search": true, "tpm": 800000 }, + "openrouter/google/gemini-3.1-pro-preview": { + "cache_read_input_token_cost": 2e-07, + "cache_read_input_token_cost_above_200k_tokens": 4e-07, + "cache_creation_input_token_cost_above_200k_tokens": 2.5e-07, + "input_cost_per_token": 2e-06, + "input_cost_per_token_above_200k_tokens": 4e-06, + "litellm_provider": "openrouter", + "max_input_tokens": 1048576, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 1.2e-05, + "output_cost_per_token_above_200k_tokens": 1.8e-05, + "source": "https://openrouter.ai/google/gemini-3.1-pro-preview", + "supported_modalities": [ + "text", + "image", + "audio", + "video" + ], + "supported_output_modalities": [ + "text" + ], + "supports_audio_input": true, + "supports_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "openrouter/gryphe/mythomax-l2-13b": { "input_cost_per_token": 1.875e-06, "litellm_provider": "openrouter", @@ -26533,6 +26762,29 @@ "supports_reasoning": true, "supports_tool_choice": true }, + "openrouter/openai/gpt-5.1-codex-max": { + "cache_read_input_token_cost": 1.25e-07, + "input_cost_per_token": 1.25e-06, + "litellm_provider": "openrouter", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1e-05, + "source": "https://openrouter.ai/openai/gpt-5.1-codex-max", + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text" + ], + "supports_function_calling": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "supports_vision": true + }, "openrouter/openai/gpt-5.2": { "input_cost_per_image": 0, "cache_read_input_token_cost": 1.75e-07, @@ -26687,6 +26939,19 @@ "supports_tool_choice": true, "supports_function_calling": true }, + "openrouter/qwen/qwen3-coder-plus": { + "input_cost_per_token": 1e-06, + "litellm_provider": "openrouter", + "max_input_tokens": 997952, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 5e-06, + "source": "https://openrouter.ai/qwen/qwen3-coder-plus", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true + }, "openrouter/qwen/qwen3-235b-a22b-2507": { "input_cost_per_token": 7.1e-08, "litellm_provider": "openrouter", @@ -26822,6 +27087,19 @@ "supports_vision": true, "supports_prompt_caching": false }, + "openrouter/z-ai/glm-5": { + "input_cost_per_token": 8e-07, + "litellm_provider": "openrouter", + "max_input_tokens": 202752, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 2.56e-06, + "source": "https://openrouter.ai/z-ai/glm-5", + "supports_function_calling": true, + "supports_reasoning": true, + "supports_tool_choice": true + }, "openrouter/minimax/minimax-m2.1": { "input_cost_per_token": 2.7e-07, "output_cost_per_token": 1.2e-06, @@ -34327,6 +34605,36 @@ "supports_tool_choice": true, "source": "https://aws.amazon.com/bedrock/pricing/" }, + "zai/glm-5": { + "cache_creation_input_token_cost": 0, + "cache_read_input_token_cost": 2e-07, + "input_cost_per_token": 1e-06, + "output_cost_per_token": 3.2e-06, + "litellm_provider": "zai", + "max_input_tokens": 200000, + "max_output_tokens": 128000, + "mode": "chat", + "supports_function_calling": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "source": "https://docs.z.ai/guides/overview/pricing" + }, + "zai/glm-5-code": { + "cache_creation_input_token_cost": 0, + "cache_read_input_token_cost": 3e-07, + "input_cost_per_token": 1.2e-06, + "output_cost_per_token": 5e-06, + "litellm_provider": "zai", + "max_input_tokens": 200000, + "max_output_tokens": 128000, + "mode": "chat", + "supports_function_calling": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_tool_choice": true, + "source": "https://docs.z.ai/guides/overview/pricing" + }, "zai/glm-4.7": { "cache_creation_input_token_cost": 0, "cache_read_input_token_cost": 1.1e-07, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 6b84d90a32..508c1c9465 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -23,33 +23,11 @@ model_list: guardrails: - - guardrail_name: "airline-competitor-intent" - guardrail_id: "airline-competitor-intent" + - guardrail_name: "tool_policy" litellm_params: - guardrail: litellm_content_filter - mode: pre_call - default_on: false - competitor_intent_config: - brand_self: - - emirates - - ek - competitors: - - qatar airways - - qatar - - etihad - locations: - - qatar - - doha - - doh - competitor_aliases: - qatar airways: [qr, doha airline] - qatar: [qr] - policy: - competitor_comparison: refuse - possible_competitor_comparison: reframe - threshold_high: 0.70 - threshold_medium: 0.45 - threshold_low: 0.30 + guardrail: tool_policy + mode: [pre_call, post_call] + default_on: true mcp_servers: my_http_server: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8d49020461..42b48446e7 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -77,6 +77,7 @@ class SupportedDBObjectType(str, enum.Enum): PASS_THROUGH_ENDPOINTS = "pass_through_endpoints" PROMPTS = "prompts" MODEL_COST_MAP = "model_cost_map" + TOOLS = "tools" def __str__(self): return str(self.value) @@ -2133,7 +2134,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): user_header_mappings: Optional[List[UserHeaderMapping]] = None supported_db_objects: Optional[List[SupportedDBObjectType]] = Field( None, - description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map'. If not set, all objects are loaded (default behavior).", + description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map', 'tools'. If not set, all objects are loaded (default behavior).", ) user_mcp_management_mode: Optional[UserMCPManagementMode] = Field( None, @@ -3377,6 +3378,11 @@ class ProxyErrorTypes(str, enum.Enum): Team member is already in team """ + tool_access_denied = "tool_access_denied" + """ + Tool is not in the allowed tools list for this key/team + """ + @classmethod def get_model_access_error_type_for_object( cls, object_type: Literal["key", "user", "team", "org", "project"] @@ -4161,6 +4167,7 @@ class ToolDiscoveryQueueItem(TypedDict, total=False): key_hash: Optional[str] # hash of virtual key that triggered discovery team_id: Optional[str] # team that triggered discovery key_alias: Optional[str] # human-readable key alias + user_agent: Optional[str] # HTTP User-Agent of the caller class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 91ac58215a..79e5f78f68 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -58,6 +58,10 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler +from litellm.proxy.guardrails.tool_name_extraction import ( + TOOL_CAPABLE_CALL_TYPES, + extract_request_tool_names, +) from litellm.proxy.route_llm_request import route_request from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics from litellm.router import Router @@ -220,7 +224,48 @@ async def _run_project_checks( ) -async def common_checks( +async def check_tools_allowlist( + request_body: dict, + valid_token: Optional[UserAPIKeyAuth], + team_object: Optional[LiteLLM_TeamTable], + route: str, +) -> None: + """ + Enforce key/team tool allowlist (metadata.allowed_tools). No DB in hot path — + effective allowlist is read from valid_token.metadata and valid_token.team_metadata. + Raises ProxyException with tool_access_denied if a tool is not allowed. + """ + from litellm.litellm_core_utils.api_route_to_call_types import ( + get_call_types_for_route, + ) + + if valid_token is None: + return + call_types = get_call_types_for_route(route) + if not call_types or not any(ct.value in TOOL_CAPABLE_CALL_TYPES for ct in call_types): + return + tool_names = extract_request_tool_names(route, request_body) + if not tool_names: + return + key_meta = (valid_token.metadata or {}) if isinstance(valid_token.metadata, dict) else {} + team_meta = (valid_token.team_metadata or {}) if isinstance(valid_token.team_metadata, dict) else {} + key_allowed = key_meta.get("allowed_tools") + team_allowed = team_meta.get("allowed_tools") + effective = key_allowed if (isinstance(key_allowed, list) and len(key_allowed) > 0) else team_allowed + if not isinstance(effective, list) or len(effective) == 0: + return + allowed_set = {str(t) for t in effective} + disallowed = [n for n in tool_names if n not in allowed_set] + if disallowed: + raise ProxyException( + message=f"Tool(s) {disallowed} are not in the allowed tools list for this key/team.", + type=ProxyErrorTypes.tool_access_denied, + param="tools", + code=status.HTTP_403_FORBIDDEN, + ) + + +async def common_checks( # noqa: PLR0915 request_body: dict, team_object: Optional[LiteLLM_TeamTable], user_object: Optional[LiteLLM_UserTable], @@ -477,6 +522,14 @@ async def common_checks( valid_token=valid_token, ) + # 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path) + await check_tools_allowlist( + request_body=request_body, + valid_token=valid_token, + team_object=team_object, + route=route, + ) + return True diff --git a/litellm/proxy/db/db_spend_update_writer.py b/litellm/proxy/db/db_spend_update_writer.py index 0c25424cea..4c96e079c9 100644 --- a/litellm/proxy/db/db_spend_update_writer.py +++ b/litellm/proxy/db/db_spend_update_writer.py @@ -13,49 +13,36 @@ import random import time import traceback from datetime import datetime, timedelta, timezone -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - Literal, - Optional, - Union, - cast, - overload, -) +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, + cast, overload) import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache, RedisCache from litellm.constants import DB_SPEND_UPDATE_JOB_NAME from litellm.litellm_core_utils.safe_json_loads import safe_json_loads -from litellm.proxy._types import ( - DB_CONNECTION_ERROR_TYPES, - BaseDailySpendTransaction, - DailyAgentSpendTransaction, - DailyEndUserSpendTransaction, - DailyOrganizationSpendTransaction, - DailyTagSpendTransaction, - DailyTeamSpendTransaction, - DailyUserSpendTransaction, - DBSpendUpdateTransactions, - Litellm_EntityType, - LiteLLM_UserTable, - SpendLogsMetadata, - SpendLogsPayload, - SpendUpdateQueueItem, - ToolDiscoveryQueueItem, -) -from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( - DailySpendUpdateQueue, -) -from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager -from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer -from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue -from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import ( - ToolDiscoveryQueue, -) +from litellm.proxy._types import (DB_CONNECTION_ERROR_TYPES, + BaseDailySpendTransaction, + DailyAgentSpendTransaction, + DailyEndUserSpendTransaction, + DailyOrganizationSpendTransaction, + DailyTagSpendTransaction, + DailyTeamSpendTransaction, + DailyUserSpendTransaction, + DBSpendUpdateTransactions, + Litellm_EntityType, LiteLLM_UserTable, + SpendLogsMetadata, SpendLogsPayload, + SpendUpdateQueueItem, ToolDiscoveryQueueItem) +from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import \ + DailySpendUpdateQueue +from litellm.proxy.db.db_transaction_queue.pod_lock_manager import \ + PodLockManager +from litellm.proxy.db.db_transaction_queue.redis_update_buffer import \ + RedisUpdateBuffer +from litellm.proxy.db.db_transaction_queue.spend_update_queue import \ + SpendUpdateQueue +from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import \ + ToolDiscoveryQueue from litellm.proxy.route_llm_request import ROUTE_ENDPOINT_MAPPING if TYPE_CHECKING: @@ -104,12 +91,10 @@ class DBSpendUpdateWriter: end_time: Optional[datetime], response_cost: Optional[float], ): - from litellm.proxy.proxy_server import ( - disable_spend_logs, - litellm_proxy_budget_name, - prisma_client, - user_api_key_cache, - ) + from litellm.proxy.proxy_server import (disable_spend_logs, + litellm_proxy_budget_name, + prisma_client, + user_api_key_cache) from litellm.proxy.utils import ProxyUpdateSpend, hash_token try: @@ -124,9 +109,8 @@ class DBSpendUpdateWriter: hashed_token = token ## CREATE SPEND LOG PAYLOAD ## - from litellm.proxy.spend_tracking.spend_tracking_utils import ( - get_logging_payload, - ) + from litellm.proxy.spend_tracking.spend_tracking_utils import \ + get_logging_payload payload = get_logging_payload( kwargs=kwargs, @@ -230,6 +214,7 @@ class DBSpendUpdateWriter: _litellm_params = kwargs.get("litellm_params") or {} _metadata = _litellm_params.get("metadata") or {} key_alias = _metadata.get("user_api_key_alias") or None + user_agent = _metadata.get("user_agent") or None def _enqueue(tool_name: str, origin: str = "user_defined") -> None: self.tool_discovery_queue.add_update( @@ -239,17 +224,20 @@ class DBSpendUpdateWriter: key_hash=hashed_token, team_id=team_id, key_alias=key_alias, + user_agent=user_agent, ) ) # --- MCP tool calls --- sl_object = kwargs.get("standard_logging_object") if sl_object is not None: - mcp_metadata = ( - sl_object.get("metadata", {}) or {} - ).get("mcp_tool_call_metadata") + mcp_metadata = (sl_object.get("metadata", {}) or {}).get( + "mcp_tool_call_metadata" + ) if mcp_metadata and isinstance(mcp_metadata, dict): - tool_name = mcp_metadata.get("namespaced_tool_name") or mcp_metadata.get("name") + tool_name = mcp_metadata.get( + "namespaced_tool_name" + ) or mcp_metadata.get("name") mcp_server_name = mcp_metadata.get("mcp_server_name") if tool_name: _enqueue(tool_name, origin=mcp_server_name or "user_defined") @@ -280,7 +268,9 @@ class DBSpendUpdateWriter: _enqueue(name) # --- Response tool_calls (OpenAI format; Anthropic pass-through converts tool_use here) --- - if completion_response is not None and hasattr(completion_response, "choices"): + if completion_response is not None and hasattr( + completion_response, "choices" + ): for choice in completion_response.choices or []: message = getattr(choice, "message", None) if message is None: @@ -768,19 +758,46 @@ class DBSpendUpdateWriter: daily_end_user_spend_update_transactions, daily_agent_spend_update_transactions, daily_tag_spend_update_transactions, - ) = await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline() + ) = ( + await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline() + ) if db_spend_update_transactions is not None: verbose_proxy_logger.info( "Spend tracking - committing spend updates from Redis to DB: " "keys=%d, users=%d, teams=%d, orgs=%d, end_users=%d, team_members=%d, tags=%d", - len(db_spend_update_transactions.get("key_list_transactions") or {}), - len(db_spend_update_transactions.get("user_list_transactions") or {}), - len(db_spend_update_transactions.get("team_list_transactions") or {}), - len(db_spend_update_transactions.get("org_list_transactions") or {}), - len(db_spend_update_transactions.get("end_user_list_transactions") or {}), - len(db_spend_update_transactions.get("team_member_list_transactions") or {}), - len(db_spend_update_transactions.get("tag_list_transactions") or {}), + len( + db_spend_update_transactions.get("key_list_transactions") + or {} + ), + len( + db_spend_update_transactions.get("user_list_transactions") + or {} + ), + len( + db_spend_update_transactions.get("team_list_transactions") + or {} + ), + len( + db_spend_update_transactions.get("org_list_transactions") + or {} + ), + len( + db_spend_update_transactions.get( + "end_user_list_transactions" + ) + or {} + ), + len( + db_spend_update_transactions.get( + "team_member_list_transactions" + ) + or {} + ), + len( + db_spend_update_transactions.get("tag_list_transactions") + or {} + ), ) await self._commit_spend_updates_to_db( prisma_client=prisma_client, @@ -985,10 +1002,8 @@ class DBSpendUpdateWriter: Commits all the spend `UPDATE` transactions to the Database """ - from litellm.proxy.utils import ( - ProxyUpdateSpend, - _raise_failed_update_spend_exception, - ) + from litellm.proxy.utils import (ProxyUpdateSpend, + _raise_failed_update_spend_exception) ### UPDATE USER TABLE ### user_list_transactions = db_spend_update_transactions["user_list_transactions"] @@ -1523,14 +1538,14 @@ class DBSpendUpdateWriter: # Add cache-related fields if they exist if "cache_read_input_tokens" in transaction: - common_data[ - "cache_read_input_tokens" - ] = transaction.get("cache_read_input_tokens", 0) + common_data["cache_read_input_tokens"] = ( + transaction.get("cache_read_input_tokens", 0) + ) if "cache_creation_input_tokens" in transaction: - common_data[ - "cache_creation_input_tokens" - ] = transaction.get( - "cache_creation_input_tokens", 0 + common_data["cache_creation_input_tokens"] = ( + transaction.get( + "cache_creation_input_tokens", 0 + ) ) if entity_type == "tag" and "request_id" in transaction: diff --git a/litellm/proxy/db/spend_log_tool_index.py b/litellm/proxy/db/spend_log_tool_index.py new file mode 100644 index 0000000000..6e8c63675e --- /dev/null +++ b/litellm/proxy/db/spend_log_tool_index.py @@ -0,0 +1,147 @@ +""" +Track tool usage for the dashboard: insert into SpendLogToolIndex when spend logs +are written, so "last N requests for tool X" and "how is this tool called in production" +queries are fast. +""" + +from datetime import datetime, timezone +from typing import Any, Dict, List, Set + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.safe_json_loads import safe_json_loads +from litellm.proxy.utils import PrismaClient + + +def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None: + """Extract tool names from OpenAI-style tool_calls list into out.""" + if not isinstance(tool_calls, list): + return + for tc in tool_calls: + if not isinstance(tc, dict): + continue + fn = tc.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if name and isinstance(name, str) and name.strip(): + out.add(name.strip()) + + +def _parse_tool_names_from_payload(payload: Dict[str, Any]) -> Set[str]: + """ + Extract deduplicated tool names from a spend log payload. + Sources: mcp_namespaced_tool_name, response (tool_calls), proxy_server_request (tools). + """ + tool_names: Set[str] = set() + + # Top-level MCP tool name (single tool per request for that flow) + mcp_name = payload.get("mcp_namespaced_tool_name") + if mcp_name and isinstance(mcp_name, str) and mcp_name.strip(): + tool_names.add(mcp_name.strip()) + + # Response: OpenAI-style tool_calls[].function.name or choices[0].message.tool_calls + response_raw = payload.get("response") + if response_raw: + response_obj = ( + safe_json_loads(response_raw, default=None) + if isinstance(response_raw, str) + else response_raw + ) + if isinstance(response_obj, dict): + _add_tool_calls_to_set(response_obj.get("tool_calls"), tool_names) + choices = response_obj.get("choices") + if isinstance(choices, list) and choices: + msg = choices[0].get("message") if isinstance(choices[0], dict) else None + if isinstance(msg, dict): + _add_tool_calls_to_set(msg.get("tool_calls"), tool_names) + + # Request body: tools[].function.name + request_raw = payload.get("proxy_server_request") + if request_raw: + request_obj = ( + safe_json_loads(request_raw, default=None) + if isinstance(request_raw, str) + else request_raw + ) + if isinstance(request_obj, dict): + body = request_obj.get("body", request_obj) + if isinstance(body, dict): + request_obj = body + if isinstance(request_obj, dict): + tools = request_obj.get("tools") + if isinstance(tools, list): + for t in tools: + if isinstance(t, dict): + fn = t.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if name and isinstance(name, str) and name.strip(): + tool_names.add(name.strip()) + + return tool_names + + +async def process_spend_logs_tool_usage( + prisma_client: PrismaClient, + logs_to_process: List[Dict[str, Any]], +) -> None: + """ + After spend logs are written: insert SpendLogToolIndex rows from each payload. + Extracts tool names from mcp_namespaced_tool_name, response tool_calls, and + proxy_server_request tools. + """ + if not logs_to_process: + return + + index_rows: List[Dict[str, Any]] = [] + + for payload in logs_to_process: + request_id = payload.get("request_id") + start_time = payload.get("startTime") + if not request_id or not start_time: + continue + if isinstance(start_time, str): + try: + start_time = datetime.fromisoformat( + start_time.replace("Z", "+00:00") + ) + except (ValueError, TypeError): + continue + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + + tool_names = _parse_tool_names_from_payload(payload) + for tool_name in tool_names: + index_rows.append({ + "request_id": request_id, + "tool_name": tool_name, + "start_time": start_time, + }) + + if not index_rows: + return + + try: + index_data = [] + for r in index_rows: + st = r["start_time"] + if isinstance(st, str): + try: + st = datetime.fromisoformat(st.replace("Z", "+00:00")) + except (ValueError, TypeError): + continue + if st.tzinfo is None: + st = st.replace(tzinfo=timezone.utc) + index_data.append({ + "request_id": r["request_id"], + "tool_name": r["tool_name"], + "start_time": st, + }) + if index_data: + await prisma_client.db.litellm_spendlogtoolindex.create_many( + data=index_data, + skip_duplicates=True, + ) + except Exception as e: + verbose_proxy_logger.warning( + "Tool usage tracking (SpendLogToolIndex) failed (non-fatal): %s", e + ) diff --git a/litellm/proxy/db/tool_registry_writer.py b/litellm/proxy/db/tool_registry_writer.py index 4e0a8095a0..0eda012d51 100644 --- a/litellm/proxy/db/tool_registry_writer.py +++ b/litellm/proxy/db/tool_registry_writer.py @@ -2,36 +2,64 @@ DB helpers for LiteLLM_ToolTable — the global tool registry. Tools are auto-discovered from LLM responses and upserted here. -Admins use the management endpoints to read and update call_policy. - -NOTE: Uses raw SQL (query_raw / execute_raw) instead of Prisma model methods -because the generated Prisma Python client may not have LiteLLM_ToolTable -when running against an older generated schema. +Admins use the management endpoints to read and update input_policy / output_policy. """ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ToolDiscoveryQueueItem -from litellm.types.tool_management import LiteLLM_ToolTableRow, ToolCallPolicy +from litellm.types.tool_management import ( + LiteLLM_ToolTableRow, + ToolPolicyOverrideRow, +) if TYPE_CHECKING: from litellm.proxy.utils import PrismaClient -def _row_to_model(row: dict) -> LiteLLM_ToolTableRow: +def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow: + """Convert a Prisma model instance or dict to LiteLLM_ToolTableRow.""" + model_dump = getattr(row, "model_dump", None) + if callable(model_dump): + row = model_dump() + elif not isinstance(row, dict): + row = { + k: getattr(row, k, None) + for k in ( + "tool_id", + "tool_name", + "origin", + "input_policy", + "output_policy", + "call_count", + "assignments", + "key_hash", + "team_id", + "key_alias", + "user_agent", + "last_used_at", + "created_at", + "updated_at", + "created_by", + "updated_by", + ) + } return LiteLLM_ToolTableRow( tool_id=row.get("tool_id", ""), tool_name=row.get("tool_name", ""), origin=row.get("origin"), - call_policy=row.get("call_policy", "untrusted"), + input_policy=row.get("input_policy") or "untrusted", + output_policy=row.get("output_policy") or "untrusted", call_count=int(row.get("call_count") or 0), assignments=row.get("assignments"), key_hash=row.get("key_hash"), team_id=row.get("team_id"), key_alias=row.get("key_alias"), + user_agent=row.get("user_agent"), + last_used_at=row.get("last_used_at"), created_at=row.get("created_at"), updated_at=row.get("updated_at"), created_by=row.get("created_by"), @@ -44,10 +72,10 @@ async def batch_upsert_tools( items: List[ToolDiscoveryQueueItem], ) -> None: """ - Batch-upsert tool registry rows via raw SQL. + Batch-upsert tool registry rows via Prisma. - On first insert: sets call_policy = "untrusted" (schema default), call_count = 1. - On conflict: increments call_count; preserves existing call_policy. + On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1. + On conflict: increments call_count; preserves existing policies. """ if not items: return @@ -55,6 +83,8 @@ async def batch_upsert_tools( data = [item for item in items if item.get("tool_name")] if not data: return + now = datetime.now(timezone.utc) + table = prisma_client.db.litellm_tooltable for item in data: tool_name = item.get("tool_name", "") origin = item.get("origin") or "user_defined" @@ -62,49 +92,52 @@ async def batch_upsert_tools( key_hash = item.get("key_hash") team_id = item.get("team_id") key_alias = item.get("key_alias") - now = datetime.now(timezone.utc).isoformat() - await prisma_client.db.execute_raw( - 'INSERT INTO "LiteLLM_ToolTable" ' - "(tool_id, tool_name, origin, call_policy, call_count, created_by, updated_by, key_hash, team_id, key_alias, created_at, updated_at) " - "VALUES ($7, $1, $2, 'untrusted', 1, $3, $3, $4, $5, $6, $8, $8) " - "ON CONFLICT (tool_name) DO UPDATE SET " - "call_count = \"LiteLLM_ToolTable\".call_count + 1, " - "updated_at = $8", - tool_name, - origin, - created_by, - key_hash, - team_id, - key_alias, - str(uuid.uuid4()), - now, + user_agent = item.get("user_agent") + await table.upsert( + where={"tool_name": tool_name}, + data={ + "create": { + "tool_id": str(uuid.uuid4()), + "tool_name": tool_name, + "origin": origin, + "input_policy": "untrusted", + "output_policy": "untrusted", + "call_count": 1, + "created_by": created_by, + "updated_by": created_by, + "key_hash": key_hash, + "team_id": team_id, + "key_alias": key_alias, + "user_agent": user_agent, + "last_used_at": now, + }, + "update": { + "call_count": {"increment": 1}, + "updated_at": now, + "last_used_at": now, + }, + }, ) verbose_proxy_logger.debug( "tool_registry_writer: upserted %d tool(s)", len(data) ) except Exception as e: - verbose_proxy_logger.error("tool_registry_writer batch_upsert_tools error: %s", e) + verbose_proxy_logger.error( + "tool_registry_writer batch_upsert_tools error: %s", e + ) async def list_tools( prisma_client: "PrismaClient", - call_policy: Optional[ToolCallPolicy] = None, + input_policy: Optional[str] = None, ) -> List[LiteLLM_ToolTableRow]: - """Return all tools, optionally filtered by call_policy.""" + """Return all tools, optionally filtered by input_policy.""" try: - if call_policy is not None: - rows = await prisma_client.db.query_raw( - 'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, ' - 'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by ' - 'FROM "LiteLLM_ToolTable" WHERE call_policy = $1 ORDER BY created_at DESC', - call_policy, - ) - else: - rows = await prisma_client.db.query_raw( - 'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, ' - 'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by ' - 'FROM "LiteLLM_ToolTable" ORDER BY created_at DESC', - ) + where = {"input_policy": input_policy} if input_policy is not None else {} + rows = await prisma_client.db.litellm_tooltable.find_many( + where=where, + order={"created_at": "desc"}, + ) return [_row_to_model(row) for row in rows] except Exception as e: verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e) @@ -117,15 +150,12 @@ async def get_tool( ) -> Optional[LiteLLM_ToolTableRow]: """Return a single tool row by tool_name.""" try: - rows = await prisma_client.db.query_raw( - 'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, ' - 'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by ' - 'FROM "LiteLLM_ToolTable" WHERE tool_name = $1', - tool_name, + row = await prisma_client.db.litellm_tooltable.find_unique( + where={"tool_name": tool_name}, ) - if not rows: + if row is None: return None - return _row_to_model(rows[0]) + return _row_to_model(row) except Exception as e: verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e) return None @@ -134,46 +164,279 @@ async def get_tool( async def update_tool_policy( prisma_client: "PrismaClient", tool_name: str, - call_policy: ToolCallPolicy, updated_by: Optional[str], + input_policy: Optional[str] = None, + output_policy: Optional[str] = None, ) -> Optional[LiteLLM_ToolTableRow]: - """Update the call_policy for a tool. Upserts the row if it does not exist yet.""" + """Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet.""" try: _updated_by = updated_by or "system" - now = datetime.now(timezone.utc).isoformat() - await prisma_client.db.execute_raw( - 'INSERT INTO "LiteLLM_ToolTable" (tool_id, tool_name, call_policy, created_by, updated_by, created_at, updated_at) ' - "VALUES ($4, $1, $2, $3, $3, $5, $5) " - "ON CONFLICT (tool_name) DO UPDATE SET call_policy = $2, updated_by = $3, updated_at = $5", - tool_name, - call_policy, - _updated_by, - str(uuid.uuid4()), - now, + now = datetime.now(timezone.utc) + + create_data: dict = { + "tool_id": str(uuid.uuid4()), + "tool_name": tool_name, + "input_policy": input_policy or "untrusted", + "output_policy": output_policy or "untrusted", + "created_by": _updated_by, + "updated_by": _updated_by, + "created_at": now, + "updated_at": now, + } + update_data: dict = { + "updated_by": _updated_by, + "updated_at": now, + } + if input_policy is not None: + update_data["input_policy"] = input_policy + if output_policy is not None: + update_data["output_policy"] = output_policy + + await prisma_client.db.litellm_tooltable.upsert( + where={"tool_name": tool_name}, + data={ + "create": create_data, + "update": update_data, + }, ) return await get_tool(prisma_client, tool_name) except Exception as e: - verbose_proxy_logger.error("tool_registry_writer update_tool_policy error: %s", e) + verbose_proxy_logger.error( + "tool_registry_writer update_tool_policy error: %s", e + ) return None async def get_tools_by_names( prisma_client: "PrismaClient", tool_names: List[str], -) -> Dict[str, str]: +) -> Dict[str, Tuple[str, str]]: """ - Return a {tool_name: call_policy} map for the given tool names. - Used by the policy enforcement guardrail — single batch query, never N+1. + Return a {tool_name: (input_policy, output_policy)} map for the given tool names. """ if not tool_names: return {} try: - placeholders = ", ".join(f"${i+1}" for i in range(len(tool_names))) - rows = await prisma_client.db.query_raw( - f'SELECT tool_name, call_policy FROM "LiteLLM_ToolTable" WHERE tool_name IN ({placeholders})', - *tool_names, + rows = await prisma_client.db.litellm_tooltable.find_many( + where={"tool_name": {"in": tool_names}}, ) - return {row["tool_name"]: row["call_policy"] for row in rows} + return { + row.tool_name: ( + getattr(row, "input_policy", "untrusted") or "untrusted", + getattr(row, "output_policy", "untrusted") or "untrusted", + ) + for row in rows + } except Exception as e: - verbose_proxy_logger.error("tool_registry_writer get_tools_by_names error: %s", e) + verbose_proxy_logger.error( + "tool_registry_writer get_tools_by_names error: %s", e + ) return {} + + +async def list_overrides_for_tool( + prisma_client: "PrismaClient", + tool_name: str, +) -> List[ToolPolicyOverrideRow]: + """ + Return override-like rows for a tool by finding object permissions that have + this tool in blocked_tools, then resolving each permission to key/team scope for display. + """ + out: List[ToolPolicyOverrideRow] = [] + try: + perms = await prisma_client.db.litellm_objectpermissiontable.find_many( + where={"blocked_tools": {"has": tool_name}}, + include={ + "verification_tokens": True, + "teams": True, + }, + ) + for perm in perms: + op_id = getattr(perm, "object_permission_id", None) or "" + tokens = getattr(perm, "verification_tokens", []) or [] + teams = getattr(perm, "teams", []) or [] + for t in tokens: + out.append( + ToolPolicyOverrideRow( + override_id=op_id, + tool_name=tool_name, + team_id=None, + key_hash=getattr(t, "token", None), + input_policy="blocked", + key_alias=getattr(t, "key_alias", None), + created_at=None, + updated_at=None, + ) + ) + for team in teams: + out.append( + ToolPolicyOverrideRow( + override_id=op_id, + tool_name=tool_name, + team_id=getattr(team, "team_id", None), + key_hash=None, + input_policy="blocked", + key_alias=getattr(team, "team_alias", None), + created_at=None, + updated_at=None, + ) + ) + return out + except Exception as e: + verbose_proxy_logger.error( + "tool_registry_writer list_overrides_for_tool error: %s", e + ) + return [] + + +class ToolPolicyRegistry: + """ + In-memory registry of tool policies synced from DB. + Hot path uses get_effective_policies only — no DB, no cache. + """ + + def __init__(self) -> None: + self._tool_input_policies: Dict[str, str] = {} + self._tool_output_policies: Dict[str, str] = {} + self._blocked_tools_by_op_id: Dict[str, List[str]] = {} + self._initialized: bool = False + + def is_initialized(self) -> bool: + return self._initialized + + async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None: + """Load all tool policies and object-permission blocked_tools from DB.""" + try: + tools = await prisma_client.db.litellm_tooltable.find_many() + self._tool_input_policies = { + row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted" + for row in tools + } + self._tool_output_policies = { + row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted" + for row in tools + } + + perms = await prisma_client.db.litellm_objectpermissiontable.find_many() + self._blocked_tools_by_op_id = {} + for row in perms: + op_id = getattr(row, "object_permission_id", None) + blocked = getattr(row, "blocked_tools", None) or [] + if op_id: + self._blocked_tools_by_op_id[op_id] = list(blocked) + + self._initialized = True + verbose_proxy_logger.info( + "ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB", + len(self._tool_input_policies), + len(self._blocked_tools_by_op_id), + ) + except Exception as e: + verbose_proxy_logger.exception( + "ToolPolicyRegistry sync_tool_policy_from_db error: %s", e + ) + raise + + def get_input_policy(self, tool_name: str) -> str: + return self._tool_input_policies.get(tool_name, "untrusted") + + def get_output_policy(self, tool_name: str) -> str: + return self._tool_output_policies.get(tool_name, "untrusted") + + def get_effective_policies( + self, + tool_names: List[str], + object_permission_id: Optional[str] = None, + team_object_permission_id: Optional[str] = None, + ) -> Dict[str, str]: + """ + Return effective input_policy per tool from in-memory state. + If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted". + """ + if not tool_names: + return {} + blocked: set = set() + for op_id in (object_permission_id, team_object_permission_id): + if op_id and op_id.strip(): + blocked.update( + self._blocked_tools_by_op_id.get(op_id.strip(), []) + ) + result: Dict[str, str] = {} + for name in tool_names: + if name in blocked: + result[name] = "blocked" + else: + result[name] = self._tool_input_policies.get(name, "untrusted") + return result + + +_tool_policy_registry: Optional[ToolPolicyRegistry] = None + + +def get_tool_policy_registry() -> ToolPolicyRegistry: + """Return the global ToolPolicyRegistry singleton.""" + global _tool_policy_registry + if _tool_policy_registry is None: + _tool_policy_registry = ToolPolicyRegistry() + return _tool_policy_registry + + +async def add_tool_to_object_permission_blocked( + prisma_client: "PrismaClient", + object_permission_id: str, + tool_name: str, +) -> bool: + """Add tool_name to the permission's blocked_tools if not already present.""" + if not object_permission_id or not tool_name: + return False + try: + row = await prisma_client.db.litellm_objectpermissiontable.find_unique( + where={"object_permission_id": object_permission_id}, + ) + if row is None: + return False + current = list(getattr(row, "blocked_tools", []) or []) + if tool_name in current: + return True + current.append(tool_name) + await prisma_client.db.litellm_objectpermissiontable.update( + where={"object_permission_id": object_permission_id}, + data={"blocked_tools": current}, + ) + return True + except Exception as e: + verbose_proxy_logger.error( + "tool_registry_writer add_tool_to_object_permission_blocked error: %s", e + ) + return False + + +async def remove_tool_from_object_permission_blocked( + prisma_client: "PrismaClient", + object_permission_id: str, + tool_name: str, +) -> bool: + """Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list.""" + if not object_permission_id or not tool_name: + return False + try: + row = await prisma_client.db.litellm_objectpermissiontable.find_unique( + where={"object_permission_id": object_permission_id}, + ) + if row is None: + return False + current = list(getattr(row, "blocked_tools", []) or []) + if tool_name not in current: + return False + current = [t for t in current if t != tool_name] + await prisma_client.db.litellm_objectpermissiontable.update( + where={"object_permission_id": object_permission_id}, + data={"blocked_tools": current}, + ) + return True + except Exception as e: + verbose_proxy_logger.error( + "tool_registry_writer remove_tool_from_object_permission_blocked error: %s", + e, + ) + return False diff --git a/litellm/proxy/guardrails/guardrail_hooks/tool_policy/tool_policy_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/tool_policy/tool_policy_guardrail.py index 87558566c4..368948414e 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/tool_policy/tool_policy_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/tool_policy/tool_policy_guardrail.py @@ -1,13 +1,16 @@ """ Tool Policy Guardrail -Reads call_policy from LiteLLM_ToolTable and enforces it on LLM requests/responses. +Reads input_policy / output_policy from LiteLLM_ToolTable and enforces them. -Policy values: - "trusted" - allow through (no action) - "untrusted" - allow through (no action; default for newly discovered tools) +Input policy values: + "untrusted" - allow through (default for newly discovered tools) + "trusted" - only allow if conversation contains no untrusted tool output "blocked" - raise HTTPException, preventing the tool call - "dual_llm" - (Phase 3) send to second LLM for verification; currently treated as allowed + +Output policy values: + "untrusted" - output may be tainted (default) + "trusted" - output is verified safe Configuration in proxy config YAML: guardrails: @@ -15,25 +18,18 @@ Configuration in proxy config YAML: litellm_params: guardrail: tool_policy mode: post_call - -or both pre and post call: - - guardrail_name: "tool_policy" - litellm_params: - guardrail: tool_policy - mode: during_call # runs before LLM and on response """ -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm.caching.dual_cache import DualCache -from litellm.constants import TOOL_POLICY_CACHE_TTL_SECONDS from litellm.integrations.custom_guardrail import ( CustomGuardrail, log_guardrail_information, ) +from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names from litellm.types.guardrails import GuardrailEventHooks from litellm.types.utils import GenericGuardrailAPIInputs @@ -43,12 +39,71 @@ if TYPE_CHECKING: GUARDRAIL_NAME = "tool_policy" +def _get_request_object_permission_ids( + request_data: dict, +) -> Tuple[Optional[str], Optional[str]]: + """Extract object_permission_id and team_object_permission_id from request_data.""" + if not request_data: + return None, None + for key in ("litellm_metadata", "metadata"): + meta = request_data.get(key) + if not isinstance(meta, dict): + continue + auth = meta.get("user_api_key_auth") + if auth is not None and hasattr(auth, "object_permission_id"): + key_op = getattr(auth, "object_permission_id", None) + team_op = getattr(auth, "team_object_permission_id", None) + if key_op is not None or team_op is not None: + return ( + str(key_op).strip() if key_op else None, + str(team_op).strip() if team_op else None, + ) + key_op = meta.get("user_api_key_object_permission_id") + team_op = meta.get("user_api_key_team_object_permission_id") + if key_op is not None or team_op is not None: + return ( + str(key_op).strip() if key_op else None, + str(team_op).strip() if team_op else None, + ) + return None, None + + +def _get_request_route_from_data(request_data: dict) -> Optional[str]: + """Get request route from request_data (metadata or top-level).""" + route = request_data.get("user_api_key_request_route") + if route: + return route + meta = request_data.get("metadata") or request_data.get("litellm_metadata") or {} + return meta.get("user_api_key_request_route") + + +def _resolve_tool_names_from_messages(messages: List[dict]) -> Dict[str, str]: + """ + Build a map of tool_call_id -> tool_name from assistant messages' tool_calls. + Used to resolve which tool produced each tool result in the conversation. + """ + mapping: Dict[str, str] = {} + for msg in messages: + if msg.get("role") != "assistant": + continue + tool_calls = msg.get("tool_calls") or [] + for tc in tool_calls: + if isinstance(tc, dict): + tc_id = tc.get("id") + fn = (tc.get("function") or {}).get("name") + else: + tc_id = getattr(tc, "id", None) + fn_obj = getattr(tc, "function", None) + fn = getattr(fn_obj, "name", None) if fn_obj else None + if tc_id and fn: + mapping[tc_id] = fn + return mapping + + class ToolPolicyGuardrail(CustomGuardrail): """ - Guardrail that enforces per-tool call policies stored in LiteLLM_ToolTable. - - Tools with call_policy="blocked" are rejected before/after the LLM call. - Tools with call_policy="trusted" or "untrusted" pass through unchanged. + Guardrail that enforces per-tool input/output policies from the in-memory + ToolPolicyRegistry (synced from DB). """ def __init__(self, **kwargs: Any) -> None: @@ -59,7 +114,6 @@ class ToolPolicyGuardrail(CustomGuardrail): GuardrailEventHooks.during_call, ] super().__init__(**kwargs) - self._policy_cache: DualCache = DualCache() @log_guardrail_information async def apply_guardrail( @@ -70,12 +124,7 @@ class ToolPolicyGuardrail(CustomGuardrail): logging_obj: Optional["LiteLLMLoggingObj"] = None, ) -> GenericGuardrailAPIInputs: """ - Enforce tool policies on both request tools and response tool_calls. - - - input_type="request": check inputs["tools"] (tool definitions in the LLM request) - - input_type="response": check inputs["tool_calls"] (tool_calls in the LLM response) - - Raises HTTPException (400) if any tool is "blocked". + Enforce input_policy and output_policy trust chain on request tools / response tool_calls. """ if input_type == "request": tools = inputs.get("tools") or [] @@ -86,7 +135,11 @@ class ToolPolicyGuardrail(CustomGuardrail): and isinstance(t.get("function"), dict) and t["function"].get("name") ] - else: # response + if not tool_names: + route = _get_request_route_from_data(request_data) + if route: + tool_names = extract_request_tool_names(route, request_data) + else: tool_calls = inputs.get("tool_calls") or [] tool_names = [] for tc in tool_calls: @@ -101,12 +154,25 @@ class ToolPolicyGuardrail(CustomGuardrail): if not tool_names: return inputs - policy_map = await self._get_policies_cached(tool_names) + object_permission_id, team_object_permission_id = ( + _get_request_object_permission_ids(request_data) + ) + from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry + registry = get_tool_policy_registry() + if not registry.is_initialized(): + return inputs + + # Stage 1: Check for blocked tools (input_policy=blocked or per-key/team override) + policy_map = registry.get_effective_policies( + tool_names, + object_permission_id=object_permission_id, + team_object_permission_id=team_object_permission_id, + ) blocked = [name for name in tool_names if policy_map.get(name) == "blocked"] if blocked: verbose_proxy_logger.warning( - "ToolPolicyGuardrail: blocking tool(s) %s (policy=blocked)", blocked + "ToolPolicyGuardrail: blocking tool(s) %s (input_policy=blocked)", blocked ) raise HTTPException( status_code=400, @@ -117,47 +183,47 @@ class ToolPolicyGuardrail(CustomGuardrail): }, ) + # Stage 2: Trust chain enforcement (response path only) + # For each tool with input_policy=trusted, check if conversation + # contains output from tools with output_policy=untrusted + if input_type == "response": + trusted_input_tools = [ + name for name in tool_names if policy_map.get(name) == "trusted" + ] + if trusted_input_tools: + messages = request_data.get("messages") or [] + tc_id_to_name = _resolve_tool_names_from_messages(messages) + + untrusted_sources: List[str] = [] + for msg in messages: + if msg.get("role") != "tool": + continue + tool_call_id = msg.get("tool_call_id") + source_tool = tc_id_to_name.get(tool_call_id, "") if tool_call_id else "" + if not source_tool: + continue + if registry.get_output_policy(source_tool) == "untrusted": + if source_tool not in untrusted_sources: + untrusted_sources.append(source_tool) + + if untrusted_sources: + verbose_proxy_logger.warning( + "ToolPolicyGuardrail: trust chain violation — %s require trusted input " + "but conversation has untrusted output from %s", + trusted_input_tools, + untrusted_sources, + ) + raise HTTPException( + status_code=400, + detail={ + "error": "Violated tool policy", + "blocked_tools": trusted_input_tools, + "untrusted_sources": untrusted_sources, + "message": ( + f"{', '.join(trusted_input_tools)} requires trusted input but " + f"conversation contains untrusted output from {', '.join(untrusted_sources)}." + ), + }, + ) + return inputs - - async def _get_policies_cached(self, tool_names: List[str]) -> Dict[str, str]: - """ - Batch-fetch call_policy for the given tool names. - - Caches per individual tool name (not per combination) so that adding - a new tool to a request doesn't invalidate the cached policies for all - the other tools already in the cache. - """ - from litellm.proxy.db.tool_registry_writer import get_tools_by_names - from litellm.proxy.proxy_server import prisma_client - - if not tool_names or prisma_client is None: - return {} - - result: Dict[str, str] = {} - cache_misses: List[str] = [] - - for name in tool_names: - cached = await self._policy_cache.async_get_cache(f"tool_policy:{name}") - if cached is not None and isinstance(cached, str): - result[name] = cached - else: - cache_misses.append(name) - - if cache_misses: - fetched = await get_tools_by_names( - prisma_client=prisma_client, tool_names=cache_misses - ) - for name, policy in fetched.items(): - result[name] = policy - await self._policy_cache.async_set_cache( - key=f"tool_policy:{name}", - value=policy, - ttl=TOOL_POLICY_CACHE_TTL_SECONDS, - ) - verbose_proxy_logger.debug( - "ToolPolicyGuardrail: fetched %d policies from DB (cache hits: %d)", - len(cache_misses), - len(tool_names) - len(cache_misses), - ) - - return result diff --git a/litellm/proxy/guardrails/tool_name_extraction.py b/litellm/proxy/guardrails/tool_name_extraction.py new file mode 100644 index 0000000000..db24fa2277 --- /dev/null +++ b/litellm/proxy/guardrails/tool_name_extraction.py @@ -0,0 +1,85 @@ +""" +Extract tool names from request body by route/call type. + +Used by auth (check_tools_allowlist) and ToolPolicyGuardrail so tool-format +knowledge lives in one place. Uses guardrail translation handlers where available, +with standalone extractors for generate_content and MCP. +""" + +from typing import Any, Dict, List + +from litellm.litellm_core_utils.api_route_to_call_types import get_call_types_for_route +from litellm.llms import load_guardrail_translation_mappings +from litellm.types.utils import CallTypes + +# Call types that have no guardrail translation handler; we use standalone extractors +STANDALONE_EXTRACTORS: Dict[str, Any] = {} + + +def _extract_generate_content_tool_names(data: dict) -> List[str]: + """Google generateContent: tools[].functionDeclarations[].name""" + names: List[str] = [] + for tool in data.get("tools") or []: + if not isinstance(tool, dict): + continue + for decl in tool.get("functionDeclarations") or []: + if isinstance(decl, dict) and decl.get("name"): + names.append(str(decl["name"])) + return names + + +def _extract_mcp_tool_names(data: dict) -> List[str]: + """MCP call_tool: name or mcp_tool_name in body""" + names: List[str] = [] + name = data.get("name") or data.get("mcp_tool_name") + if name: + names.append(str(name)) + return names + + +def _register_standalone_extractors() -> None: + if STANDALONE_EXTRACTORS: + return + STANDALONE_EXTRACTORS[CallTypes.generate_content.value] = _extract_generate_content_tool_names + STANDALONE_EXTRACTORS[CallTypes.agenerate_content.value] = _extract_generate_content_tool_names + STANDALONE_EXTRACTORS[CallTypes.call_mcp_tool.value] = _extract_mcp_tool_names + + +# Tool-capable call types (routes that can send tools in the request) +TOOL_CAPABLE_CALL_TYPES = frozenset({ + CallTypes.completion.value, + CallTypes.acompletion.value, + CallTypes.responses.value, + CallTypes.aresponses.value, + CallTypes.anthropic_messages.value, + CallTypes.generate_content.value, + CallTypes.agenerate_content.value, + CallTypes.call_mcp_tool.value, +}) + + +def extract_request_tool_names(route: str, data: dict) -> List[str]: + """ + Extract tool names from the request body for the given route. + Uses guardrail translation handlers when available, else standalone extractors + for generate_content and MCP. Returns [] for non-tool-capable routes or when + no tools are present. + """ + call_types = get_call_types_for_route(route) + if not call_types: + return [] + _register_standalone_extractors() + mappings = load_guardrail_translation_mappings() + for call_type in call_types: + if not isinstance(call_type, CallTypes): + continue + if call_type.value not in TOOL_CAPABLE_CALL_TYPES: + continue + if call_type.value in STANDALONE_EXTRACTORS: + return STANDALONE_EXTRACTORS[call_type.value](data) + handler_cls = mappings.get(call_type) + if handler_cls is not None: + names = handler_cls().extract_request_tool_names(data) + if names: + return names + return [] diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 68f1ab114b..32eab99fb9 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1091,6 +1091,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ] = user_api_key_dict.user_max_budget data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata + data[_metadata_variable_name]["user_api_key_team_metadata"] = ( + user_api_key_dict.team_metadata + ) + data[_metadata_variable_name]["user_api_key_object_permission_id"] = ( + getattr(user_api_key_dict, "object_permission_id", None) + ) + data[_metadata_variable_name]["user_api_key_team_object_permission_id"] = ( + getattr(user_api_key_dict, "team_object_permission_id", None) + ) data[_metadata_variable_name]["headers"] = _headers data[_metadata_variable_name]["endpoint"] = str(request.url) diff --git a/litellm/proxy/management_endpoints/tool_management_endpoints.py b/litellm/proxy/management_endpoints/tool_management_endpoints.py index 89880c9a4e..7fdd3475c0 100644 --- a/litellm/proxy/management_endpoints/tool_management_endpoints.py +++ b/litellm/proxy/management_endpoints/tool_management_endpoints.py @@ -4,27 +4,87 @@ TOOL POLICY MANAGEMENT All /tool management endpoints GET /v1/tool/list - List all discovered tools and their policies +GET /v1/tool/policy/options - List available input/output policy options with descriptions GET /v1/tool/{tool_name} - Get a single tool's details -POST /v1/tool/policy - Update the call_policy for a tool +POST /v1/tool/policy - Update the input_policy / output_policy for a tool """ -from typing import Optional +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, List, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient from litellm._logging import verbose_proxy_logger from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.tool_management import ( LiteLLM_ToolTableRow, - ToolCallPolicy, + ToolDetailResponse, + ToolInputPolicy, ToolListResponse, + ToolOutputPolicy, + ToolPolicyOption, + ToolPolicyOptionsResponse, ToolPolicyUpdateRequest, ToolPolicyUpdateResponse, + ToolUsageLogEntry, + ToolUsageLogsResponse, ) router = APIRouter() +TOOL_POLICY_OPTIONS = ToolPolicyOptionsResponse( + input_policies=[ + ToolPolicyOption( + value="untrusted", + label="Untrusted", + description="Tool accepts any input, including data from untrusted tool outputs. Default for newly discovered tools.", + ), + ToolPolicyOption( + value="trusted", + label="Trusted", + description="Tool requires trusted input. Blocked if the conversation contains output from any tool with output_policy=untrusted.", + ), + ToolPolicyOption( + value="blocked", + label="Blocked", + description="Tool is completely prohibited. Any attempt to call it is rejected.", + ), + ], + output_policies=[ + ToolPolicyOption( + value="untrusted", + label="Untrusted", + description="Tool output may contain unsafe content (prompt injection, risky code). Downstream tools with input_policy=trusted will be blocked.", + ), + ToolPolicyOption( + value="trusted", + label="Trusted", + description="Tool output is verified safe. Will not trigger trust-chain blocks on downstream tools.", + ), + ], +) + + +@router.get( + "/v1/tool/policy/options", + tags=["tool management"], + dependencies=[Depends(user_api_key_auth)], + response_model=ToolPolicyOptionsResponse, +) +async def get_tool_policy_options( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Return the available input and output policy options with descriptions. + Static data — no DB call. + """ + return TOOL_POLICY_OPTIONS + @router.get( "/v1/tool/list", @@ -33,14 +93,14 @@ router = APIRouter() response_model=ToolListResponse, ) async def list_tools( - call_policy: Optional[ToolCallPolicy] = None, + input_policy: Optional[ToolInputPolicy] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - List all auto-discovered tools and their call policies. + List all auto-discovered tools and their policies. Parameters: - - call_policy: Optional filter — one of "trusted", "untrusted", "dual_llm", "blocked" + - input_policy: Optional filter — one of "trusted", "untrusted", "blocked" """ from litellm.proxy.db.tool_registry_writer import list_tools as db_list_tools from litellm.proxy.proxy_server import prisma_client @@ -51,13 +111,201 @@ async def list_tools( ) try: - tools = await db_list_tools(prisma_client=prisma_client, call_policy=call_policy) + tools = await db_list_tools( + prisma_client=prisma_client, input_policy=input_policy + ) return ToolListResponse(tools=tools, total=len(tools)) except Exception as e: verbose_proxy_logger.exception("Error listing tools: %s", e) raise HTTPException(status_code=500, detail=str(e)) +@router.get( + "/v1/tool/{tool_name:path}/detail", + tags=["tool management"], + dependencies=[Depends(user_api_key_auth)], + response_model=ToolDetailResponse, +) +async def get_tool_detail( + tool_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Get a single tool with its policy overrides (for UI detail view). + """ + from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool + from litellm.proxy.db.tool_registry_writer import list_overrides_for_tool + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, detail=CommonProxyErrors.db_not_connected_error.value + ) + + try: + tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name) + if tool is None: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + overrides = await list_overrides_for_tool( + prisma_client=prisma_client, tool_name=tool_name + ) + return ToolDetailResponse(tool=tool, overrides=overrides) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception("Error getting tool detail: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + + +def _input_snippet_for_tool_log(sl: Any, max_len: int = 200) -> Optional[str]: + """Short snippet from messages or proxy_server_request for tool usage log row.""" + if sl is None: + return None + messages = getattr(sl, "messages", None) + if messages is not None: + s = _snippet_str(messages, max_len) + if s: + return s + psr = getattr(sl, "proxy_server_request", None) + if not psr: + return None + if isinstance(psr, str): + import json + + try: + psr = json.loads(psr) + except Exception: + return _snippet_str(psr, max_len) + if isinstance(psr, dict): + msgs = psr.get("messages") + if msgs is None and isinstance(psr.get("body"), dict): + msgs = psr["body"].get("messages") + s = _snippet_str(msgs, max_len) + if s: + return s + return _snippet_str(psr, max_len) + + +def _snippet_str(text: Any, max_len: int = 200) -> Optional[str]: + if text is None: + return None + if isinstance(text, str): + s = text + elif isinstance(text, list): + parts = [] + for item in text: + if isinstance(item, dict) and "content" in item: + c = item["content"] + parts.append(c if isinstance(c, str) else str(c)) + else: + parts.append(str(item)) + s = " ".join(parts) + else: + s = str(text) + if not s or s == "{}": + return None + return (s[:max_len] + "...") if len(s) > max_len else s + + +@router.get( + "/v1/tool/{tool_name:path}/logs", + tags=["tool management"], + dependencies=[Depends(user_api_key_auth)], + response_model=ToolUsageLogsResponse, +) +async def get_tool_usage_logs( + tool_name: str, + page: int = Query(1, ge=1), + page_size: int = Query(50, ge=1, le=100), + start_date: Optional[str] = Query(None, description="YYYY-MM-DD"), + end_date: Optional[str] = Query(None, description="YYYY-MM-DD"), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Return paginated spend logs for requests that used this tool (from SpendLogToolIndex). + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, detail=CommonProxyErrors.db_not_connected_error.value + ) + + try: + where: dict = {"tool_name": tool_name} + if start_date or end_date: + start_time_filter: Optional[datetime] = None + end_time_filter: Optional[datetime] = None + if start_date: + try: + start_time_filter = datetime.strptime( + start_date + "T00:00:00", "%Y-%m-%dT%H:%M:%S" + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + if end_date: + try: + end_time_filter = datetime.strptime( + end_date + "T23:59:59", "%Y-%m-%dT%H:%M:%S" + ).replace(tzinfo=timezone.utc) + except ValueError: + pass + if start_time_filter is not None or end_time_filter is not None: + where["start_time"] = {} + if start_time_filter is not None: + where["start_time"]["gte"] = start_time_filter + if end_time_filter is not None: + where["start_time"]["lte"] = end_time_filter + + total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where) + index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many( + where=where, + order={"start_time": "desc"}, + skip=(page - 1) * page_size, + take=page_size, + ) + request_ids = [r.request_id for r in index_rows] + if not request_ids: + return ToolUsageLogsResponse( + logs=[], total=total, page=page, page_size=page_size + ) + + spend_logs = await prisma_client.db.litellm_spendlogs.find_many( + where={"request_id": {"in": request_ids}} + ) + log_by_id = {s.request_id: s for s in spend_logs} + + logs_out: List[ToolUsageLogEntry] = [] + for r in index_rows: + sl = log_by_id.get(r.request_id) + if not sl: + continue + ts = ( + sl.startTime.isoformat() + if hasattr(sl.startTime, "isoformat") + else str(sl.startTime) + ) + logs_out.append( + ToolUsageLogEntry( + id=sl.request_id, + timestamp=ts, + model=getattr(sl, "model", None) or None, + spend=getattr(sl, "spend", None), + total_tokens=getattr(sl, "total_tokens", None), + input_snippet=_input_snippet_for_tool_log(sl), + ) + ) + + return ToolUsageLogsResponse( + logs=logs_out, total=total, page=page, page_size=page_size + ) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception("Error getting tool usage logs: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get( "/v1/tool/{tool_name:path}", tags=["tool management"], @@ -70,9 +318,6 @@ async def get_tool( ): """ Get details for a single tool. - - Parameters: - - tool_name: The tool name (supports namespaced names with slashes) """ from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool from litellm.proxy.proxy_server import prisma_client @@ -85,9 +330,7 @@ async def get_tool( try: tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name) if tool is None: - raise HTTPException( - status_code=404, detail=f"Tool '{tool_name}' not found" - ) + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") return tool except HTTPException: raise @@ -96,6 +339,80 @@ async def get_tool( raise HTTPException(status_code=500, detail=str(e)) +async def _resolve_key_hash_to_object_permission_id( + prisma_client: "PrismaClient", + key_hash: str, +) -> Optional[str]: + """Resolve key (hash or raw) to object_permission_id; create permission if key has none.""" + from litellm.proxy.proxy_server import hash_token + + hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash) + if not hashed: + return None + row = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed} + ) + if row is None: + return None + op_id = getattr(row, "object_permission_id", None) + if op_id: + return op_id + new_id = str(uuid.uuid4()) + await prisma_client.db.litellm_objectpermissiontable.create( + data={"object_permission_id": new_id, "blocked_tools": []} + ) + updated_count = await prisma_client.db.litellm_verificationtoken.update_many( + where={"token": hashed, "object_permission_id": None}, + data={"object_permission_id": new_id}, + ) + if updated_count == 0: + await prisma_client.db.litellm_objectpermissiontable.delete( + where={"object_permission_id": new_id} + ) + row = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed} + ) + return getattr(row, "object_permission_id", None) if row else None + return new_id + + +async def _resolve_team_id_to_object_permission_id( + prisma_client: "PrismaClient", + team_id: str, +) -> Optional[str]: + """Resolve team_id to object_permission_id; create permission if team has none.""" + if not team_id or not team_id.strip(): + return None + team_id_clean = team_id.strip() + row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id_clean}, + select={"object_permission_id": True}, + ) + if row is None: + return None + op_id = getattr(row, "object_permission_id", None) + if op_id: + return op_id + new_id = str(uuid.uuid4()) + await prisma_client.db.litellm_objectpermissiontable.create( + data={"object_permission_id": new_id, "blocked_tools": []} + ) + updated_count = await prisma_client.db.litellm_teamtable.update_many( + where={"team_id": team_id_clean, "object_permission_id": None}, + data={"object_permission_id": new_id}, + ) + if updated_count == 0: + await prisma_client.db.litellm_objectpermissiontable.delete( + where={"object_permission_id": new_id} + ) + row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id_clean}, + select={"object_permission_id": True}, + ) + return getattr(row, "object_permission_id", None) if row else None + return new_id + + @router.post( "/v1/tool/policy", tags=["tool management"], @@ -107,15 +424,20 @@ async def update_tool_policy( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - Set the call policy for a tool. + Set the input_policy and/or output_policy for a tool (global), or block for a specific team/key (override). Parameters: - tool_name: str - The tool to update - - call_policy: "trusted" | "untrusted" | "dual_llm" | "blocked" - - Setting a tool to "blocked" will cause the ToolPolicyGuardrail to remove - that tool_call from LLM responses before returning them to the client. + - input_policy: optional - "trusted" | "untrusted" | "blocked" + - output_policy: optional - "trusted" | "untrusted" + - team_id: optional - if set, create/update override for this team only + - key_hash: optional - if set, create/update override for this key only """ + from litellm.proxy.db.tool_registry_writer import ( + add_tool_to_object_permission_blocked, + get_tool_policy_registry, + remove_tool_from_object_permission_blocked, + ) from litellm.proxy.db.tool_registry_writer import ( update_tool_policy as db_update_tool_policy, ) @@ -127,19 +449,80 @@ async def update_tool_policy( ) try: + if data.team_id is not None or data.key_hash is not None: + if data.team_id is not None and data.key_hash is not None: + raise HTTPException( + status_code=400, + detail="Provide either team_id or key_hash, not both", + ) + if data.key_hash is not None: + op_id = await _resolve_key_hash_to_object_permission_id( + prisma_client, data.key_hash + ) + else: + op_id = await _resolve_team_id_to_object_permission_id( + prisma_client, data.team_id or "" + ) + if op_id is None: + raise HTTPException( + status_code=404, + detail="Key or team not found for the given identifier", + ) + is_blocking = data.input_policy == "blocked" + if is_blocking: + ok = await add_tool_to_object_permission_blocked( + prisma_client=prisma_client, + object_permission_id=op_id, + tool_name=data.tool_name, + ) + else: + ok = await remove_tool_from_object_permission_blocked( + prisma_client=prisma_client, + object_permission_id=op_id, + tool_name=data.tool_name, + ) + if not ok: + raise HTTPException( + status_code=500, + detail=f"Failed to update policy override for tool '{data.tool_name}'", + ) + registry = get_tool_policy_registry() + if registry.is_initialized(): + await registry.sync_tool_policy_from_db(prisma_client) + return ToolPolicyUpdateResponse( + tool_name=data.tool_name, + input_policy=data.input_policy, + output_policy=data.output_policy, + updated=True, + team_id=data.team_id, + key_hash=data.key_hash, + ) + + if data.input_policy is None and data.output_policy is None: + raise HTTPException( + status_code=400, + detail="At least one of input_policy or output_policy must be provided", + ) + updated = await db_update_tool_policy( prisma_client=prisma_client, tool_name=data.tool_name, - call_policy=data.call_policy, updated_by=user_api_key_dict.user_id, + input_policy=data.input_policy, + output_policy=data.output_policy, ) if updated is None: raise HTTPException( - status_code=500, detail=f"Failed to update policy for tool '{data.tool_name}'" + status_code=500, + detail=f"Failed to update policy for tool '{data.tool_name}'", ) + registry = get_tool_policy_registry() + if registry.is_initialized(): + await registry.sync_tool_policy_from_db(prisma_client) return ToolPolicyUpdateResponse( tool_name=updated.tool_name, - call_policy=updated.call_policy, + input_policy=updated.input_policy, + output_policy=updated.output_policy, updated=True, ) except HTTPException: @@ -147,3 +530,77 @@ async def update_tool_policy( except Exception as e: verbose_proxy_logger.exception("Error updating tool policy: %s", e) raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete( + "/v1/tool/{tool_name:path}/overrides", + tags=["tool management"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_tool_policy_override( + tool_name: str, + team_id: Optional[str] = Query( + None, description="Team ID of the override to remove" + ), + key_hash: Optional[str] = Query( + None, description="Key hash of the override to remove" + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Remove a policy override for a tool. Specify the override by team_id or key_hash + (exactly one required). + """ + from litellm.proxy.db.tool_registry_writer import ( + get_tool_policy_registry, + remove_tool_from_object_permission_blocked, + ) + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, detail=CommonProxyErrors.db_not_connected_error.value + ) + if team_id is None and key_hash is None: + raise HTTPException( + status_code=400, + detail="At least one of team_id or key_hash is required to identify the override", + ) + if team_id is not None and key_hash is not None: + raise HTTPException( + status_code=400, + detail="Provide either team_id or key_hash, not both", + ) + try: + if key_hash is not None: + op_id = await _resolve_key_hash_to_object_permission_id( + prisma_client, key_hash + ) + else: + op_id = await _resolve_team_id_to_object_permission_id( + prisma_client, team_id or "" + ) + if op_id is None: + raise HTTPException( + status_code=404, + detail="Key or team not found for the given identifier", + ) + deleted = await remove_tool_from_object_permission_blocked( + prisma_client=prisma_client, + object_permission_id=op_id, + tool_name=tool_name, + ) + if not deleted: + raise HTTPException( + status_code=404, + detail=f"No override found for tool '{tool_name}' with the given scope", + ) + registry = get_tool_policy_registry() + if registry.is_initialized(): + await registry.sync_tool_policy_from_db(prisma_client) + return {"deleted": True, "tool_name": tool_name} + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception("Error deleting tool policy override: %s", e) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b3d707b1aa..6a2b0accb0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4411,6 +4411,9 @@ class ProxyConfig: if self._should_load_db_object(object_type="search_tools"): await self._init_search_tools_in_db(prisma_client=prisma_client) + if self._should_load_db_object(object_type="tools"): + await self._init_tool_policy_in_db(prisma_client=prisma_client) + if self._should_load_db_object(object_type="model_cost_map"): await self._check_and_reload_model_cost_map(prisma_client=prisma_client) @@ -4847,6 +4850,24 @@ class ProxyConfig: ) ) + async def _init_tool_policy_in_db(self, prisma_client: PrismaClient): + """ + Initialize tool policy from database into the in-memory registry. + Synced periodically by add_deployment -> _init_non_llm_objects_in_db. + """ + from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry + + try: + registry = get_tool_policy_registry() + await registry.sync_tool_policy_from_db(prisma_client=prisma_client) + verbose_proxy_logger.debug("Successfully synced tool policy from DB") + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.py::ProxyConfig:_init_tool_policy_in_db - {}".format( + str(e) + ) + ) + async def _init_vector_stores_in_db(self, prisma_client: PrismaClient): from litellm.vector_stores.vector_store_registry import VectorStoreRegistry @@ -10577,6 +10598,12 @@ async def async_queue_request( data["metadata"]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None ) + data["metadata"]["user_api_key_object_permission_id"] = getattr( + user_api_key_dict, "object_permission_id", None + ) + data["metadata"]["user_api_key_team_object_permission_id"] = getattr( + user_api_key_dict, "team_object_permission_id", None + ) data["metadata"]["endpoint"] = str(request.url) global user_temperature, user_request_timeout, user_max_tokens, user_api_base @@ -11093,9 +11120,7 @@ async def get_favicon(): if favicon_url.startswith(("http://", "https://")): try: - from litellm.llms.custom_httpx.http_handler import ( - get_async_httpx_client, - ) + from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.custom_http import httpxSpecialProvider async_client = get_async_httpx_client( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index e0b28a4e01..25ee275054 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable { vector_stores String[] @default([]) agents String[] @default([]) agent_access_groups String[] @default([]) + blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission teams LiteLLM_TeamTable[] projects LiteLLM_ProjectTable[] verification_tokens LiteLLM_VerificationToken[] @@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex { @@index([policy_id, start_time]) } +// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production +model LiteLLM_SpendLogToolIndex { + request_id String + tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc. + start_time DateTime + + @@id([request_id, tool_name]) + @@index([tool_name, start_time]) +} + // Prompt table for storing prompt configurations model LiteLLM_PromptTable { id String @id @default(uuid()) @@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable { updated_by String? } -// Global tool registry - auto-discovered from LLM responses; admins set call_policy here +// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here model LiteLLM_ToolTable { - tool_id String @id @default(uuid()) - tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" - origin String? // MCP server name or "user_defined" - call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked" - call_count Int @default(0) // cumulative number of times this tool was seen - assignments Json? @default("{}") - key_hash String? // hash of the virtual key that first called this tool - team_id String? // team that first called this tool - key_alias String? // human-readable alias of the virtual key - created_at DateTime @default(now()) - created_by String? - updated_at DateTime @default(now()) @updatedAt - updated_by String? + tool_id String @id @default(uuid()) + tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" + origin String? // MCP server name or "user_defined" + input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked" + output_policy String @default("untrusted") // "trusted" | "untrusted" + call_count Int @default(0) // cumulative number of times this tool was seen + assignments Json? @default("{}") + key_hash String? // hash of the virtual key that first called this tool + team_id String? // team that first called this tool + key_alias String? // human-readable alias of the virtual key + user_agent String? // user-agent of the first request that discovered this tool + last_used_at DateTime? // timestamp of the most recent call + created_at DateTime @default(now()) + created_by String? + updated_at DateTime @default(now()) @updatedAt + updated_by String? - @@index([call_policy]) + @@index([input_policy]) + @@index([output_policy]) @@index([team_id]) } diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index afcdd9d0c5..e6da95bb78 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -3583,8 +3583,9 @@ class PrismaClient: def _get_engine_pid(self) -> int: try: engine = self.db._original_prisma._engine # type: ignore[attr-defined] - if engine is not None and engine.process is not None: - return engine.process.pid + process = getattr(engine, "process", None) if engine is not None else None + if process is not None: + return process.pid except (AttributeError, TypeError): pass return 0 @@ -4688,6 +4689,19 @@ async def update_spend_logs_job( guardrail_tracking_err, ) + # Tool usage tracking (same batch): SpendLogToolIndex for "last N requests for tool X" + try: + from litellm.proxy.db.spend_log_tool_index import process_spend_logs_tool_usage + await process_spend_logs_tool_usage( + prisma_client=prisma_client, + logs_to_process=logs_to_process, + ) + except Exception as tool_tracking_err: + verbose_proxy_logger.warning( + "Spend tracking - tool usage tracking failed (non-fatal): %s", + tool_tracking_err, + ) + async def _monitor_spend_logs_queue( prisma_client: PrismaClient, diff --git a/litellm/types/tool_management.py b/litellm/types/tool_management.py index 8704ff2775..1c5e1df9e9 100644 --- a/litellm/types/tool_management.py +++ b/litellm/types/tool_management.py @@ -5,21 +5,27 @@ Pydantic models for Tool Policy management endpoints. from datetime import datetime from typing import Dict, List, Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field ToolCallPolicy = Literal["trusted", "untrusted", "dual_llm", "blocked"] +ToolInputPolicy = Literal["trusted", "untrusted", "blocked"] +ToolOutputPolicy = Literal["trusted", "untrusted"] + class LiteLLM_ToolTableRow(BaseModel): tool_id: str tool_name: str origin: Optional[str] = None - call_policy: ToolCallPolicy = "untrusted" + input_policy: ToolInputPolicy = "untrusted" + output_policy: ToolOutputPolicy = "untrusted" call_count: int = 0 assignments: Optional[Dict] = None key_hash: Optional[str] = None team_id: Optional[str] = None key_alias: Optional[str] = None + user_agent: Optional[str] = None + last_used_at: Optional[datetime] = None created_at: Optional[datetime] = None updated_at: Optional[datetime] = None created_by: Optional[str] = None @@ -33,10 +39,62 @@ class ToolListResponse(BaseModel): class ToolPolicyUpdateRequest(BaseModel): tool_name: str - call_policy: ToolCallPolicy + input_policy: Optional[ToolInputPolicy] = None + output_policy: Optional[ToolOutputPolicy] = None + team_id: Optional[str] = None + key_hash: Optional[str] = None + key_alias: Optional[str] = None class ToolPolicyUpdateResponse(BaseModel): tool_name: str - call_policy: ToolCallPolicy + input_policy: Optional[ToolInputPolicy] = None + output_policy: Optional[ToolOutputPolicy] = None updated: bool + team_id: Optional[str] = None + key_hash: Optional[str] = None + + +class ToolPolicyOverrideRow(BaseModel): + override_id: str + tool_name: str + team_id: Optional[str] = None + key_hash: Optional[str] = None + input_policy: ToolInputPolicy = "blocked" + key_alias: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + +class ToolPolicyOption(BaseModel): + value: str + label: str + description: str + + +class ToolPolicyOptionsResponse(BaseModel): + input_policies: List[ToolPolicyOption] + output_policies: List[ToolPolicyOption] + + +class ToolDetailResponse(BaseModel): + tool: LiteLLM_ToolTableRow + overrides: List[ToolPolicyOverrideRow] = Field(default_factory=list) + + +class ToolUsageLogEntry(BaseModel): + """One spend log row for a tool call (for UI "recent logs" table).""" + + id: str # request_id + timestamp: str + model: Optional[str] = None + spend: Optional[float] = None + total_tokens: Optional[int] = None + input_snippet: Optional[str] = None + + +class ToolUsageLogsResponse(BaseModel): + logs: List[ToolUsageLogEntry] + total: int + page: int + page_size: int diff --git a/schema.prisma b/schema.prisma index e0b28a4e01..25ee275054 100644 --- a/schema.prisma +++ b/schema.prisma @@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable { vector_stores String[] @default([]) agents String[] @default([]) agent_access_groups String[] @default([]) + blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission teams LiteLLM_TeamTable[] projects LiteLLM_ProjectTable[] verification_tokens LiteLLM_VerificationToken[] @@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex { @@index([policy_id, start_time]) } +// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production +model LiteLLM_SpendLogToolIndex { + request_id String + tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc. + start_time DateTime + + @@id([request_id, tool_name]) + @@index([tool_name, start_time]) +} + // Prompt table for storing prompt configurations model LiteLLM_PromptTable { id String @id @default(uuid()) @@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable { updated_by String? } -// Global tool registry - auto-discovered from LLM responses; admins set call_policy here +// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here model LiteLLM_ToolTable { - tool_id String @id @default(uuid()) - tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" - origin String? // MCP server name or "user_defined" - call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked" - call_count Int @default(0) // cumulative number of times this tool was seen - assignments Json? @default("{}") - key_hash String? // hash of the virtual key that first called this tool - team_id String? // team that first called this tool - key_alias String? // human-readable alias of the virtual key - created_at DateTime @default(now()) - created_by String? - updated_at DateTime @default(now()) @updatedAt - updated_by String? + tool_id String @id @default(uuid()) + tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space" + origin String? // MCP server name or "user_defined" + input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked" + output_policy String @default("untrusted") // "trusted" | "untrusted" + call_count Int @default(0) // cumulative number of times this tool was seen + assignments Json? @default("{}") + key_hash String? // hash of the virtual key that first called this tool + team_id String? // team that first called this tool + key_alias String? // human-readable alias of the virtual key + user_agent String? // user-agent of the first request that discovered this tool + last_used_at DateTime? // timestamp of the most recent call + created_at DateTime @default(now()) + created_by String? + updated_at DateTime @default(now()) @updatedAt + updated_by String? - @@index([call_policy]) + @@index([input_policy]) + @@index([output_policy]) @@index([team_id]) } diff --git a/scripts/test_tool_allowlist_script.py b/scripts/test_tool_allowlist_script.py new file mode 100644 index 0000000000..75a50d09b8 --- /dev/null +++ b/scripts/test_tool_allowlist_script.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +""" +Standalone script to test tool allowlist enforcement and tool name extraction. + +Run from repo root: + poetry run python scripts/test_tool_allowlist_script.py + +Or run the unit tests: + poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v +""" + +import asyncio +import sys +from pathlib import Path + +# Ensure repo root is on path +repo_root = Path(__file__).resolve().parent.parent +if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + + +def test_extraction(): + """Test extract_request_tool_names for each API shape.""" + from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names + + cases = [ + ("OpenAI chat tools", "/v1/chat/completions", {"tools": [{"type": "function", "function": {"name": "get_weather"}}]}), + ("OpenAI chat functions", "/v1/chat/completions", {"functions": [{"name": "run_sql"}]}), + ("OpenAI responses function", "/v1/responses", {"tools": [{"type": "function", "name": "get_current_weather"}]}), + ("OpenAI responses MCP", "/v1/responses", {"tools": [{"type": "mcp", "server_label": "dmcp"}]}), + ("Anthropic", "/v1/messages", {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]}), + ("Google generateContent", "/generate_content", {"tools": [{"functionDeclarations": [{"name": "schedule_meeting"}]}]}), + ("MCP call_tool", "/mcp/call_tool", {"name": "my_tool", "arguments": {}}), + ("Non-tool route", "/v1/embeddings", {"tools": [{"type": "function", "function": {"name": "x"}}]}), + ] + print("=== extract_request_tool_names(route, data) ===\n") + for label, route, data in cases: + names = extract_request_tool_names(route, data) + print(f" {label}: {names}") + print() + + +async def test_check_tools_allowlist(): + """Test check_tools_allowlist with mock tokens.""" + from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth + from litellm.proxy.auth.auth_checks import check_tools_allowlist + + def token(metadata=None, team_metadata=None): + return UserAPIKeyAuth( + api_key="test-key", + user_id="user", + team_id="team", + org_id=None, + models=["*"], + metadata=metadata or {}, + team_metadata=team_metadata or {}, + ) + + print("=== check_tools_allowlist (auth) ===\n") + + # No allowlist -> pass + await check_tools_allowlist( + request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]}, + valid_token=token(), + team_object=None, + route="/v1/chat/completions", + ) + print(" No allowlist, body has tools: PASS") + + # Allowed tool -> pass + await check_tools_allowlist( + request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]}, + valid_token=token(metadata={"allowed_tools": ["get_weather"]}), + team_object=None, + route="/v1/chat/completions", + ) + print(" allowed_tools=['get_weather'], body has get_weather: PASS") + + # Disallowed tool -> raise + try: + await check_tools_allowlist( + request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]}, + valid_token=token(metadata={"allowed_tools": ["other_tool"]}), + team_object=None, + route="/v1/chat/completions", + ) + print(" DISALLOWED: expected ProxyException") + except ProxyException as e: + if e.type == ProxyErrorTypes.tool_access_denied: + print(" allowed_tools=['other_tool'], body has get_weather: PASS (raised tool_access_denied)") + else: + print(f" Unexpected ProxyException type: {e.type}") + except Exception as e: + print(f" Unexpected: {e}") + + # Team allowlist when key empty + await check_tools_allowlist( + request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]}, + valid_token=token(team_metadata={"allowed_tools": ["get_weather"]}), + team_object=None, + route="/v1/chat/completions", + ) + print(" team_metadata.allowed_tools=['get_weather']: PASS") + print() + + +def main(): + print("Tool allowlist / tool name extraction – script checks\n") + test_extraction() + asyncio.run(test_check_tools_allowlist()) + print("Done. For full unit tests run:") + print(" poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v") + + +if __name__ == "__main__": + main() diff --git a/tests/test_litellm/proxy/db/test_tool_registry_writer.py b/tests/test_litellm/proxy/db/test_tool_registry_writer.py index 44f9e32058..1b1ee7afcb 100644 --- a/tests/test_litellm/proxy/db/test_tool_registry_writer.py +++ b/tests/test_litellm/proxy/db/test_tool_registry_writer.py @@ -1,6 +1,6 @@ """ Unit tests for tool_registry_writer.py — uses a mock prisma client -that exposes execute_raw / query_raw (matching the actual raw-SQL implementation). +that exposes litellm_tooltable.upsert / find_many / find_unique. """ import os @@ -13,21 +13,28 @@ import pytest sys.path.insert(0, os.path.abspath("../../..")) from litellm.proxy.db.tool_registry_writer import ( + ToolPolicyRegistry, batch_upsert_tools, get_tool, + get_tool_policy_registry, get_tools_by_names, list_tools, update_tool_policy, ) -def _make_prisma(query_rows=None): - """Return a minimal mock prisma_client with execute_raw / query_raw.""" - default_row = { +def _mock_row(**kwargs): + """Build a row-like object with real attributes (no MagicMock) for _row_to_model.""" + + class Row: + pass + + default = { "tool_id": "uuid-1", "tool_name": "my_tool", "origin": "user_defined", - "call_policy": "untrusted", + "input_policy": "untrusted", + "output_policy": "untrusted", "call_count": 1, "assignments": {}, "key_hash": None, @@ -38,31 +45,54 @@ def _make_prisma(query_rows=None): "created_by": None, "updated_by": None, } - rows = query_rows if query_rows is not None else [default_row] + default.update(kwargs) + row = Row() + for k, v in default.items(): + setattr(row, k, v) + return row + +def _make_prisma( + *, + upsert_return=None, + find_many_rows=None, + find_unique_row=None, +): + """Return a mock prisma_client with litellm_tooltable.upsert, find_many, find_unique.""" prisma = MagicMock() - prisma.db.execute_raw = AsyncMock(return_value=None) - prisma.db.query_raw = AsyncMock(return_value=rows) + prisma.db.litellm_tooltable = MagicMock() + prisma.db.litellm_tooltable.upsert = AsyncMock(return_value=upsert_return) + prisma.db.litellm_tooltable.find_many = AsyncMock( + return_value=find_many_rows if find_many_rows is not None else [] + ) + prisma.db.litellm_tooltable.find_unique = AsyncMock( + return_value=find_unique_row + ) return prisma @pytest.mark.asyncio -async def test_batch_upsert_tools_calls_execute_raw(): +async def test_batch_upsert_tools_calls_upsert(): prisma = _make_prisma() items = [{"tool_name": "tool_a", "origin": "mcp_server", "created_by": None}] await batch_upsert_tools(prisma, items) - prisma.db.execute_raw.assert_awaited_once() - call_args = prisma.db.execute_raw.call_args - sql = call_args.args[0] - assert "LiteLLM_ToolTable" in sql - assert "ON CONFLICT" in sql + prisma.db.litellm_tooltable.upsert.assert_awaited_once() + call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs + assert call_kw["where"] == {"tool_name": "tool_a"} + assert call_kw["data"]["create"]["tool_name"] == "tool_a" + assert call_kw["data"]["create"]["origin"] == "mcp_server" + assert call_kw["data"]["create"]["input_policy"] == "untrusted" + assert call_kw["data"]["create"]["output_policy"] == "untrusted" + assert call_kw["data"]["create"]["call_count"] == 1 + assert call_kw["data"]["update"]["call_count"] == {"increment": 1} + assert "updated_at" in call_kw["data"]["update"] @pytest.mark.asyncio async def test_batch_upsert_tools_empty_list(): prisma = _make_prisma() await batch_upsert_tools(prisma, []) - prisma.db.execute_raw.assert_not_awaited() + prisma.db.litellm_tooltable.upsert.assert_not_awaited() @pytest.mark.asyncio @@ -70,123 +100,120 @@ async def test_batch_upsert_tools_skips_empty_names(): prisma = _make_prisma() items = [{"tool_name": "", "origin": None}, {"tool_name": None}] # type: ignore[list-item] await batch_upsert_tools(prisma, items) - prisma.db.execute_raw.assert_not_awaited() + prisma.db.litellm_tooltable.upsert.assert_not_awaited() @pytest.mark.asyncio -async def test_batch_upsert_multiple_tools_calls_execute_raw_per_tool(): +async def test_batch_upsert_multiple_tools_calls_upsert_per_tool(): prisma = _make_prisma() items = [ {"tool_name": "tool_a", "origin": "mcp_server", "created_by": None}, {"tool_name": "tool_b", "origin": "user_defined", "created_by": "alice"}, ] await batch_upsert_tools(prisma, items) - assert prisma.db.execute_raw.await_count == 2 + assert prisma.db.litellm_tooltable.upsert.await_count == 2 + calls = prisma.db.litellm_tooltable.upsert.call_args_list + assert calls[0].kwargs["where"]["tool_name"] == "tool_a" + assert calls[1].kwargs["where"]["tool_name"] == "tool_b" @pytest.mark.asyncio async def test_list_tools_no_filter(): - row = { - "tool_id": "id1", - "tool_name": "tool_a", - "origin": "mcp", - "call_policy": "untrusted", - "call_count": 5, - "assignments": {}, - "key_hash": None, - "team_id": None, - "key_alias": None, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - "created_by": None, - "updated_by": None, - } - prisma = _make_prisma(query_rows=[row]) + row = _mock_row( + tool_id="id1", + tool_name="tool_a", + origin="mcp", + input_policy="untrusted", + output_policy="untrusted", + call_count=5, + ) + prisma = _make_prisma(find_many_rows=[row]) result = await list_tools(prisma) assert len(result) == 1 assert result[0].tool_name == "tool_a" assert result[0].call_count == 5 - prisma.db.query_raw.assert_awaited_once() + prisma.db.litellm_tooltable.find_many.assert_awaited_once() + call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs + assert call_kw["where"] == {} + assert call_kw["order"] == {"created_at": "desc"} @pytest.mark.asyncio -async def test_list_tools_with_policy_filter(): - row = { - "tool_id": "id1", - "tool_name": "blocked_tool", - "origin": None, - "call_policy": "blocked", - "call_count": 2, - "assignments": None, - "key_hash": None, - "team_id": None, - "key_alias": None, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - "created_by": None, - "updated_by": None, - } - prisma = _make_prisma(query_rows=[row]) - result = await list_tools(prisma, call_policy="blocked") - assert result[0].call_policy == "blocked" - call_args = prisma.db.query_raw.call_args - sql = call_args.args[0] - assert "WHERE call_policy" in sql +async def test_list_tools_with_input_policy_filter(): + row = _mock_row( + tool_id="id1", + tool_name="blocked_tool", + origin=None, + input_policy="blocked", + output_policy="untrusted", + call_count=2, + assignments=None, + ) + prisma = _make_prisma(find_many_rows=[row]) + result = await list_tools(prisma, input_policy="blocked") + assert result[0].input_policy == "blocked" + call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs + assert call_kw["where"] == {"input_policy": "blocked"} @pytest.mark.asyncio async def test_get_tool_found(): - prisma = _make_prisma() + row = _mock_row(tool_name="my_tool") + prisma = _make_prisma(find_unique_row=row) result = await get_tool(prisma, "my_tool") assert result is not None assert result.tool_name == "my_tool" - prisma.db.query_raw.assert_awaited_once() + prisma.db.litellm_tooltable.find_unique.assert_awaited_once_with( + where={"tool_name": "my_tool"} + ) @pytest.mark.asyncio async def test_get_tool_not_found(): - prisma = _make_prisma(query_rows=[]) + prisma = _make_prisma(find_unique_row=None) result = await get_tool(prisma, "nonexistent") assert result is None @pytest.mark.asyncio -async def test_update_tool_policy_calls_execute_raw(): - row = { - "tool_id": "uuid-1", - "tool_name": "my_tool", - "origin": "user_defined", - "call_policy": "blocked", - "call_count": 1, - "assignments": {}, - "key_hash": None, - "team_id": None, - "key_alias": None, - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - "created_by": None, - "updated_by": "admin", - } - prisma = _make_prisma(query_rows=[row]) - result = await update_tool_policy(prisma, "my_tool", "blocked", "admin") +async def test_update_tool_policy_calls_upsert_then_get_tool(): + row = _mock_row( + tool_name="my_tool", + input_policy="blocked", + output_policy="untrusted", + updated_by="admin", + ) + prisma = _make_prisma(find_unique_row=row) + result = await update_tool_policy( + prisma, "my_tool", updated_by="admin", input_policy="blocked" + ) assert result is not None - assert result.call_policy == "blocked" - prisma.db.execute_raw.assert_awaited_once() - call_args = prisma.db.execute_raw.call_args - sql = call_args.args[0] - assert "ON CONFLICT" in sql - assert "call_policy" in sql + assert result.input_policy == "blocked" + prisma.db.litellm_tooltable.upsert.assert_awaited_once() + call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs + assert call_kw["where"] == {"tool_name": "my_tool"} + assert call_kw["data"]["update"]["input_policy"] == "blocked" + assert call_kw["data"]["update"]["updated_by"] == "admin" + prisma.db.litellm_tooltable.find_unique.assert_awaited_with( + where={"tool_name": "my_tool"} + ) @pytest.mark.asyncio async def test_get_tools_by_names_returns_policy_map(): rows = [ - {"tool_name": "tool_a", "call_policy": "trusted"}, - {"tool_name": "tool_b", "call_policy": "blocked"}, + _mock_row(tool_name="tool_a", input_policy="trusted", output_policy="untrusted"), + _mock_row(tool_name="tool_b", input_policy="blocked", output_policy="untrusted"), ] - prisma = _make_prisma(query_rows=rows) + prisma = _make_prisma(find_many_rows=rows) result = await get_tools_by_names(prisma, ["tool_a", "tool_b"]) - assert result == {"tool_a": "trusted", "tool_b": "blocked"} + assert result == { + "tool_a": ("trusted", "untrusted"), + "tool_b": ("blocked", "untrusted"), + } + prisma.db.litellm_tooltable.find_many.assert_awaited_once_with( + where={"tool_name": {"in": ["tool_a", "tool_b"]}} + ) @pytest.mark.asyncio @@ -194,4 +221,71 @@ async def test_get_tools_by_names_empty_list(): prisma = _make_prisma() result = await get_tools_by_names(prisma, []) assert result == {} - prisma.db.query_raw.assert_not_awaited() + prisma.db.litellm_tooltable.find_many.assert_not_awaited() + + +# --- ToolPolicyRegistry --- + + +def _mock_tool_row( + tool_name: str, + input_policy: str = "untrusted", + output_policy: str = "untrusted", +): + row = MagicMock() + row.tool_name = tool_name + row.input_policy = input_policy + row.output_policy = output_policy + return row + + +def _mock_perm_row(object_permission_id: str, blocked_tools: list): + row = MagicMock() + row.object_permission_id = object_permission_id + row.blocked_tools = blocked_tools + return row + + +@pytest.mark.asyncio +async def test_tool_policy_registry_sync_and_get_effective_policies(): + """Registry syncs from DB; get_effective_policies returns merged blocked + global.""" + prisma = MagicMock() + prisma.db.litellm_tooltable.find_many = AsyncMock( + return_value=[ + _mock_tool_row("tool_a", input_policy="trusted"), + _mock_tool_row("tool_b", input_policy="blocked"), + _mock_tool_row("tool_c", input_policy="untrusted"), + ] + ) + prisma.db.litellm_objectpermissiontable.find_many = AsyncMock( + return_value=[ + _mock_perm_row("op-key-1", ["tool_a"]), + _mock_perm_row("op-team-1", ["tool_c"]), + ] + ) + registry = get_tool_policy_registry() + await registry.sync_tool_policy_from_db(prisma) + assert registry.is_initialized() + # Key blocked: tool_a. Team blocked: tool_c. Global: tool_b blocked. + result = registry.get_effective_policies( + ["tool_a", "tool_b", "tool_c"], + object_permission_id="op-key-1", + team_object_permission_id="op-team-1", + ) + assert result["tool_a"] == "blocked" + assert result["tool_b"] == "blocked" + assert result["tool_c"] == "blocked" + # No op ids: only global + result_global = registry.get_effective_policies(["tool_a", "tool_b", "tool_c"]) + assert result_global["tool_a"] == "trusted" + assert result_global["tool_b"] == "blocked" + assert result_global["tool_c"] == "untrusted" + + +@pytest.mark.asyncio +async def test_tool_policy_registry_not_initialized_returns_untrusted(): + """When not synced, get_effective_policies still returns untrusted for unknown tools.""" + registry = ToolPolicyRegistry() + assert not registry.is_initialized() + result = registry.get_effective_policies(["unknown_tool"]) + assert result == {"unknown_tool": "untrusted"} diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_tool_policy_guardrail.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_tool_policy_guardrail.py index c6a81efbf0..943a8d4be7 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_tool_policy_guardrail.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_tool_policy_guardrail.py @@ -12,9 +12,8 @@ from fastapi import HTTPException sys.path.insert(0, os.path.abspath("../../../../../..")) -from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import ( - ToolPolicyGuardrail, -) +from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import \ + ToolPolicyGuardrail from litellm.types.guardrails import GuardrailEventHooks @@ -70,10 +69,21 @@ async def test_no_tool_calls_in_response_passes_through(guardrail): assert result is inputs +def _registry_mock(policy_map: dict): + """Return a mock registry with is_initialized=True and get_effective_policies returning policy_map.""" + reg = MagicMock() + reg.is_initialized.return_value = True + reg.get_effective_policies.return_value = policy_map + return reg + + @pytest.mark.asyncio async def test_untrusted_tools_pass_through(guardrail): policy_map = {"search": "untrusted", "read_file": "trusted"} - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)): + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock(policy_map), + ): inputs: Any = _tool_request_inputs(["search", "read_file"]) result = await guardrail.apply_guardrail( inputs=inputs, request_data={}, input_type="request" @@ -84,7 +94,10 @@ async def test_untrusted_tools_pass_through(guardrail): @pytest.mark.asyncio async def test_blocked_tool_in_request_raises_http_exception(guardrail): policy_map = {"dangerous_tool": "blocked"} - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)): + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock(policy_map), + ): inputs: Any = _tool_request_inputs(["dangerous_tool"]) with pytest.raises(HTTPException) as exc_info: await guardrail.apply_guardrail( @@ -97,7 +110,10 @@ async def test_blocked_tool_in_request_raises_http_exception(guardrail): @pytest.mark.asyncio async def test_blocked_tool_in_response_raises_http_exception(guardrail): policy_map = {"exfil_tool": "blocked"} - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)): + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock(policy_map), + ): inputs: Any = _tool_response_inputs(["exfil_tool"]) with pytest.raises(HTTPException) as exc_info: await guardrail.apply_guardrail( @@ -110,7 +126,10 @@ async def test_blocked_tool_in_response_raises_http_exception(guardrail): @pytest.mark.asyncio async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail): policy_map = {"safe_tool": "trusted", "bad_tool": "blocked"} - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)): + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock(policy_map), + ): inputs: Any = _tool_request_inputs(["safe_tool", "bad_tool"]) with pytest.raises(HTTPException) as exc_info: await guardrail.apply_guardrail( @@ -123,8 +142,11 @@ async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail): @pytest.mark.asyncio async def test_tool_not_in_db_passes_through(guardrail): - """Tools not found in the DB (no entry) should not be blocked.""" - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value={})): + """When registry returns no policy (or empty), tools are not blocked.""" + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock({}), + ): inputs: Any = _tool_request_inputs(["unknown_tool"]) result = await guardrail.apply_guardrail( inputs=inputs, request_data={}, input_type="request" @@ -133,43 +155,30 @@ async def test_tool_not_in_db_passes_through(guardrail): @pytest.mark.asyncio -async def test_get_policies_cached_uses_cache(guardrail): - """Second call with same tool names should return the cached result.""" - policy_map = {"tool_a": "trusted"} +async def test_registry_not_initialized_passes_through(guardrail): + """When registry is not initialized, no tools are blocked (empty policy map).""" + reg = MagicMock() + reg.is_initialized.return_value = False with patch( - "litellm.proxy.db.tool_registry_writer.get_tools_by_names", - new=AsyncMock(return_value=policy_map), - ) as mock_db, patch( - "litellm.proxy.proxy_server.prisma_client", - new=MagicMock(), + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=reg, ): - # first call — should hit DB - result1 = await guardrail._get_policies_cached(["tool_a"]) - assert result1 == policy_map - - # second call — should hit cache, not DB again - result2 = await guardrail._get_policies_cached(["tool_a"]) - assert result2 == policy_map - - assert mock_db.call_count == 1 - - -@pytest.mark.asyncio -async def test_get_policies_cached_no_prisma(guardrail): - """Without a prisma client, returns empty dict.""" - with patch( - "litellm.proxy.proxy_server.prisma_client", - None, - ): - result = await guardrail._get_policies_cached(["tool_a"]) - assert result == {} + inputs: Any = _tool_request_inputs(["any_tool"]) + result = await guardrail.apply_guardrail( + inputs=inputs, request_data={}, input_type="request" + ) + assert result is inputs + reg.get_effective_policies.assert_not_called() @pytest.mark.asyncio async def test_response_tool_calls_as_objects(guardrail): """tool_calls that are objects (not dicts) with .function.name should work.""" policy_map = {"obj_tool": "blocked"} - with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)): + with patch( + "litellm.proxy.db.tool_registry_writer.get_tool_policy_registry", + return_value=_registry_mock(policy_map), + ): fn = MagicMock() fn.name = "obj_tool" tc = MagicMock() diff --git a/tests/test_litellm/proxy/test_tools_allowlist_enforcement.py b/tests/test_litellm/proxy/test_tools_allowlist_enforcement.py new file mode 100644 index 0000000000..4adc5acde8 --- /dev/null +++ b/tests/test_litellm/proxy/test_tools_allowlist_enforcement.py @@ -0,0 +1,200 @@ +""" +Tests for tool allowlist enforcement (key/team metadata.allowed_tools). + +Covers: +- check_tools_allowlist: allowed, disallowed, no allowlist, non-tool routes +- extract_request_tool_names: OpenAI chat, responses, Anthropic, generate_content, MCP +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.proxy._types import (ProxyErrorTypes, ProxyException, + UserAPIKeyAuth) +from litellm.proxy.auth.auth_checks import check_tools_allowlist +from litellm.proxy.guardrails.tool_name_extraction import ( + TOOL_CAPABLE_CALL_TYPES, extract_request_tool_names) + + +def _token(metadata=None, team_metadata=None): + return UserAPIKeyAuth( + api_key="test-key", + user_id="user", + team_id="team", + org_id=None, + models=["*"], + metadata=metadata or {}, + team_metadata=team_metadata or {}, + ) + + +class TestExtractRequestToolNames: + """Test tool name extraction per API format.""" + + def test_openai_chat_tools(self): + data = { + "tools": [ + {"type": "function", "function": {"name": "get_weather"}}, + {"type": "function", "function": {"name": "run_sql"}}, + ] + } + assert extract_request_tool_names("/v1/chat/completions", data) == [ + "get_weather", + "run_sql", + ] + + def test_openai_chat_functions_legacy(self): + data = {"functions": [{"name": "get_weather"}, {"name": "run_sql"}]} + assert extract_request_tool_names("/v1/chat/completions", data) == [ + "get_weather", + "run_sql", + ] + + def test_openai_responses_function_tools(self): + data = { + "tools": [ + {"type": "function", "name": "get_current_weather", "description": "x"}, + ] + } + assert extract_request_tool_names("/v1/responses", data) == [ + "get_current_weather" + ] + + def test_openai_responses_mcp_tools(self): + data = { + "tools": [ + {"type": "mcp", "server_label": "dmcp", "server_url": "http://x"}, + ] + } + assert extract_request_tool_names("/v1/responses", data) == ["dmcp"] + + def test_anthropic_tools(self): + data = {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]} + assert extract_request_tool_names("/v1/messages", data) == [ + "get_weather", + "run_sql", + ] + + def test_generate_content_tools(self): + data = { + "tools": [ + { + "functionDeclarations": [ + {"name": "schedule_meeting", "description": "x"}, + ] + }, + ] + } + assert extract_request_tool_names("/generate_content", data) == [ + "schedule_meeting" + ] + + def test_mcp_call_tool_name(self): + data = {"name": "my_tool", "arguments": {}} + assert extract_request_tool_names("/mcp/call_tool", data) == ["my_tool"] + + def test_mcp_call_tool_mcp_tool_name(self): + data = {"mcp_tool_name": "other_tool"} + assert extract_request_tool_names("/mcp/call_tool", data) == ["other_tool"] + + def test_non_tool_route_returns_empty(self): + data = {"tools": [{"type": "function", "function": {"name": "x"}}]} + assert extract_request_tool_names("/v1/embeddings", data) == [] + + +class TestCheckToolsAllowlist: + """Test allowlist enforcement in auth (no DB in hot path).""" + + @pytest.mark.asyncio + async def test_no_allowlist_passes(self): + token = _token(metadata={}, team_metadata={}) + body = { + "tools": [{"type": "function", "function": {"name": "get_weather"}}] + } + await check_tools_allowlist( + request_body=body, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) + + @pytest.mark.asyncio + async def test_allowed_tool_passes(self): + token = _token(metadata={"allowed_tools": ["get_weather"]}) + body = { + "tools": [{"type": "function", "function": {"name": "get_weather"}}] + } + await check_tools_allowlist( + request_body=body, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) + + @pytest.mark.asyncio + async def test_disallowed_tool_raises(self): + token = _token(metadata={"allowed_tools": ["other_tool"]}) + body = { + "tools": [{"type": "function", "function": {"name": "get_weather"}}] + } + with pytest.raises(ProxyException) as exc_info: + await check_tools_allowlist( + request_body=body, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) + assert exc_info.value.type == ProxyErrorTypes.tool_access_denied + assert "get_weather" in str(exc_info.value.message) + + @pytest.mark.asyncio + async def test_team_allowlist_used_when_key_empty(self): + token = _token( + metadata={}, + team_metadata={"allowed_tools": ["get_weather"]}, + ) + body = { + "tools": [{"type": "function", "function": {"name": "get_weather"}}] + } + await check_tools_allowlist( + request_body=body, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) + + @pytest.mark.asyncio + async def test_key_allowlist_overrides_team(self): + token = _token( + metadata={"allowed_tools": ["get_weather"]}, + team_metadata={"allowed_tools": ["other_tool"]}, + ) + body = { + "tools": [{"type": "function", "function": {"name": "get_weather"}}] + } + await check_tools_allowlist( + request_body=body, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) + + @pytest.mark.asyncio + async def test_valid_token_none_skips(self): + await check_tools_allowlist( + request_body={"tools": [{"type": "function", "function": {"name": "x"}}]}, + valid_token=None, + team_object=None, + route="/v1/chat/completions", + ) + + @pytest.mark.asyncio + async def test_no_tools_in_body_passes(self): + token = _token(metadata={"allowed_tools": ["get_weather"]}) + await check_tools_allowlist( + request_body={"messages": []}, + valid_token=token, + team_object=None, + route="/v1/chat/completions", + ) diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 622c3bf70a..b927f312df 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -39,7 +39,7 @@ import UserDashboard from "@/components/user_dashboard"; import { AccessGroupsPage } from "@/components/AccessGroups/AccessGroupsPage"; import { ProjectsPage } from "@/components/Projects/ProjectsPage"; import VectorStoreManagement from "@/components/vector_store_management"; -import ToolPolicies from "@/components/ToolPolicies"; +import ToolPoliciesView from "@/components/ToolPoliciesView"; import SpendLogsTable from "@/components/view_logs"; import ViewUserDashboard from "@/components/view_users"; import { ThemeProvider } from "@/contexts/ThemeContext"; @@ -549,7 +549,7 @@ function CreateKeyPageContent() { ) : page == "vector-stores" ? ( ) : page == "tool-policies" ? ( - + ) : page == "guardrails-monitor" ? ( ) : page == "new_usage" ? ( diff --git a/ui/litellm-dashboard/src/components/ToolDetail.tsx b/ui/litellm-dashboard/src/components/ToolDetail.tsx new file mode 100644 index 0000000000..ed0f866acb --- /dev/null +++ b/ui/litellm-dashboard/src/components/ToolDetail.tsx @@ -0,0 +1,445 @@ +"use client"; + +import { ArrowLeftOutlined, HistoryOutlined, ToolOutlined } from "@ant-design/icons"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { Button, Select, Spin } from "antd"; +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import TeamDropdown from "@/components/common_components/team_dropdown"; +import { LogViewer } from "@/components/GuardrailsMonitor/LogViewer"; +import type { LogEntry } from "@/components/GuardrailsMonitor/mockData"; +import { PolicySelect } from "@/components/ToolPolicies/PolicySelect"; +import { + deleteToolPolicyOverride, + fetchToolDetail, + fetchToolPolicyOptions, + getToolUsageLogs, + keyListCall, + teamListCall, + updateToolPolicy, + type ToolPolicyOption, + type ToolPolicyOverrideRow, +} from "@/components/networking"; +import type { Team } from "@/components/key_team_helpers/key_list"; + +interface ToolDetailProps { + toolName: string; + onBack: () => void; + accessToken: string | null; +} + +interface KeyOption { + token: string; + key_alias?: string; +} + +const TOOL_DETAIL_QUERY_KEY = "tool-detail"; + +const LOGS_PAGE_SIZE = 50; + +function getDefaultLogsDateRange(): { start: string; end: string } { + const end = new Date(); + const start = new Date(); + start.setDate(start.getDate() - 90); + const fmt = (d: Date) => + d.toISOString().slice(0, 19).replace("T", " "); + return { start: fmt(start), end: fmt(end) }; +} + +export function ToolDetail({ toolName, onBack, accessToken }: ToolDetailProps) { + const queryClient = useQueryClient(); + const [overrideSaving, setOverrideSaving] = useState(false); + const [inputPolicySaving, setInputPolicySaving] = useState(false); + const [outputPolicySaving, setOutputPolicySaving] = useState(false); + const [blockScope, setBlockScope] = useState<"team" | "key">("team"); + const [blockTeamId, setBlockTeamId] = useState(null); + const [blockKey, setBlockKey] = useState(null); + + const logsDateRange = useMemo(() => getDefaultLogsDateRange(), []); + + const { data: detail, isLoading: detailLoading, error: detailError } = useQuery({ + queryKey: [TOOL_DETAIL_QUERY_KEY, toolName], + queryFn: () => fetchToolDetail(accessToken!, toolName), + enabled: !!accessToken && !!toolName, + }); + + const { data: policyOptions } = useQuery({ + queryKey: ["tool-policy-options"], + queryFn: () => fetchToolPolicyOptions(accessToken!), + enabled: !!accessToken, + staleTime: 60_000, + }); + + const { data: teamsData } = useQuery({ + queryKey: ["teams-list-tool-detail"], + queryFn: () => teamListCall(accessToken!, null, null), + enabled: !!accessToken, + }); + + const { data: keysData } = useQuery({ + queryKey: ["keys-list-tool-detail"], + queryFn: () => keyListCall(accessToken!, null, null, null, null, null, 1, 100), + enabled: !!accessToken, + }); + + const { data: logsData, isLoading: logsLoading } = useQuery({ + queryKey: ["tool-usage-logs", toolName, logsDateRange.start, logsDateRange.end], + queryFn: () => + getToolUsageLogs(accessToken!, toolName, { + page: 1, + pageSize: LOGS_PAGE_SIZE, + startDate: logsDateRange.start, + endDate: logsDateRange.end, + }), + enabled: !!accessToken && !!toolName, + }); + + const logs: LogEntry[] = useMemo(() => { + const list = logsData?.logs ?? []; + return list.map((l) => ({ + id: l.id, + timestamp: l.timestamp, + action: "passed" as const, + model: l.model ?? undefined, + input_snippet: l.input_snippet ?? undefined, + })); + }, [logsData?.logs]); + + const teams: Team[] = useMemo(() => { + const arr = Array.isArray(teamsData) ? teamsData : teamsData?.data ?? []; + return arr.map((t: { team_id?: string; id?: string; team_alias?: string }) => ({ + team_id: t.team_id ?? t.id ?? "", + team_alias: t.team_alias ?? t.team_id ?? "", + models: [], + max_budget: null, + budget_duration: null, + tpm_limit: null, + rpm_limit: null, + organization_id: "", + created_at: "", + keys: [], + members_with_roles: [], + spend: 0, + })); + }, [teamsData]); + + const keys: KeyOption[] = useMemo(() => { + const keysRes = keysData?.keys ?? keysData?.data ?? []; + return keysRes.map((k: { token?: string; api_key?: string; key_hash?: string; key_alias?: string }) => ({ + token: k.token ?? k.api_key ?? k.key_hash ?? "", + key_alias: k.key_alias ?? (k.token ?? k.api_key ?? k.key_hash)?.toString?.()?.substring?.(0, 8), + })); + }, [keysData]); + + const invalidateDetail = useCallback(() => { + queryClient.invalidateQueries({ queryKey: [TOOL_DETAIL_QUERY_KEY, toolName] }); + }, [queryClient, toolName]); + + const handleInputPolicyChange = useCallback( + async (_name: string, newPolicy: string) => { + if (!accessToken) return; + setInputPolicySaving(true); + try { + await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy }); + invalidateDetail(); + } catch (e: unknown) { + alert(`Failed to update input policy: ${e instanceof Error ? e.message : String(e)}`); + } finally { + setInputPolicySaving(false); + } + }, + [accessToken, toolName, invalidateDetail] + ); + + const handleOutputPolicyChange = useCallback( + async (_name: string, newPolicy: string) => { + if (!accessToken) return; + setOutputPolicySaving(true); + try { + await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy }); + invalidateDetail(); + } catch (e: unknown) { + alert(`Failed to update output policy: ${e instanceof Error ? e.message : String(e)}`); + } finally { + setOutputPolicySaving(false); + } + }, + [accessToken, toolName, invalidateDetail] + ); + + const handleAddOverride = useCallback(async () => { + if (!accessToken || !toolName) return; + const isTeam = blockScope === "team"; + if (isTeam && !blockTeamId) return; + if (!isTeam && !blockKey?.token) return; + setOverrideSaving(true); + try { + await updateToolPolicy(accessToken, toolName, { input_policy: "blocked" }, { + team_id: isTeam ? blockTeamId : undefined, + key_hash: !isTeam ? blockKey!.token : undefined, + key_alias: !isTeam ? blockKey!.key_alias : undefined, + }); + invalidateDetail(); + setBlockTeamId(null); + setBlockKey(null); + } catch (e: unknown) { + alert(`Failed to add override: ${e instanceof Error ? e.message : String(e)}`); + } finally { + setOverrideSaving(false); + } + }, [accessToken, toolName, blockScope, blockTeamId, blockKey, invalidateDetail]); + + const handleRemoveOverride = useCallback( + async (override: ToolPolicyOverrideRow) => { + if (!accessToken || !toolName) return; + setOverrideSaving(true); + try { + await deleteToolPolicyOverride(accessToken, toolName, { + team_id: override.team_id ?? undefined, + key_hash: override.key_hash ?? undefined, + }); + invalidateDetail(); + } catch (e: unknown) { + alert(`Failed to remove override: ${e instanceof Error ? e.message : String(e)}`); + } finally { + setOverrideSaving(false); + } + }, + [accessToken, toolName, invalidateDetail] + ); + + if (detailLoading && !detail) { + return ( +
+ +
+ ); + } + + if (detailError && !detail) { + return ( +
+ +

Failed to load tool details.

+
+ ); + } + + if (!detail) { + return null; + } + + const { tool, overrides } = detail; + + const inputDesc = policyOptions?.input_policies?.find( + (p) => p.value === tool.input_policy + )?.description; + const outputDesc = policyOptions?.output_policies?.find( + (p) => p.value === tool.output_policy + )?.description; + + return ( +
+
+ + +
+
+
+ +

{tool.tool_name}

+ + {tool.origin ?? "—"} + + + {(tool.call_count ?? 0).toLocaleString()} calls + +
+
+ {tool.user_agent && ( +
+
User Agent:
+
{tool.user_agent}
+
+ )} + {tool.created_at && ( +
+
First Discovered:
+
{new Date(tool.created_at).toLocaleString()}
+
+ )} + {tool.last_used_at && ( +
+
Last Used:
+
{new Date(tool.last_used_at).toLocaleString()}
+
+ )} +
+
+
+
+ +
+ {/* Two-panel policy layout */} +
+
+

Input Policy

+

+ {inputDesc ?? "Controls what data this tool is allowed to accept."} +

+ +
+ +
+

Output Policy

+

+ {outputDesc ?? "Controls how this tool's output is trusted by downstream tools."} +

+ +
+
+ + {overrides.length > 0 && ( +
+

Blocked for team or key

+
    + {overrides.map((ov) => ( +
  • + + {ov.team_id ? `Team: ${ov.team_id}` : ""} + {ov.team_id && ov.key_hash ? " · " : ""} + {ov.key_hash ? `Key: ${ov.key_alias || ov.key_hash.substring(0, 8)}` : ""} + {!ov.team_id && !ov.key_hash ? "—" : ""} + + +
  • + ))} +
+
+ )} + +
+

Block for team or key

+
+
+ Scope +
+ + +
+
+
+ + {blockScope === "team" ? "Team" : "Key"} + + {blockScope === "team" ? ( + setBlockTeamId(id || null)} + /> + ) : ( + onChange(toolName, v)} - onClick={(e) => e.stopPropagation()} - style={{ - minWidth: 110, - fontWeight: 500, - }} - popupMatchSelectWidth={false} - options={POLICY_OPTIONS.map((o) => ({ - value: o.value, - label: ( - - - {o.label} - - ), - }))} - /> - ); -}; - -export const ToolPolicies: React.FC = ({ accessToken }) => { +export const ToolPolicies: React.FC = ({ accessToken, onSelectTool }) => { const [tools, setTools] = useState([]); const [loading, setLoading] = useState(true); const [isFetching, setIsFetching] = useState(false); const [error, setError] = useState(null); - const [saving, setSaving] = useState(null); + const [savingInput, setSavingInput] = useState(null); + const [savingOutput, setSavingOutput] = useState(null); const [searchTerm, setSearchTerm] = useState(""); const [sortField, setSortField] = useState("created_at"); @@ -123,16 +96,29 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { return () => clearInterval(id); }, [isLiveTail, load]); - const handlePolicyChange = async (toolName: string, newPolicy: string) => { + const handleInputPolicyChange = async (toolName: string, newPolicy: string) => { if (!accessToken) return; - setSaving(toolName); + setSavingInput(toolName); try { - await updateToolPolicy(accessToken, toolName, newPolicy); - setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, call_policy: newPolicy } : t))); + await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy }); + setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, input_policy: newPolicy } : t))); } catch (e: any) { - alert(`Failed to update policy: ${e.message}`); + alert(`Failed to update input policy: ${e.message}`); } finally { - setSaving(null); + setSavingInput(null); + } + }; + + const handleOutputPolicyChange = async (toolName: string, newPolicy: string) => { + if (!accessToken) return; + setSavingOutput(toolName); + try { + await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy }); + setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, output_policy: newPolicy } : t))); + } catch (e: any) { + alert(`Failed to update output policy: ${e.message}`); + } finally { + setSavingOutput(null); } }; @@ -157,7 +143,6 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { setCurrentPage(1); }; - // Build unique team/key options from loaded data const teamOptions = Array.from(new Set(tools.map((t) => t.team_id).filter(Boolean))).map((v) => ({ label: v as string, value: v as string, @@ -169,9 +154,14 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { const filterOptions: FilterOption[] = [ { - name: "Policy", - label: "Policy", - options: POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })), + name: "Input Policy", + label: "Input Policy", + options: INPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })), + }, + { + name: "Output Policy", + label: "Output Policy", + options: OUTPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })), }, { name: "Team Name", @@ -185,6 +175,39 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { }, ]; + const { newToday, newYesterday, trendSubtitle, totalTools, blockedCount, activeTeamsCount, needsReviewTools } = + useMemo(() => { + const now = new Date(); + const todayKey = getUTCDateKey(now); + const yesterday = new Date(now); + yesterday.setUTCDate(yesterday.getUTCDate() - 1); + const yesterdayKey = getUTCDateKey(yesterday); + + const newToday = countToolsInUTCDay(tools, todayKey); + const newYesterday = countToolsInUTCDay(tools, yesterdayKey); + const trendSubtitle = getTrendSubtitle(newToday, newYesterday); + + const totalTools = tools.length; + const blockedCount = tools.filter((t) => t.input_policy === "blocked").length; + const activeTeamsCount = new Set(tools.map((t) => t.team_id).filter(Boolean)).size; + + const needsReviewTools = tools.filter( + (t) => + isCreatedInUTCDay(t.created_at, todayKey) && + t.input_policy === "untrusted" + ); + + return { + newToday, + newYesterday, + trendSubtitle, + totalTools, + blockedCount, + activeTeamsCount, + needsReviewTools, + }; + }, [tools]); + const SortHeader = ({ label, field }: { label: string; field: SortField }) => (
{label} @@ -203,10 +226,12 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { (t.team_id ?? "").toLowerCase().includes(q) || (t.key_alias ?? "").toLowerCase().includes(q) || (t.key_hash ?? "").toLowerCase().includes(q) || - t.call_policy.toLowerCase().includes(q); + t.input_policy.toLowerCase().includes(q) || + t.output_policy.toLowerCase().includes(q); if (!matchesSearch) return false; } - if (activeFilters["Policy"] && t.call_policy !== activeFilters["Policy"]) return false; + if (activeFilters["Input Policy"] && t.input_policy !== activeFilters["Input Policy"]) return false; + if (activeFilters["Output Policy"] && t.output_policy !== activeFilters["Output Policy"]) return false; if (activeFilters["Team Name"] && t.team_id !== activeFilters["Team Name"]) return false; if (activeFilters["Key Name"] && t.key_alias !== activeFilters["Key Name"]) return false; return true; @@ -223,11 +248,74 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { const totalPages = Math.max(1, Math.ceil(sorted.length / pageSize)); const paginated = sorted.slice((currentPage - 1) * pageSize, currentPage * pageSize); + const scrollToToolRow = (toolId: string) => { + const idx = sorted.findIndex((t) => t.tool_id === toolId); + if (idx >= 0) { + const page = Math.floor(idx / pageSize) + 1; + if (page !== currentPage) setCurrentPage(page); + requestAnimationFrame(() => { + setTimeout(() => { + document.getElementById(`tool-row-${toolId}`)?.scrollIntoView({ behavior: "smooth", block: "center" }); + }, 100); + }); + } + }; + return ( -
+

Tool Policies

+ +
+ + + + } + /> + + 0 ? "text-red-600" : undefined} + /> + 0 ? activeTeamsCount : "—"} /> +
+ + {needsReviewTools.length > 0 && ( +
+

Needs Review

+

+ {needsReviewTools.length} new tool{needsReviewTools.length !== 1 ? "s" : ""} discovered that require + policy decisions. +

+
+ {needsReviewTools.map((t) => ( + + + {t.tool_name} + + + + ))} +
+
+ )} +
- {/* Toolbar */}
@@ -311,7 +399,6 @@ export const ToolPolicies: React.FC = ({ accessToken }) => {
- {/* Filter row */}
= ({ accessToken }) => {
- {/* Auto-refresh banner */} {isLiveTail && (
Auto-refreshing every 15 seconds @@ -336,7 +422,6 @@ export const ToolPolicies: React.FC = ({ accessToken }) => {
{error}
)} - {/* Table */} @@ -347,7 +432,10 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { - + + + + @@ -359,45 +447,61 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { - Origin + User Agent {loading ? ( - + Loading tools… ) : paginated.length === 0 ? ( - + No tools discovered yet. Make a chat completion that returns tool_calls to start auto-discovery. ) : ( paginated.map((tool) => ( - + - - - {tool.tool_name} - - + - - {(tool.call_count ?? 0).toLocaleString()} + + + + +
+ {(tool.call_count ?? 0).toLocaleString()} +
@@ -417,8 +521,8 @@ export const ToolPolicies: React.FC = ({ accessToken }) => { - - {tool.origin ?? "-"} + + {tool.user_agent ?? "-"}
@@ -427,7 +531,6 @@ export const ToolPolicies: React.FC = ({ accessToken }) => {
- {/* Bottom pagination (only when > 1 page) */} {totalPages > 1 && (
@@ -453,6 +556,7 @@ export const ToolPolicies: React.FC = ({ accessToken }) => {
)}
+
); }; diff --git a/ui/litellm-dashboard/src/components/ToolPolicies/PolicySelect.tsx b/ui/litellm-dashboard/src/components/ToolPolicies/PolicySelect.tsx new file mode 100644 index 0000000000..1317351931 --- /dev/null +++ b/ui/litellm-dashboard/src/components/ToolPolicies/PolicySelect.tsx @@ -0,0 +1,92 @@ +"use client"; + +import React from "react"; +import { Select } from "antd"; + +export const INPUT_POLICY_OPTIONS = [ + { value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" }, + { value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" }, + { value: "blocked", label: "blocked", color: "#991b1b", bg: "#fee2e2", border: "#fca5a5" }, +] as const; + +export const OUTPUT_POLICY_OPTIONS = [ + { value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" }, + { value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" }, +] as const; + +export const POLICY_OPTIONS = INPUT_POLICY_OPTIONS; + +export const policyStyle = (p: string) => + INPUT_POLICY_OPTIONS.find((o) => o.value === p) ?? INPUT_POLICY_OPTIONS[0]; + +export interface PolicySelectProps { + value: string; + toolName: string; + saving: boolean; + onChange: (toolName: string, policy: string) => void; + policyType?: "input" | "output"; + size?: "small" | "middle"; + minWidth?: number; + stopPropagation?: boolean; +} + +export const PolicySelect: React.FC = ({ + value, + toolName, + saving, + onChange, + policyType = "input", + size = "small", + minWidth = 110, + stopPropagation = true, +}) => { + const options = policyType === "output" ? OUTPUT_POLICY_OPTIONS : INPUT_POLICY_OPTIONS; + const style = policyStyle(value); + return ( +