[Feat] Add Tool Policies for AI Gateway (#22732)

* fix: fix ui render

* fix: fix minor bugs

* refactor: use prisma functions instead of raw sql (safer)

* fix(add-new-tiles-to-tool-policies): allow developer to see what's available

* feat: ensure tool allowlist runs correctly for tool names + mcp's

* refactor: more ui improvements

* feat: working key tool blocking

* feat(tools): show tool logs

* refactor: backend code improvements

* refactor: improve log viewer for tools

* fix: address PR review feedback for tool access control

- Add missing blocked_tools column to root schema.prisma (schema drift)
- Invalidate ToolPolicyRegistry after policy mutations so changes take effect immediately
- Remove dead code: unused get_effective_policies, get_tool_policies_cached, and helpers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: race condition in permission resolution and remove duplicate allowlist check

- Use atomic update_many with object_permission_id=None to prevent concurrent
  requests from creating orphaned permission rows and losing tool blocks
- Remove duplicate allowed_tools enforcement from guardrail (already enforced
  in auth layer via check_tools_allowlist)
- Move inline uuid import to module level

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* update to account for  userAgent

* UI - Add ToolDetails

* input/output policy

* LiteLLM_PolicyAttachmentTable

* LiteLLM_PolicyAttachmentTable

* fix: add _enqueue_tool_registry_upsert

* fix: tool mgmt endpoints

* tool mgmt endpoints

* Update tests/test_litellm/proxy/db/test_tool_registry_writer.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update tests/test_litellm/proxy/db/test_tool_registry_writer.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update tests/test_litellm/proxy/db/test_tool_registry_writer.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* fix: sync root schema.prisma and fix test_tool_registry_writer for input/output policy

- Migrate root schema.prisma LiteLLM_ToolTable from call_policy to
  input_policy/output_policy, add missing user_agent and last_used_at columns
  (now consistent with litellm/proxy/schema.prisma and litellm-proxy-extras)
- Fix SpendLogToolIndex comment across all three schema files
- Fix all call_policy references in test_tool_registry_writer.py:
  swapped update_tool_policy arguments, wrong get_tools_by_names return type
  assertions, _mock_tool_row setting call_policy instead of input_policy

Addresses Greptile review feedback on PR #22732.

Made-with: Cursor

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2026-03-03 20:22:20 -08:00 committed by GitHub
parent 8baa3ae8cb
commit 1f412bc6d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 3467 additions and 627 deletions

View File

@ -109,6 +109,8 @@ Key files:
- `litellm/proxy/auth/` - Authentication logic
- `litellm/proxy/management_endpoints/` - Admin API endpoints
**Database (proxy)**: Use Prisma model methods (`prisma_client.db.<model>.upsert`, `.find_many`, `.find_unique`, etc.), not raw SQL (`execute_raw`/`query_raw`). See COMMON PITFALLS for details.
## MCP (MODEL CONTEXT PROTOCOL) SUPPORT
LiteLLM supports MCP for agent workflows:
@ -176,6 +178,7 @@ When opening issues or pull requests, follow these templates:
5. **Dependencies**: Keep dependencies minimal and well-justified
6. **UI/Backend Contract Mismatch**: When adding a new entity type to the UI, always check whether the backend endpoint accepts a single value or an array. Match the UI control accordingly (single-select vs. multi-select) to avoid silently dropping user selections
7. **Missing Tests for New Entity Types**: When adding a new entity type (e.g., in `EntityUsage`, `UsageViewSelect`), always add corresponding tests in the existing test files and update any icon/component mocks
8. **Raw SQL in proxy DB code**: Do not use `execute_raw` or `query_raw` for proxy database access. Use Prisma model methods (e.g. `prisma_client.db.litellm_tooltable.upsert()`, `.find_many()`, `.find_unique()`) so behavior stays consistent with the schema, the client stays mockable in tests, and you avoid the pitfalls of hand-written SQL (parameter ordering, type casting, schema drift)
8. **Do not hardcode model-specific flags**: Put model-specific capability flags in `model_prices_and_context_window.json` and read them via `get_model_info` (or existing helpers like `supports_reasoning`). This prevents users from needing to upgrade LiteLLM each time a new model supports a feature.

View File

@ -107,6 +107,10 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
- Migration files auto-generated with `prisma migrate dev`
- Always test migrations against both PostgreSQL and SQLite
### Proxy database access
- **Do not write raw SQL** for proxy DB operations. Use Prisma model methods instead of `execute_raw` / `query_raw`.
- Use the generated client: `prisma_client.db.<model>` (e.g. `litellm_tooltable`, `litellm_usertable`) with `.upsert()`, `.find_many()`, `.find_unique()`, `.update()`, `.update_many()` as appropriate. This avoids schema/client drift, keeps code testable with simple mocks, and matches patterns used in spend logs and other proxy code.
### Enterprise Features
- Enterprise-specific code in `enterprise/` directory
- Optional features enabled via environment variables

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LiteLLM_ObjectPermissionTable" ADD COLUMN "blocked_tools" TEXT[] DEFAULT ARRAY[]::TEXT[];

View File

@ -0,0 +1,11 @@
-- CreateTable
CREATE TABLE "LiteLLM_SpendLogToolIndex" (
"request_id" TEXT NOT NULL,
"tool_name" TEXT NOT NULL,
"start_time" TIMESTAMP(3) NOT NULL,
CONSTRAINT "LiteLLM_SpendLogToolIndex_pkey" PRIMARY KEY ("request_id","tool_name")
);
-- CreateIndex
CREATE INDEX "LiteLLM_SpendLogToolIndex_tool_name_start_time_idx" ON "LiteLLM_SpendLogToolIndex"("tool_name", "start_time");

View File

@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
vector_stores String[] @default([])
agents String[] @default([])
agent_access_groups String[] @default([])
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
teams LiteLLM_TeamTable[]
projects LiteLLM_ProjectTable[]
verification_tokens LiteLLM_VerificationToken[]
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
@@index([policy_id, start_time])
}
// Index for fast "last N logs for tool" from SpendLogs see how a tool is called in production
model LiteLLM_SpendLogToolIndex {
request_id String
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
start_time DateTime
@@id([request_id, tool_name])
@@index([tool_name, start_time])
}
// Prompt table for storing prompt configurations
model LiteLLM_PromptTable {
id String @id @default(uuid())
@ -1065,26 +1076,31 @@ model LiteLLM_PolicyAttachmentTable {
updated_by String?
}
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
// Global tool registry - auto-discovered from LLM responses; admins set input_policy/output_policy here
model LiteLLM_ToolTable {
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
output_policy String @default("untrusted") // "trusted" | "untrusted"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
user_agent String? // user-agent of the first request that discovered this tool
last_used_at DateTime? // timestamp of the most recent call
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
@@index([call_policy])
@@index([input_policy])
@@index([output_policy])
@@index([team_id])
}
// Per-(tool, team/key) policy overrides. When present, override replaces global tool policy for that scope.
//Unified Access Groups table for storing unified access groups
model LiteLLM_AccessGroupTable {
access_group_id String @id @default(uuid())

View File

@ -75,7 +75,7 @@ class AnthropicMessagesHandler(BaseTranslation):
if messages is None:
return data
chat_completion_compatible_request, tool_name_mapping = (
chat_completion_compatible_request, _tool_name_mapping = (
LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
# Use a shallow copy to avoid mutating request data (pop on litellm_metadata).
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
@ -141,6 +141,14 @@ class AnthropicMessagesHandler(BaseTranslation):
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from Anthropic messages request (tools[].name)."""
names: List[str] = []
for tool in data.get("tools") or []:
if isinstance(tool, dict) and tool.get("name"):
names.append(str(tool["name"]))
return names
def _extract_input_text_and_images(
self,
message: Dict[str, Any],

View File

@ -98,3 +98,10 @@ class BaseTranslation(ABC):
Optional to override in subclasses.
"""
return responses_so_far
def extract_request_tool_names(self, data: dict) -> List[str]:
"""
Extract tool names from the request body for allowlist/policy checks.
Override in tool-capable handlers; default returns [].
"""
return []

View File

@ -135,6 +135,19 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name)."""
names: List[str] = []
for tool in data.get("tools") or []:
if isinstance(tool, dict) and tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict) and fn.get("name"):
names.append(str(fn["name"]))
for fn in data.get("functions") or []:
if isinstance(fn, dict) and fn.get("name"):
names.append(str(fn["name"]))
return names
def _extract_inputs(
self,
message: Dict[str, Any],

View File

@ -30,27 +30,22 @@ Output: response.output is List[GenericResponseOutputItem] where each has:
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call import \
ResponseFunctionToolCall
from pydantic import BaseModel
from litellm._logging import verbose_proxy_logger
from litellm.completion_extras.litellm_responses_transformation.transformation import (
LiteLLMResponsesTransformationHandler,
OpenAiResponsesToChatCompletionStreamIterator,
)
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolParam,
)
from litellm.types.responses.main import (
GenericResponseOutputItem,
OutputFunctionToolCall,
OutputText,
)
OpenAiResponsesToChatCompletionStreamIterator)
from litellm.llms.base_llm.guardrail_translation.base_translation import \
BaseTranslation
from litellm.responses.litellm_completion_transformation.transformation import \
LiteLLMCompletionResponsesConfig
from litellm.types.llms.openai import (ChatCompletionToolCallChunk,
ChatCompletionToolParam)
from litellm.types.responses.main import (GenericResponseOutputItem,
OutputFunctionToolCall, OutputText)
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
@ -188,6 +183,18 @@ class OpenAIResponsesHandler(BaseTranslation):
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from Responses API request (tools[].name for function, tools[].server_label for mcp)."""
names: List[str] = []
for tool in data.get("tools") or []:
if not isinstance(tool, dict):
continue
if tool.get("type") == "function" and tool.get("name"):
names.append(str(tool["name"]))
elif tool.get("type") == "mcp" and tool.get("server_label"):
names.append(str(tool["server_label"]))
return names
def _extract_and_transform_tools(
self,
tools: List[Dict[str, Any]],

View File

@ -9779,6 +9779,122 @@
}
]
},
"dashscope/qwen3-max-2026-01-23": {
"litellm_provider": "dashscope",
"max_input_tokens": 258048,
"max_output_tokens": 65536,
"max_tokens": 65536,
"mode": "chat",
"source": "https://www.alibabacloud.com/help/en/model-studio/models",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"tiered_pricing": [
{
"input_cost_per_token": 1.2e-06,
"output_cost_per_token": 6e-06,
"range": [
0,
32000.0
]
},
{
"input_cost_per_token": 2.4e-06,
"output_cost_per_token": 1.2e-05,
"range": [
32000.0,
128000.0
]
},
{
"input_cost_per_token": 3e-06,
"output_cost_per_token": 1.5e-05,
"range": [
128000.0,
252000.0
]
}
]
},
"dashscope/qwen3-next-80b-a3b-instruct": {
"input_cost_per_token": 1.5e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 262144,
"max_output_tokens": 65536,
"max_tokens": 65536,
"mode": "chat",
"output_cost_per_token": 1.2e-06,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_tool_choice": true
},
"dashscope/qwen3-next-80b-a3b-thinking": {
"input_cost_per_token": 1.5e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 262144,
"max_output_tokens": 65536,
"max_tokens": 65536,
"mode": "chat",
"output_cost_per_token": 1.2e-06,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true
},
"dashscope/qwen3-vl-235b-a22b-instruct": {
"input_cost_per_token": 4e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 131072,
"max_output_tokens": 32768,
"max_tokens": 32768,
"mode": "chat",
"output_cost_per_token": 1.6e-06,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_vision": true
},
"dashscope/qwen3-vl-235b-a22b-thinking": {
"input_cost_per_token": 4e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 131072,
"max_output_tokens": 32768,
"max_tokens": 32768,
"mode": "chat",
"output_cost_per_token": 4e-06,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true
},
"dashscope/qwen3-vl-32b-instruct": {
"input_cost_per_token": 1.6e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 131072,
"max_output_tokens": 32768,
"max_tokens": 32768,
"mode": "chat",
"output_cost_per_token": 6.4e-07,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_vision": true
},
"dashscope/qwen3-vl-32b-thinking": {
"input_cost_per_token": 1.6e-07,
"litellm_provider": "dashscope",
"max_input_tokens": 131072,
"max_output_tokens": 32768,
"max_tokens": 32768,
"mode": "chat",
"output_cost_per_token": 2.87e-06,
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true
},
"dashscope/qwen3-vl-plus": {
"litellm_provider": "dashscope",
"max_input_tokens": 260096,
@ -10844,7 +10960,8 @@
"output_cost_per_token": 9e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/NousResearch/Hermes-3-Llama-3.1-405B": {
"max_tokens": 131072,
@ -10854,7 +10971,8 @@
"output_cost_per_token": 1e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/NousResearch/Hermes-3-Llama-3.1-70B": {
"max_tokens": 131072,
@ -10874,7 +10992,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen2.5-72B-Instruct": {
"max_tokens": 32768,
@ -10884,7 +11003,8 @@
"output_cost_per_token": 3.9e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen2.5-7B-Instruct": {
"max_tokens": 32768,
@ -10905,7 +11025,8 @@
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true,
"supports_vision": true
"supports_vision": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-14B": {
"max_tokens": 40960,
@ -10915,7 +11036,8 @@
"output_cost_per_token": 2.4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-235B-A22B": {
"max_tokens": 40960,
@ -10925,7 +11047,8 @@
"output_cost_per_token": 5.4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-235B-A22B-Instruct-2507": {
"max_tokens": 262144,
@ -10935,7 +11058,8 @@
"output_cost_per_token": 6e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-235B-A22B-Thinking-2507": {
"max_tokens": 262144,
@ -10945,7 +11069,8 @@
"output_cost_per_token": 2.9e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-30B-A3B": {
"max_tokens": 40960,
@ -10955,7 +11080,8 @@
"output_cost_per_token": 2.9e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-32B": {
"max_tokens": 40960,
@ -10965,7 +11091,8 @@
"output_cost_per_token": 2.8e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct": {
"max_tokens": 262144,
@ -10975,7 +11102,8 @@
"output_cost_per_token": 1.6e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo": {
"max_tokens": 262144,
@ -10985,7 +11113,8 @@
"output_cost_per_token": 1.2e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-Next-80B-A3B-Instruct": {
"max_tokens": 262144,
@ -10995,7 +11124,8 @@
"output_cost_per_token": 1.4e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Qwen/Qwen3-Next-80B-A3B-Thinking": {
"max_tokens": 262144,
@ -11005,7 +11135,8 @@
"output_cost_per_token": 1.4e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/Sao10K/L3-8B-Lunaris-v1-Turbo": {
"max_tokens": 8192,
@ -11056,7 +11187,8 @@
"cache_read_input_token_cost": 3.3e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/anthropic/claude-4-opus": {
"max_tokens": 200000,
@ -11066,7 +11198,8 @@
"output_cost_per_token": 8.25e-05,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/anthropic/claude-4-sonnet": {
"max_tokens": 200000,
@ -11076,7 +11209,8 @@
"output_cost_per_token": 1.65e-05,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-R1": {
"max_tokens": 163840,
@ -11086,7 +11220,8 @@
"output_cost_per_token": 2.4e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-R1-0528": {
"max_tokens": 163840,
@ -11097,7 +11232,8 @@
"cache_read_input_token_cost": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-R1-0528-Turbo": {
"max_tokens": 32768,
@ -11107,7 +11243,8 @@
"output_cost_per_token": 3e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-R1-Distill-Llama-70B": {
"max_tokens": 131072,
@ -11127,7 +11264,8 @@
"output_cost_per_token": 2.7e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-R1-Turbo": {
"max_tokens": 40960,
@ -11137,7 +11275,8 @@
"output_cost_per_token": 3e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-V3": {
"max_tokens": 163840,
@ -11147,7 +11286,8 @@
"output_cost_per_token": 8.9e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-V3-0324": {
"max_tokens": 163840,
@ -11157,7 +11297,8 @@
"output_cost_per_token": 8.8e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-V3.1": {
"max_tokens": 163840,
@ -11169,7 +11310,8 @@
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true,
"supports_reasoning": true
"supports_reasoning": true,
"supports_function_calling": true
},
"deepinfra/deepseek-ai/DeepSeek-V3.1-Terminus": {
"max_tokens": 163840,
@ -11180,7 +11322,8 @@
"cache_read_input_token_cost": 2.16e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemini-2.0-flash-001": {
"deprecation_date": "2026-06-01",
@ -11191,7 +11334,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemini-2.5-flash": {
"max_tokens": 1000000,
@ -11201,7 +11345,8 @@
"output_cost_per_token": 2.5e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemini-2.5-pro": {
"max_tokens": 1000000,
@ -11211,7 +11356,8 @@
"output_cost_per_token": 1e-05,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemma-3-12b-it": {
"max_tokens": 131072,
@ -11221,7 +11367,8 @@
"output_cost_per_token": 1e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemma-3-27b-it": {
"max_tokens": 131072,
@ -11231,7 +11378,8 @@
"output_cost_per_token": 1.6e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/google/gemma-3-4b-it": {
"max_tokens": 131072,
@ -11241,7 +11389,8 @@
"output_cost_per_token": 8e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Llama-3.2-11B-Vision-Instruct": {
"max_tokens": 131072,
@ -11261,7 +11410,8 @@
"output_cost_per_token": 2e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Llama-3.3-70B-Instruct": {
"max_tokens": 131072,
@ -11271,7 +11421,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Llama-3.3-70B-Instruct-Turbo": {
"max_tokens": 131072,
@ -11281,6 +11432,7 @@
"output_cost_per_token": 3.9e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"deepinfra/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8": {
@ -11291,7 +11443,8 @@
"output_cost_per_token": 6e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Llama-4-Scout-17B-16E-Instruct": {
"max_tokens": 327680,
@ -11301,7 +11454,8 @@
"output_cost_per_token": 3e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Llama-Guard-3-8B": {
"max_tokens": 131072,
@ -11331,7 +11485,8 @@
"output_cost_per_token": 6e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct": {
"max_tokens": 131072,
@ -11341,7 +11496,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"max_tokens": 131072,
@ -11351,7 +11507,8 @@
"output_cost_per_token": 2.8e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct": {
"max_tokens": 131072,
@ -11361,7 +11518,8 @@
"output_cost_per_token": 5e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
"max_tokens": 131072,
@ -11371,7 +11529,8 @@
"output_cost_per_token": 3e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/microsoft/WizardLM-2-8x22B": {
"max_tokens": 65536,
@ -11391,7 +11550,8 @@
"output_cost_per_token": 1.4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/mistralai/Mistral-Nemo-Instruct-2407": {
"max_tokens": 131072,
@ -11401,7 +11561,8 @@
"output_cost_per_token": 4e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/mistralai/Mistral-Small-24B-Instruct-2501": {
"max_tokens": 32768,
@ -11411,7 +11572,8 @@
"output_cost_per_token": 8e-08,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/mistralai/Mistral-Small-3.2-24B-Instruct-2506": {
"max_tokens": 128000,
@ -11421,7 +11583,8 @@
"output_cost_per_token": 2e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 32768,
@ -11431,7 +11594,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/moonshotai/Kimi-K2-Instruct": {
"max_tokens": 131072,
@ -11441,7 +11605,8 @@
"output_cost_per_token": 2e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/moonshotai/Kimi-K2-Instruct-0905": {
"max_tokens": 262144,
@ -11452,7 +11617,8 @@
"cache_read_input_token_cost": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/nvidia/Llama-3.1-Nemotron-70B-Instruct": {
"max_tokens": 131072,
@ -11462,7 +11628,8 @@
"output_cost_per_token": 6e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/nvidia/Llama-3.3-Nemotron-Super-49B-v1.5": {
"max_tokens": 131072,
@ -11472,7 +11639,8 @@
"output_cost_per_token": 4e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/nvidia/NVIDIA-Nemotron-Nano-9B-v2": {
"max_tokens": 131072,
@ -11482,7 +11650,8 @@
"output_cost_per_token": 1.6e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/openai/gpt-oss-120b": {
"max_tokens": 131072,
@ -11492,7 +11661,8 @@
"output_cost_per_token": 4.5e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/openai/gpt-oss-20b": {
"max_tokens": 131072,
@ -11502,7 +11672,8 @@
"output_cost_per_token": 1.5e-07,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepinfra/zai-org/GLM-4.5": {
"max_tokens": 131072,
@ -11512,7 +11683,8 @@
"output_cost_per_token": 1.6e-06,
"litellm_provider": "deepinfra",
"mode": "chat",
"supports_tool_choice": true
"supports_tool_choice": true,
"supports_function_calling": true
},
"deepseek/deepseek-chat": {
"cache_creation_input_token_cost": 0.0,
@ -25806,6 +25978,30 @@
"supports_vision": true,
"tool_use_system_prompt_tokens": 159
},
"openrouter/anthropic/claude-sonnet-4.6": {
"cache_creation_input_token_cost": 3.75e-06,
"cache_creation_input_token_cost_above_200k_tokens": 7.5e-06,
"cache_read_input_token_cost": 3e-07,
"cache_read_input_token_cost_above_200k_tokens": 6e-07,
"input_cost_per_token": 3e-06,
"input_cost_per_token_above_200k_tokens": 6e-06,
"litellm_provider": "openrouter",
"max_input_tokens": 1000000,
"max_output_tokens": 128000,
"max_tokens": 128000,
"mode": "chat",
"output_cost_per_token": 1.5e-05,
"output_cost_per_token_above_200k_tokens": 2.25e-05,
"source": "https://openrouter.ai/anthropic/claude-sonnet-4.6",
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159
},
"openrouter/anthropic/claude-opus-4.5": {
"cache_creation_input_token_cost": 6.25e-06,
"cache_read_input_token_cost": 5e-07,
@ -26156,6 +26352,39 @@
"supports_web_search": true,
"tpm": 800000
},
"openrouter/google/gemini-3.1-pro-preview": {
"cache_read_input_token_cost": 2e-07,
"cache_read_input_token_cost_above_200k_tokens": 4e-07,
"cache_creation_input_token_cost_above_200k_tokens": 2.5e-07,
"input_cost_per_token": 2e-06,
"input_cost_per_token_above_200k_tokens": 4e-06,
"litellm_provider": "openrouter",
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"max_tokens": 65536,
"mode": "chat",
"output_cost_per_token": 1.2e-05,
"output_cost_per_token_above_200k_tokens": 1.8e-05,
"source": "https://openrouter.ai/google/gemini-3.1-pro-preview",
"supported_modalities": [
"text",
"image",
"audio",
"video"
],
"supported_output_modalities": [
"text"
],
"supports_audio_input": true,
"supports_function_calling": true,
"supports_pdf_input": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_system_messages": true,
"supports_tool_choice": true,
"supports_vision": true
},
"openrouter/gryphe/mythomax-l2-13b": {
"input_cost_per_token": 1.875e-06,
"litellm_provider": "openrouter",
@ -26533,6 +26762,29 @@
"supports_reasoning": true,
"supports_tool_choice": true
},
"openrouter/openai/gpt-5.1-codex-max": {
"cache_read_input_token_cost": 1.25e-07,
"input_cost_per_token": 1.25e-06,
"litellm_provider": "openrouter",
"max_input_tokens": 400000,
"max_output_tokens": 128000,
"max_tokens": 128000,
"mode": "chat",
"output_cost_per_token": 1e-05,
"source": "https://openrouter.ai/openai/gpt-5.1-codex-max",
"supported_modalities": [
"text",
"image"
],
"supported_output_modalities": [
"text"
],
"supports_function_calling": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true
},
"openrouter/openai/gpt-5.2": {
"input_cost_per_image": 0,
"cache_read_input_token_cost": 1.75e-07,
@ -26687,6 +26939,19 @@
"supports_tool_choice": true,
"supports_function_calling": true
},
"openrouter/qwen/qwen3-coder-plus": {
"input_cost_per_token": 1e-06,
"litellm_provider": "openrouter",
"max_input_tokens": 997952,
"max_output_tokens": 65536,
"max_tokens": 65536,
"mode": "chat",
"output_cost_per_token": 5e-06,
"source": "https://openrouter.ai/qwen/qwen3-coder-plus",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true
},
"openrouter/qwen/qwen3-235b-a22b-2507": {
"input_cost_per_token": 7.1e-08,
"litellm_provider": "openrouter",
@ -26822,6 +27087,19 @@
"supports_vision": true,
"supports_prompt_caching": false
},
"openrouter/z-ai/glm-5": {
"input_cost_per_token": 8e-07,
"litellm_provider": "openrouter",
"max_input_tokens": 202752,
"max_output_tokens": 128000,
"max_tokens": 128000,
"mode": "chat",
"output_cost_per_token": 2.56e-06,
"source": "https://openrouter.ai/z-ai/glm-5",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true
},
"openrouter/minimax/minimax-m2.1": {
"input_cost_per_token": 2.7e-07,
"output_cost_per_token": 1.2e-06,
@ -34327,6 +34605,36 @@
"supports_tool_choice": true,
"source": "https://aws.amazon.com/bedrock/pricing/"
},
"zai/glm-5": {
"cache_creation_input_token_cost": 0,
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 1e-06,
"output_cost_per_token": 3.2e-06,
"litellm_provider": "zai",
"max_input_tokens": 200000,
"max_output_tokens": 128000,
"mode": "chat",
"supports_function_calling": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"source": "https://docs.z.ai/guides/overview/pricing"
},
"zai/glm-5-code": {
"cache_creation_input_token_cost": 0,
"cache_read_input_token_cost": 3e-07,
"input_cost_per_token": 1.2e-06,
"output_cost_per_token": 5e-06,
"litellm_provider": "zai",
"max_input_tokens": 200000,
"max_output_tokens": 128000,
"mode": "chat",
"supports_function_calling": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"source": "https://docs.z.ai/guides/overview/pricing"
},
"zai/glm-4.7": {
"cache_creation_input_token_cost": 0,
"cache_read_input_token_cost": 1.1e-07,

View File

@ -23,33 +23,11 @@ model_list:
guardrails:
- guardrail_name: "airline-competitor-intent"
guardrail_id: "airline-competitor-intent"
- guardrail_name: "tool_policy"
litellm_params:
guardrail: litellm_content_filter
mode: pre_call
default_on: false
competitor_intent_config:
brand_self:
- emirates
- ek
competitors:
- qatar airways
- qatar
- etihad
locations:
- qatar
- doha
- doh
competitor_aliases:
qatar airways: [qr, doha airline]
qatar: [qr]
policy:
competitor_comparison: refuse
possible_competitor_comparison: reframe
threshold_high: 0.70
threshold_medium: 0.45
threshold_low: 0.30
guardrail: tool_policy
mode: [pre_call, post_call]
default_on: true
mcp_servers:
my_http_server:

View File

@ -77,6 +77,7 @@ class SupportedDBObjectType(str, enum.Enum):
PASS_THROUGH_ENDPOINTS = "pass_through_endpoints"
PROMPTS = "prompts"
MODEL_COST_MAP = "model_cost_map"
TOOLS = "tools"
def __str__(self):
return str(self.value)
@ -2133,7 +2134,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
user_header_mappings: Optional[List[UserHeaderMapping]] = None
supported_db_objects: Optional[List[SupportedDBObjectType]] = Field(
None,
description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map'. If not set, all objects are loaded (default behavior).",
description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map', 'tools'. If not set, all objects are loaded (default behavior).",
)
user_mcp_management_mode: Optional[UserMCPManagementMode] = Field(
None,
@ -3377,6 +3378,11 @@ class ProxyErrorTypes(str, enum.Enum):
Team member is already in team
"""
tool_access_denied = "tool_access_denied"
"""
Tool is not in the allowed tools list for this key/team
"""
@classmethod
def get_model_access_error_type_for_object(
cls, object_type: Literal["key", "user", "team", "org", "project"]
@ -4161,6 +4167,7 @@ class ToolDiscoveryQueueItem(TypedDict, total=False):
key_hash: Optional[str] # hash of virtual key that triggered discovery
team_id: Optional[str] # team that triggered discovery
key_alias: Optional[str] # human-readable key alias
user_agent: Optional[str] # HTTP User-Agent of the caller
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):

View File

@ -58,6 +58,10 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
from litellm.proxy.guardrails.tool_name_extraction import (
TOOL_CAPABLE_CALL_TYPES,
extract_request_tool_names,
)
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router import Router
@ -220,7 +224,48 @@ async def _run_project_checks(
)
async def common_checks(
async def check_tools_allowlist(
request_body: dict,
valid_token: Optional[UserAPIKeyAuth],
team_object: Optional[LiteLLM_TeamTable],
route: str,
) -> None:
"""
Enforce key/team tool allowlist (metadata.allowed_tools). No DB in hot path
effective allowlist is read from valid_token.metadata and valid_token.team_metadata.
Raises ProxyException with tool_access_denied if a tool is not allowed.
"""
from litellm.litellm_core_utils.api_route_to_call_types import (
get_call_types_for_route,
)
if valid_token is None:
return
call_types = get_call_types_for_route(route)
if not call_types or not any(ct.value in TOOL_CAPABLE_CALL_TYPES for ct in call_types):
return
tool_names = extract_request_tool_names(route, request_body)
if not tool_names:
return
key_meta = (valid_token.metadata or {}) if isinstance(valid_token.metadata, dict) else {}
team_meta = (valid_token.team_metadata or {}) if isinstance(valid_token.team_metadata, dict) else {}
key_allowed = key_meta.get("allowed_tools")
team_allowed = team_meta.get("allowed_tools")
effective = key_allowed if (isinstance(key_allowed, list) and len(key_allowed) > 0) else team_allowed
if not isinstance(effective, list) or len(effective) == 0:
return
allowed_set = {str(t) for t in effective}
disallowed = [n for n in tool_names if n not in allowed_set]
if disallowed:
raise ProxyException(
message=f"Tool(s) {disallowed} are not in the allowed tools list for this key/team.",
type=ProxyErrorTypes.tool_access_denied,
param="tools",
code=status.HTTP_403_FORBIDDEN,
)
async def common_checks( # noqa: PLR0915
request_body: dict,
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
@ -477,6 +522,14 @@ async def common_checks(
valid_token=valid_token,
)
# 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path)
await check_tools_allowlist(
request_body=request_body,
valid_token=valid_token,
team_object=team_object,
route=route,
)
return True

View File

@ -13,49 +13,36 @@ import random
import time
import traceback
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Union,
cast,
overload,
)
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union,
cast, overload)
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache, RedisCache
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
BaseDailySpendTransaction,
DailyAgentSpendTransaction,
DailyEndUserSpendTransaction,
DailyOrganizationSpendTransaction,
DailyTagSpendTransaction,
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
Litellm_EntityType,
LiteLLM_UserTable,
SpendLogsMetadata,
SpendLogsPayload,
SpendUpdateQueueItem,
ToolDiscoveryQueueItem,
)
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import (
ToolDiscoveryQueue,
)
from litellm.proxy._types import (DB_CONNECTION_ERROR_TYPES,
BaseDailySpendTransaction,
DailyAgentSpendTransaction,
DailyEndUserSpendTransaction,
DailyOrganizationSpendTransaction,
DailyTagSpendTransaction,
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
Litellm_EntityType, LiteLLM_UserTable,
SpendLogsMetadata, SpendLogsPayload,
SpendUpdateQueueItem, ToolDiscoveryQueueItem)
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import \
DailySpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import \
PodLockManager
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import \
RedisUpdateBuffer
from litellm.proxy.db.db_transaction_queue.spend_update_queue import \
SpendUpdateQueue
from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import \
ToolDiscoveryQueue
from litellm.proxy.route_llm_request import ROUTE_ENDPOINT_MAPPING
if TYPE_CHECKING:
@ -104,12 +91,10 @@ class DBSpendUpdateWriter:
end_time: Optional[datetime],
response_cost: Optional[float],
):
from litellm.proxy.proxy_server import (
disable_spend_logs,
litellm_proxy_budget_name,
prisma_client,
user_api_key_cache,
)
from litellm.proxy.proxy_server import (disable_spend_logs,
litellm_proxy_budget_name,
prisma_client,
user_api_key_cache)
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
try:
@ -124,9 +109,8 @@ class DBSpendUpdateWriter:
hashed_token = token
## CREATE SPEND LOG PAYLOAD ##
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
from litellm.proxy.spend_tracking.spend_tracking_utils import \
get_logging_payload
payload = get_logging_payload(
kwargs=kwargs,
@ -230,6 +214,7 @@ class DBSpendUpdateWriter:
_litellm_params = kwargs.get("litellm_params") or {}
_metadata = _litellm_params.get("metadata") or {}
key_alias = _metadata.get("user_api_key_alias") or None
user_agent = _metadata.get("user_agent") or None
def _enqueue(tool_name: str, origin: str = "user_defined") -> None:
self.tool_discovery_queue.add_update(
@ -239,17 +224,20 @@ class DBSpendUpdateWriter:
key_hash=hashed_token,
team_id=team_id,
key_alias=key_alias,
user_agent=user_agent,
)
)
# --- MCP tool calls ---
sl_object = kwargs.get("standard_logging_object")
if sl_object is not None:
mcp_metadata = (
sl_object.get("metadata", {}) or {}
).get("mcp_tool_call_metadata")
mcp_metadata = (sl_object.get("metadata", {}) or {}).get(
"mcp_tool_call_metadata"
)
if mcp_metadata and isinstance(mcp_metadata, dict):
tool_name = mcp_metadata.get("namespaced_tool_name") or mcp_metadata.get("name")
tool_name = mcp_metadata.get(
"namespaced_tool_name"
) or mcp_metadata.get("name")
mcp_server_name = mcp_metadata.get("mcp_server_name")
if tool_name:
_enqueue(tool_name, origin=mcp_server_name or "user_defined")
@ -280,7 +268,9 @@ class DBSpendUpdateWriter:
_enqueue(name)
# --- Response tool_calls (OpenAI format; Anthropic pass-through converts tool_use here) ---
if completion_response is not None and hasattr(completion_response, "choices"):
if completion_response is not None and hasattr(
completion_response, "choices"
):
for choice in completion_response.choices or []:
message = getattr(choice, "message", None)
if message is None:
@ -768,19 +758,46 @@ class DBSpendUpdateWriter:
daily_end_user_spend_update_transactions,
daily_agent_spend_update_transactions,
daily_tag_spend_update_transactions,
) = await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline()
) = (
await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline()
)
if db_spend_update_transactions is not None:
verbose_proxy_logger.info(
"Spend tracking - committing spend updates from Redis to DB: "
"keys=%d, users=%d, teams=%d, orgs=%d, end_users=%d, team_members=%d, tags=%d",
len(db_spend_update_transactions.get("key_list_transactions") or {}),
len(db_spend_update_transactions.get("user_list_transactions") or {}),
len(db_spend_update_transactions.get("team_list_transactions") or {}),
len(db_spend_update_transactions.get("org_list_transactions") or {}),
len(db_spend_update_transactions.get("end_user_list_transactions") or {}),
len(db_spend_update_transactions.get("team_member_list_transactions") or {}),
len(db_spend_update_transactions.get("tag_list_transactions") or {}),
len(
db_spend_update_transactions.get("key_list_transactions")
or {}
),
len(
db_spend_update_transactions.get("user_list_transactions")
or {}
),
len(
db_spend_update_transactions.get("team_list_transactions")
or {}
),
len(
db_spend_update_transactions.get("org_list_transactions")
or {}
),
len(
db_spend_update_transactions.get(
"end_user_list_transactions"
)
or {}
),
len(
db_spend_update_transactions.get(
"team_member_list_transactions"
)
or {}
),
len(
db_spend_update_transactions.get("tag_list_transactions")
or {}
),
)
await self._commit_spend_updates_to_db(
prisma_client=prisma_client,
@ -985,10 +1002,8 @@ class DBSpendUpdateWriter:
Commits all the spend `UPDATE` transactions to the Database
"""
from litellm.proxy.utils import (
ProxyUpdateSpend,
_raise_failed_update_spend_exception,
)
from litellm.proxy.utils import (ProxyUpdateSpend,
_raise_failed_update_spend_exception)
### UPDATE USER TABLE ###
user_list_transactions = db_spend_update_transactions["user_list_transactions"]
@ -1523,14 +1538,14 @@ class DBSpendUpdateWriter:
# Add cache-related fields if they exist
if "cache_read_input_tokens" in transaction:
common_data[
"cache_read_input_tokens"
] = transaction.get("cache_read_input_tokens", 0)
common_data["cache_read_input_tokens"] = (
transaction.get("cache_read_input_tokens", 0)
)
if "cache_creation_input_tokens" in transaction:
common_data[
"cache_creation_input_tokens"
] = transaction.get(
"cache_creation_input_tokens", 0
common_data["cache_creation_input_tokens"] = (
transaction.get(
"cache_creation_input_tokens", 0
)
)
if entity_type == "tag" and "request_id" in transaction:

View File

@ -0,0 +1,147 @@
"""
Track tool usage for the dashboard: insert into SpendLogToolIndex when spend logs
are written, so "last N requests for tool X" and "how is this tool called in production"
queries are fast.
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Set
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy.utils import PrismaClient
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
"""Extract tool names from OpenAI-style tool_calls list into out."""
if not isinstance(tool_calls, list):
return
for tc in tool_calls:
if not isinstance(tc, dict):
continue
fn = tc.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if name and isinstance(name, str) and name.strip():
out.add(name.strip())
def _parse_tool_names_from_payload(payload: Dict[str, Any]) -> Set[str]:
"""
Extract deduplicated tool names from a spend log payload.
Sources: mcp_namespaced_tool_name, response (tool_calls), proxy_server_request (tools).
"""
tool_names: Set[str] = set()
# Top-level MCP tool name (single tool per request for that flow)
mcp_name = payload.get("mcp_namespaced_tool_name")
if mcp_name and isinstance(mcp_name, str) and mcp_name.strip():
tool_names.add(mcp_name.strip())
# Response: OpenAI-style tool_calls[].function.name or choices[0].message.tool_calls
response_raw = payload.get("response")
if response_raw:
response_obj = (
safe_json_loads(response_raw, default=None)
if isinstance(response_raw, str)
else response_raw
)
if isinstance(response_obj, dict):
_add_tool_calls_to_set(response_obj.get("tool_calls"), tool_names)
choices = response_obj.get("choices")
if isinstance(choices, list) and choices:
msg = choices[0].get("message") if isinstance(choices[0], dict) else None
if isinstance(msg, dict):
_add_tool_calls_to_set(msg.get("tool_calls"), tool_names)
# Request body: tools[].function.name
request_raw = payload.get("proxy_server_request")
if request_raw:
request_obj = (
safe_json_loads(request_raw, default=None)
if isinstance(request_raw, str)
else request_raw
)
if isinstance(request_obj, dict):
body = request_obj.get("body", request_obj)
if isinstance(body, dict):
request_obj = body
if isinstance(request_obj, dict):
tools = request_obj.get("tools")
if isinstance(tools, list):
for t in tools:
if isinstance(t, dict):
fn = t.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if name and isinstance(name, str) and name.strip():
tool_names.add(name.strip())
return tool_names
async def process_spend_logs_tool_usage(
prisma_client: PrismaClient,
logs_to_process: List[Dict[str, Any]],
) -> None:
"""
After spend logs are written: insert SpendLogToolIndex rows from each payload.
Extracts tool names from mcp_namespaced_tool_name, response tool_calls, and
proxy_server_request tools.
"""
if not logs_to_process:
return
index_rows: List[Dict[str, Any]] = []
for payload in logs_to_process:
request_id = payload.get("request_id")
start_time = payload.get("startTime")
if not request_id or not start_time:
continue
if isinstance(start_time, str):
try:
start_time = datetime.fromisoformat(
start_time.replace("Z", "+00:00")
)
except (ValueError, TypeError):
continue
if start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
tool_names = _parse_tool_names_from_payload(payload)
for tool_name in tool_names:
index_rows.append({
"request_id": request_id,
"tool_name": tool_name,
"start_time": start_time,
})
if not index_rows:
return
try:
index_data = []
for r in index_rows:
st = r["start_time"]
if isinstance(st, str):
try:
st = datetime.fromisoformat(st.replace("Z", "+00:00"))
except (ValueError, TypeError):
continue
if st.tzinfo is None:
st = st.replace(tzinfo=timezone.utc)
index_data.append({
"request_id": r["request_id"],
"tool_name": r["tool_name"],
"start_time": st,
})
if index_data:
await prisma_client.db.litellm_spendlogtoolindex.create_many(
data=index_data,
skip_duplicates=True,
)
except Exception as e:
verbose_proxy_logger.warning(
"Tool usage tracking (SpendLogToolIndex) failed (non-fatal): %s", e
)

View File

@ -2,36 +2,64 @@
DB helpers for LiteLLM_ToolTable the global tool registry.
Tools are auto-discovered from LLM responses and upserted here.
Admins use the management endpoints to read and update call_policy.
NOTE: Uses raw SQL (query_raw / execute_raw) instead of Prisma model methods
because the generated Prisma Python client may not have LiteLLM_ToolTable
when running against an older generated schema.
Admins use the management endpoints to read and update input_policy / output_policy.
"""
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem
from litellm.types.tool_management import LiteLLM_ToolTableRow, ToolCallPolicy
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolPolicyOverrideRow,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
def _row_to_model(row: dict) -> LiteLLM_ToolTableRow:
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
model_dump = getattr(row, "model_dump", None)
if callable(model_dump):
row = model_dump()
elif not isinstance(row, dict):
row = {
k: getattr(row, k, None)
for k in (
"tool_id",
"tool_name",
"origin",
"input_policy",
"output_policy",
"call_count",
"assignments",
"key_hash",
"team_id",
"key_alias",
"user_agent",
"last_used_at",
"created_at",
"updated_at",
"created_by",
"updated_by",
)
}
return LiteLLM_ToolTableRow(
tool_id=row.get("tool_id", ""),
tool_name=row.get("tool_name", ""),
origin=row.get("origin"),
call_policy=row.get("call_policy", "untrusted"),
input_policy=row.get("input_policy") or "untrusted",
output_policy=row.get("output_policy") or "untrusted",
call_count=int(row.get("call_count") or 0),
assignments=row.get("assignments"),
key_hash=row.get("key_hash"),
team_id=row.get("team_id"),
key_alias=row.get("key_alias"),
user_agent=row.get("user_agent"),
last_used_at=row.get("last_used_at"),
created_at=row.get("created_at"),
updated_at=row.get("updated_at"),
created_by=row.get("created_by"),
@ -44,10 +72,10 @@ async def batch_upsert_tools(
items: List[ToolDiscoveryQueueItem],
) -> None:
"""
Batch-upsert tool registry rows via raw SQL.
Batch-upsert tool registry rows via Prisma.
On first insert: sets call_policy = "untrusted" (schema default), call_count = 1.
On conflict: increments call_count; preserves existing call_policy.
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
On conflict: increments call_count; preserves existing policies.
"""
if not items:
return
@ -55,6 +83,8 @@ async def batch_upsert_tools(
data = [item for item in items if item.get("tool_name")]
if not data:
return
now = datetime.now(timezone.utc)
table = prisma_client.db.litellm_tooltable
for item in data:
tool_name = item.get("tool_name", "")
origin = item.get("origin") or "user_defined"
@ -62,49 +92,52 @@ async def batch_upsert_tools(
key_hash = item.get("key_hash")
team_id = item.get("team_id")
key_alias = item.get("key_alias")
now = datetime.now(timezone.utc).isoformat()
await prisma_client.db.execute_raw(
'INSERT INTO "LiteLLM_ToolTable" '
"(tool_id, tool_name, origin, call_policy, call_count, created_by, updated_by, key_hash, team_id, key_alias, created_at, updated_at) "
"VALUES ($7, $1, $2, 'untrusted', 1, $3, $3, $4, $5, $6, $8, $8) "
"ON CONFLICT (tool_name) DO UPDATE SET "
"call_count = \"LiteLLM_ToolTable\".call_count + 1, "
"updated_at = $8",
tool_name,
origin,
created_by,
key_hash,
team_id,
key_alias,
str(uuid.uuid4()),
now,
user_agent = item.get("user_agent")
await table.upsert(
where={"tool_name": tool_name},
data={
"create": {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"origin": origin,
"input_policy": "untrusted",
"output_policy": "untrusted",
"call_count": 1,
"created_by": created_by,
"updated_by": created_by,
"key_hash": key_hash,
"team_id": team_id,
"key_alias": key_alias,
"user_agent": user_agent,
"last_used_at": now,
},
"update": {
"call_count": {"increment": 1},
"updated_at": now,
"last_used_at": now,
},
},
)
verbose_proxy_logger.debug(
"tool_registry_writer: upserted %d tool(s)", len(data)
)
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer batch_upsert_tools error: %s", e)
verbose_proxy_logger.error(
"tool_registry_writer batch_upsert_tools error: %s", e
)
async def list_tools(
prisma_client: "PrismaClient",
call_policy: Optional[ToolCallPolicy] = None,
input_policy: Optional[str] = None,
) -> List[LiteLLM_ToolTableRow]:
"""Return all tools, optionally filtered by call_policy."""
"""Return all tools, optionally filtered by input_policy."""
try:
if call_policy is not None:
rows = await prisma_client.db.query_raw(
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
'FROM "LiteLLM_ToolTable" WHERE call_policy = $1 ORDER BY created_at DESC',
call_policy,
)
else:
rows = await prisma_client.db.query_raw(
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
'FROM "LiteLLM_ToolTable" ORDER BY created_at DESC',
)
where = {"input_policy": input_policy} if input_policy is not None else {}
rows = await prisma_client.db.litellm_tooltable.find_many(
where=where,
order={"created_at": "desc"},
)
return [_row_to_model(row) for row in rows]
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
@ -117,15 +150,12 @@ async def get_tool(
) -> Optional[LiteLLM_ToolTableRow]:
"""Return a single tool row by tool_name."""
try:
rows = await prisma_client.db.query_raw(
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
'FROM "LiteLLM_ToolTable" WHERE tool_name = $1',
tool_name,
row = await prisma_client.db.litellm_tooltable.find_unique(
where={"tool_name": tool_name},
)
if not rows:
if row is None:
return None
return _row_to_model(rows[0])
return _row_to_model(row)
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
return None
@ -134,46 +164,279 @@ async def get_tool(
async def update_tool_policy(
prisma_client: "PrismaClient",
tool_name: str,
call_policy: ToolCallPolicy,
updated_by: Optional[str],
input_policy: Optional[str] = None,
output_policy: Optional[str] = None,
) -> Optional[LiteLLM_ToolTableRow]:
"""Update the call_policy for a tool. Upserts the row if it does not exist yet."""
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
try:
_updated_by = updated_by or "system"
now = datetime.now(timezone.utc).isoformat()
await prisma_client.db.execute_raw(
'INSERT INTO "LiteLLM_ToolTable" (tool_id, tool_name, call_policy, created_by, updated_by, created_at, updated_at) '
"VALUES ($4, $1, $2, $3, $3, $5, $5) "
"ON CONFLICT (tool_name) DO UPDATE SET call_policy = $2, updated_by = $3, updated_at = $5",
tool_name,
call_policy,
_updated_by,
str(uuid.uuid4()),
now,
now = datetime.now(timezone.utc)
create_data: dict = {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"input_policy": input_policy or "untrusted",
"output_policy": output_policy or "untrusted",
"created_by": _updated_by,
"updated_by": _updated_by,
"created_at": now,
"updated_at": now,
}
update_data: dict = {
"updated_by": _updated_by,
"updated_at": now,
}
if input_policy is not None:
update_data["input_policy"] = input_policy
if output_policy is not None:
update_data["output_policy"] = output_policy
await prisma_client.db.litellm_tooltable.upsert(
where={"tool_name": tool_name},
data={
"create": create_data,
"update": update_data,
},
)
return await get_tool(prisma_client, tool_name)
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer update_tool_policy error: %s", e)
verbose_proxy_logger.error(
"tool_registry_writer update_tool_policy error: %s", e
)
return None
async def get_tools_by_names(
prisma_client: "PrismaClient",
tool_names: List[str],
) -> Dict[str, str]:
) -> Dict[str, Tuple[str, str]]:
"""
Return a {tool_name: call_policy} map for the given tool names.
Used by the policy enforcement guardrail single batch query, never N+1.
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
"""
if not tool_names:
return {}
try:
placeholders = ", ".join(f"${i+1}" for i in range(len(tool_names)))
rows = await prisma_client.db.query_raw(
f'SELECT tool_name, call_policy FROM "LiteLLM_ToolTable" WHERE tool_name IN ({placeholders})',
*tool_names,
rows = await prisma_client.db.litellm_tooltable.find_many(
where={"tool_name": {"in": tool_names}},
)
return {row["tool_name"]: row["call_policy"] for row in rows}
return {
row.tool_name: (
getattr(row, "input_policy", "untrusted") or "untrusted",
getattr(row, "output_policy", "untrusted") or "untrusted",
)
for row in rows
}
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer get_tools_by_names error: %s", e)
verbose_proxy_logger.error(
"tool_registry_writer get_tools_by_names error: %s", e
)
return {}
async def list_overrides_for_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> List[ToolPolicyOverrideRow]:
"""
Return override-like rows for a tool by finding object permissions that have
this tool in blocked_tools, then resolving each permission to key/team scope for display.
"""
out: List[ToolPolicyOverrideRow] = []
try:
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
where={"blocked_tools": {"has": tool_name}},
include={
"verification_tokens": True,
"teams": True,
},
)
for perm in perms:
op_id = getattr(perm, "object_permission_id", None) or ""
tokens = getattr(perm, "verification_tokens", []) or []
teams = getattr(perm, "teams", []) or []
for t in tokens:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=None,
key_hash=getattr(t, "token", None),
input_policy="blocked",
key_alias=getattr(t, "key_alias", None),
created_at=None,
updated_at=None,
)
)
for team in teams:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=getattr(team, "team_id", None),
key_hash=None,
input_policy="blocked",
key_alias=getattr(team, "team_alias", None),
created_at=None,
updated_at=None,
)
)
return out
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer list_overrides_for_tool error: %s", e
)
return []
class ToolPolicyRegistry:
"""
In-memory registry of tool policies synced from DB.
Hot path uses get_effective_policies only no DB, no cache.
"""
def __init__(self) -> None:
self._tool_input_policies: Dict[str, str] = {}
self._tool_output_policies: Dict[str, str] = {}
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
self._initialized: bool = False
def is_initialized(self) -> bool:
return self._initialized
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
"""Load all tool policies and object-permission blocked_tools from DB."""
try:
tools = await prisma_client.db.litellm_tooltable.find_many()
self._tool_input_policies = {
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
for row in tools
}
self._tool_output_policies = {
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
for row in tools
}
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
self._blocked_tools_by_op_id = {}
for row in perms:
op_id = getattr(row, "object_permission_id", None)
blocked = getattr(row, "blocked_tools", None) or []
if op_id:
self._blocked_tools_by_op_id[op_id] = list(blocked)
self._initialized = True
verbose_proxy_logger.info(
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
len(self._tool_input_policies),
len(self._blocked_tools_by_op_id),
)
except Exception as e:
verbose_proxy_logger.exception(
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
)
raise
def get_input_policy(self, tool_name: str) -> str:
return self._tool_input_policies.get(tool_name, "untrusted")
def get_output_policy(self, tool_name: str) -> str:
return self._tool_output_policies.get(tool_name, "untrusted")
def get_effective_policies(
self,
tool_names: List[str],
object_permission_id: Optional[str] = None,
team_object_permission_id: Optional[str] = None,
) -> Dict[str, str]:
"""
Return effective input_policy per tool from in-memory state.
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
"""
if not tool_names:
return {}
blocked: set = set()
for op_id in (object_permission_id, team_object_permission_id):
if op_id and op_id.strip():
blocked.update(
self._blocked_tools_by_op_id.get(op_id.strip(), [])
)
result: Dict[str, str] = {}
for name in tool_names:
if name in blocked:
result[name] = "blocked"
else:
result[name] = self._tool_input_policies.get(name, "untrusted")
return result
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
def get_tool_policy_registry() -> ToolPolicyRegistry:
"""Return the global ToolPolicyRegistry singleton."""
global _tool_policy_registry
if _tool_policy_registry is None:
_tool_policy_registry = ToolPolicyRegistry()
return _tool_policy_registry
async def add_tool_to_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Add tool_name to the permission's blocked_tools if not already present."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name in current:
return True
current.append(tool_name)
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
)
return False
async def remove_tool_from_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name not in current:
return False
current = [t for t in current if t != tool_name]
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
e,
)
return False

View File

@ -1,13 +1,16 @@
"""
Tool Policy Guardrail
Reads call_policy from LiteLLM_ToolTable and enforces it on LLM requests/responses.
Reads input_policy / output_policy from LiteLLM_ToolTable and enforces them.
Policy values:
"trusted" - allow through (no action)
"untrusted" - allow through (no action; default for newly discovered tools)
Input policy values:
"untrusted" - allow through (default for newly discovered tools)
"trusted" - only allow if conversation contains no untrusted tool output
"blocked" - raise HTTPException, preventing the tool call
"dual_llm" - (Phase 3) send to second LLM for verification; currently treated as allowed
Output policy values:
"untrusted" - output may be tainted (default)
"trusted" - output is verified safe
Configuration in proxy config YAML:
guardrails:
@ -15,25 +18,18 @@ Configuration in proxy config YAML:
litellm_params:
guardrail: tool_policy
mode: post_call
or both pre and post call:
- guardrail_name: "tool_policy"
litellm_params:
guardrail: tool_policy
mode: during_call # runs before LLM and on response
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.caching.dual_cache import DualCache
from litellm.constants import TOOL_POLICY_CACHE_TTL_SECONDS
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import GenericGuardrailAPIInputs
@ -43,12 +39,71 @@ if TYPE_CHECKING:
GUARDRAIL_NAME = "tool_policy"
def _get_request_object_permission_ids(
request_data: dict,
) -> Tuple[Optional[str], Optional[str]]:
"""Extract object_permission_id and team_object_permission_id from request_data."""
if not request_data:
return None, None
for key in ("litellm_metadata", "metadata"):
meta = request_data.get(key)
if not isinstance(meta, dict):
continue
auth = meta.get("user_api_key_auth")
if auth is not None and hasattr(auth, "object_permission_id"):
key_op = getattr(auth, "object_permission_id", None)
team_op = getattr(auth, "team_object_permission_id", None)
if key_op is not None or team_op is not None:
return (
str(key_op).strip() if key_op else None,
str(team_op).strip() if team_op else None,
)
key_op = meta.get("user_api_key_object_permission_id")
team_op = meta.get("user_api_key_team_object_permission_id")
if key_op is not None or team_op is not None:
return (
str(key_op).strip() if key_op else None,
str(team_op).strip() if team_op else None,
)
return None, None
def _get_request_route_from_data(request_data: dict) -> Optional[str]:
"""Get request route from request_data (metadata or top-level)."""
route = request_data.get("user_api_key_request_route")
if route:
return route
meta = request_data.get("metadata") or request_data.get("litellm_metadata") or {}
return meta.get("user_api_key_request_route")
def _resolve_tool_names_from_messages(messages: List[dict]) -> Dict[str, str]:
"""
Build a map of tool_call_id -> tool_name from assistant messages' tool_calls.
Used to resolve which tool produced each tool result in the conversation.
"""
mapping: Dict[str, str] = {}
for msg in messages:
if msg.get("role") != "assistant":
continue
tool_calls = msg.get("tool_calls") or []
for tc in tool_calls:
if isinstance(tc, dict):
tc_id = tc.get("id")
fn = (tc.get("function") or {}).get("name")
else:
tc_id = getattr(tc, "id", None)
fn_obj = getattr(tc, "function", None)
fn = getattr(fn_obj, "name", None) if fn_obj else None
if tc_id and fn:
mapping[tc_id] = fn
return mapping
class ToolPolicyGuardrail(CustomGuardrail):
"""
Guardrail that enforces per-tool call policies stored in LiteLLM_ToolTable.
Tools with call_policy="blocked" are rejected before/after the LLM call.
Tools with call_policy="trusted" or "untrusted" pass through unchanged.
Guardrail that enforces per-tool input/output policies from the in-memory
ToolPolicyRegistry (synced from DB).
"""
def __init__(self, **kwargs: Any) -> None:
@ -59,7 +114,6 @@ class ToolPolicyGuardrail(CustomGuardrail):
GuardrailEventHooks.during_call,
]
super().__init__(**kwargs)
self._policy_cache: DualCache = DualCache()
@log_guardrail_information
async def apply_guardrail(
@ -70,12 +124,7 @@ class ToolPolicyGuardrail(CustomGuardrail):
logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> GenericGuardrailAPIInputs:
"""
Enforce tool policies on both request tools and response tool_calls.
- input_type="request": check inputs["tools"] (tool definitions in the LLM request)
- input_type="response": check inputs["tool_calls"] (tool_calls in the LLM response)
Raises HTTPException (400) if any tool is "blocked".
Enforce input_policy and output_policy trust chain on request tools / response tool_calls.
"""
if input_type == "request":
tools = inputs.get("tools") or []
@ -86,7 +135,11 @@ class ToolPolicyGuardrail(CustomGuardrail):
and isinstance(t.get("function"), dict)
and t["function"].get("name")
]
else: # response
if not tool_names:
route = _get_request_route_from_data(request_data)
if route:
tool_names = extract_request_tool_names(route, request_data)
else:
tool_calls = inputs.get("tool_calls") or []
tool_names = []
for tc in tool_calls:
@ -101,12 +154,25 @@ class ToolPolicyGuardrail(CustomGuardrail):
if not tool_names:
return inputs
policy_map = await self._get_policies_cached(tool_names)
object_permission_id, team_object_permission_id = (
_get_request_object_permission_ids(request_data)
)
from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry
registry = get_tool_policy_registry()
if not registry.is_initialized():
return inputs
# Stage 1: Check for blocked tools (input_policy=blocked or per-key/team override)
policy_map = registry.get_effective_policies(
tool_names,
object_permission_id=object_permission_id,
team_object_permission_id=team_object_permission_id,
)
blocked = [name for name in tool_names if policy_map.get(name) == "blocked"]
if blocked:
verbose_proxy_logger.warning(
"ToolPolicyGuardrail: blocking tool(s) %s (policy=blocked)", blocked
"ToolPolicyGuardrail: blocking tool(s) %s (input_policy=blocked)", blocked
)
raise HTTPException(
status_code=400,
@ -117,47 +183,47 @@ class ToolPolicyGuardrail(CustomGuardrail):
},
)
# Stage 2: Trust chain enforcement (response path only)
# For each tool with input_policy=trusted, check if conversation
# contains output from tools with output_policy=untrusted
if input_type == "response":
trusted_input_tools = [
name for name in tool_names if policy_map.get(name) == "trusted"
]
if trusted_input_tools:
messages = request_data.get("messages") or []
tc_id_to_name = _resolve_tool_names_from_messages(messages)
untrusted_sources: List[str] = []
for msg in messages:
if msg.get("role") != "tool":
continue
tool_call_id = msg.get("tool_call_id")
source_tool = tc_id_to_name.get(tool_call_id, "") if tool_call_id else ""
if not source_tool:
continue
if registry.get_output_policy(source_tool) == "untrusted":
if source_tool not in untrusted_sources:
untrusted_sources.append(source_tool)
if untrusted_sources:
verbose_proxy_logger.warning(
"ToolPolicyGuardrail: trust chain violation — %s require trusted input "
"but conversation has untrusted output from %s",
trusted_input_tools,
untrusted_sources,
)
raise HTTPException(
status_code=400,
detail={
"error": "Violated tool policy",
"blocked_tools": trusted_input_tools,
"untrusted_sources": untrusted_sources,
"message": (
f"{', '.join(trusted_input_tools)} requires trusted input but "
f"conversation contains untrusted output from {', '.join(untrusted_sources)}."
),
},
)
return inputs
async def _get_policies_cached(self, tool_names: List[str]) -> Dict[str, str]:
"""
Batch-fetch call_policy for the given tool names.
Caches per individual tool name (not per combination) so that adding
a new tool to a request doesn't invalidate the cached policies for all
the other tools already in the cache.
"""
from litellm.proxy.db.tool_registry_writer import get_tools_by_names
from litellm.proxy.proxy_server import prisma_client
if not tool_names or prisma_client is None:
return {}
result: Dict[str, str] = {}
cache_misses: List[str] = []
for name in tool_names:
cached = await self._policy_cache.async_get_cache(f"tool_policy:{name}")
if cached is not None and isinstance(cached, str):
result[name] = cached
else:
cache_misses.append(name)
if cache_misses:
fetched = await get_tools_by_names(
prisma_client=prisma_client, tool_names=cache_misses
)
for name, policy in fetched.items():
result[name] = policy
await self._policy_cache.async_set_cache(
key=f"tool_policy:{name}",
value=policy,
ttl=TOOL_POLICY_CACHE_TTL_SECONDS,
)
verbose_proxy_logger.debug(
"ToolPolicyGuardrail: fetched %d policies from DB (cache hits: %d)",
len(cache_misses),
len(tool_names) - len(cache_misses),
)
return result

View File

@ -0,0 +1,85 @@
"""
Extract tool names from request body by route/call type.
Used by auth (check_tools_allowlist) and ToolPolicyGuardrail so tool-format
knowledge lives in one place. Uses guardrail translation handlers where available,
with standalone extractors for generate_content and MCP.
"""
from typing import Any, Dict, List
from litellm.litellm_core_utils.api_route_to_call_types import get_call_types_for_route
from litellm.llms import load_guardrail_translation_mappings
from litellm.types.utils import CallTypes
# Call types that have no guardrail translation handler; we use standalone extractors
STANDALONE_EXTRACTORS: Dict[str, Any] = {}
def _extract_generate_content_tool_names(data: dict) -> List[str]:
"""Google generateContent: tools[].functionDeclarations[].name"""
names: List[str] = []
for tool in data.get("tools") or []:
if not isinstance(tool, dict):
continue
for decl in tool.get("functionDeclarations") or []:
if isinstance(decl, dict) and decl.get("name"):
names.append(str(decl["name"]))
return names
def _extract_mcp_tool_names(data: dict) -> List[str]:
"""MCP call_tool: name or mcp_tool_name in body"""
names: List[str] = []
name = data.get("name") or data.get("mcp_tool_name")
if name:
names.append(str(name))
return names
def _register_standalone_extractors() -> None:
if STANDALONE_EXTRACTORS:
return
STANDALONE_EXTRACTORS[CallTypes.generate_content.value] = _extract_generate_content_tool_names
STANDALONE_EXTRACTORS[CallTypes.agenerate_content.value] = _extract_generate_content_tool_names
STANDALONE_EXTRACTORS[CallTypes.call_mcp_tool.value] = _extract_mcp_tool_names
# Tool-capable call types (routes that can send tools in the request)
TOOL_CAPABLE_CALL_TYPES = frozenset({
CallTypes.completion.value,
CallTypes.acompletion.value,
CallTypes.responses.value,
CallTypes.aresponses.value,
CallTypes.anthropic_messages.value,
CallTypes.generate_content.value,
CallTypes.agenerate_content.value,
CallTypes.call_mcp_tool.value,
})
def extract_request_tool_names(route: str, data: dict) -> List[str]:
"""
Extract tool names from the request body for the given route.
Uses guardrail translation handlers when available, else standalone extractors
for generate_content and MCP. Returns [] for non-tool-capable routes or when
no tools are present.
"""
call_types = get_call_types_for_route(route)
if not call_types:
return []
_register_standalone_extractors()
mappings = load_guardrail_translation_mappings()
for call_type in call_types:
if not isinstance(call_type, CallTypes):
continue
if call_type.value not in TOOL_CAPABLE_CALL_TYPES:
continue
if call_type.value in STANDALONE_EXTRACTORS:
return STANDALONE_EXTRACTORS[call_type.value](data)
handler_cls = mappings.get(call_type)
if handler_cls is not None:
names = handler_cls().extract_request_tool_names(data)
if names:
return names
return []

View File

@ -1091,6 +1091,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
] = user_api_key_dict.user_max_budget
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
data[_metadata_variable_name]["user_api_key_team_metadata"] = (
user_api_key_dict.team_metadata
)
data[_metadata_variable_name]["user_api_key_object_permission_id"] = (
getattr(user_api_key_dict, "object_permission_id", None)
)
data[_metadata_variable_name]["user_api_key_team_object_permission_id"] = (
getattr(user_api_key_dict, "team_object_permission_id", None)
)
data[_metadata_variable_name]["headers"] = _headers
data[_metadata_variable_name]["endpoint"] = str(request.url)

View File

@ -4,27 +4,87 @@ TOOL POLICY MANAGEMENT
All /tool management endpoints
GET /v1/tool/list - List all discovered tools and their policies
GET /v1/tool/policy/options - List available input/output policy options with descriptions
GET /v1/tool/{tool_name} - Get a single tool's details
POST /v1/tool/policy - Update the call_policy for a tool
POST /v1/tool/policy - Update the input_policy / output_policy for a tool
"""
from typing import Optional
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolCallPolicy,
ToolDetailResponse,
ToolInputPolicy,
ToolListResponse,
ToolOutputPolicy,
ToolPolicyOption,
ToolPolicyOptionsResponse,
ToolPolicyUpdateRequest,
ToolPolicyUpdateResponse,
ToolUsageLogEntry,
ToolUsageLogsResponse,
)
router = APIRouter()
TOOL_POLICY_OPTIONS = ToolPolicyOptionsResponse(
input_policies=[
ToolPolicyOption(
value="untrusted",
label="Untrusted",
description="Tool accepts any input, including data from untrusted tool outputs. Default for newly discovered tools.",
),
ToolPolicyOption(
value="trusted",
label="Trusted",
description="Tool requires trusted input. Blocked if the conversation contains output from any tool with output_policy=untrusted.",
),
ToolPolicyOption(
value="blocked",
label="Blocked",
description="Tool is completely prohibited. Any attempt to call it is rejected.",
),
],
output_policies=[
ToolPolicyOption(
value="untrusted",
label="Untrusted",
description="Tool output may contain unsafe content (prompt injection, risky code). Downstream tools with input_policy=trusted will be blocked.",
),
ToolPolicyOption(
value="trusted",
label="Trusted",
description="Tool output is verified safe. Will not trigger trust-chain blocks on downstream tools.",
),
],
)
@router.get(
"/v1/tool/policy/options",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolPolicyOptionsResponse,
)
async def get_tool_policy_options(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Return the available input and output policy options with descriptions.
Static data no DB call.
"""
return TOOL_POLICY_OPTIONS
@router.get(
"/v1/tool/list",
@ -33,14 +93,14 @@ router = APIRouter()
response_model=ToolListResponse,
)
async def list_tools(
call_policy: Optional[ToolCallPolicy] = None,
input_policy: Optional[ToolInputPolicy] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
List all auto-discovered tools and their call policies.
List all auto-discovered tools and their policies.
Parameters:
- call_policy: Optional filter one of "trusted", "untrusted", "dual_llm", "blocked"
- input_policy: Optional filter one of "trusted", "untrusted", "blocked"
"""
from litellm.proxy.db.tool_registry_writer import list_tools as db_list_tools
from litellm.proxy.proxy_server import prisma_client
@ -51,13 +111,201 @@ async def list_tools(
)
try:
tools = await db_list_tools(prisma_client=prisma_client, call_policy=call_policy)
tools = await db_list_tools(
prisma_client=prisma_client, input_policy=input_policy
)
return ToolListResponse(tools=tools, total=len(tools))
except Exception as e:
verbose_proxy_logger.exception("Error listing tools: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/v1/tool/{tool_name:path}/detail",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolDetailResponse,
)
async def get_tool_detail(
tool_name: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a single tool with its policy overrides (for UI detail view).
"""
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
from litellm.proxy.db.tool_registry_writer import list_overrides_for_tool
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
if tool is None:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
overrides = await list_overrides_for_tool(
prisma_client=prisma_client, tool_name=tool_name
)
return ToolDetailResponse(tool=tool, overrides=overrides)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error getting tool detail: %s", e)
raise HTTPException(status_code=500, detail=str(e))
def _input_snippet_for_tool_log(sl: Any, max_len: int = 200) -> Optional[str]:
"""Short snippet from messages or proxy_server_request for tool usage log row."""
if sl is None:
return None
messages = getattr(sl, "messages", None)
if messages is not None:
s = _snippet_str(messages, max_len)
if s:
return s
psr = getattr(sl, "proxy_server_request", None)
if not psr:
return None
if isinstance(psr, str):
import json
try:
psr = json.loads(psr)
except Exception:
return _snippet_str(psr, max_len)
if isinstance(psr, dict):
msgs = psr.get("messages")
if msgs is None and isinstance(psr.get("body"), dict):
msgs = psr["body"].get("messages")
s = _snippet_str(msgs, max_len)
if s:
return s
return _snippet_str(psr, max_len)
def _snippet_str(text: Any, max_len: int = 200) -> Optional[str]:
if text is None:
return None
if isinstance(text, str):
s = text
elif isinstance(text, list):
parts = []
for item in text:
if isinstance(item, dict) and "content" in item:
c = item["content"]
parts.append(c if isinstance(c, str) else str(c))
else:
parts.append(str(item))
s = " ".join(parts)
else:
s = str(text)
if not s or s == "{}":
return None
return (s[:max_len] + "...") if len(s) > max_len else s
@router.get(
"/v1/tool/{tool_name:path}/logs",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
response_model=ToolUsageLogsResponse,
)
async def get_tool_usage_logs(
tool_name: str,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
start_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
end_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Return paginated spend logs for requests that used this tool (from SpendLogToolIndex).
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
try:
where: dict = {"tool_name": tool_name}
if start_date or end_date:
start_time_filter: Optional[datetime] = None
end_time_filter: Optional[datetime] = None
if start_date:
try:
start_time_filter = datetime.strptime(
start_date + "T00:00:00", "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=timezone.utc)
except ValueError:
pass
if end_date:
try:
end_time_filter = datetime.strptime(
end_date + "T23:59:59", "%Y-%m-%dT%H:%M:%S"
).replace(tzinfo=timezone.utc)
except ValueError:
pass
if start_time_filter is not None or end_time_filter is not None:
where["start_time"] = {}
if start_time_filter is not None:
where["start_time"]["gte"] = start_time_filter
if end_time_filter is not None:
where["start_time"]["lte"] = end_time_filter
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
where=where,
order={"start_time": "desc"},
skip=(page - 1) * page_size,
take=page_size,
)
request_ids = [r.request_id for r in index_rows]
if not request_ids:
return ToolUsageLogsResponse(
logs=[], total=total, page=page, page_size=page_size
)
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
where={"request_id": {"in": request_ids}}
)
log_by_id = {s.request_id: s for s in spend_logs}
logs_out: List[ToolUsageLogEntry] = []
for r in index_rows:
sl = log_by_id.get(r.request_id)
if not sl:
continue
ts = (
sl.startTime.isoformat()
if hasattr(sl.startTime, "isoformat")
else str(sl.startTime)
)
logs_out.append(
ToolUsageLogEntry(
id=sl.request_id,
timestamp=ts,
model=getattr(sl, "model", None) or None,
spend=getattr(sl, "spend", None),
total_tokens=getattr(sl, "total_tokens", None),
input_snippet=_input_snippet_for_tool_log(sl),
)
)
return ToolUsageLogsResponse(
logs=logs_out, total=total, page=page, page_size=page_size
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error getting tool usage logs: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/v1/tool/{tool_name:path}",
tags=["tool management"],
@ -70,9 +318,6 @@ async def get_tool(
):
"""
Get details for a single tool.
Parameters:
- tool_name: The tool name (supports namespaced names with slashes)
"""
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
from litellm.proxy.proxy_server import prisma_client
@ -85,9 +330,7 @@ async def get_tool(
try:
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
if tool is None:
raise HTTPException(
status_code=404, detail=f"Tool '{tool_name}' not found"
)
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
return tool
except HTTPException:
raise
@ -96,6 +339,80 @@ async def get_tool(
raise HTTPException(status_code=500, detail=str(e))
async def _resolve_key_hash_to_object_permission_id(
prisma_client: "PrismaClient",
key_hash: str,
) -> Optional[str]:
"""Resolve key (hash or raw) to object_permission_id; create permission if key has none."""
from litellm.proxy.proxy_server import hash_token
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
if not hashed:
return None
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed}
)
if row is None:
return None
op_id = getattr(row, "object_permission_id", None)
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
where={"token": hashed, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed}
)
return getattr(row, "object_permission_id", None) if row else None
return new_id
async def _resolve_team_id_to_object_permission_id(
prisma_client: "PrismaClient",
team_id: str,
) -> Optional[str]:
"""Resolve team_id to object_permission_id; create permission if team has none."""
if not team_id or not team_id.strip():
return None
team_id_clean = team_id.strip()
row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
if row is None:
return None
op_id = getattr(row, "object_permission_id", None)
if op_id:
return op_id
new_id = str(uuid.uuid4())
await prisma_client.db.litellm_objectpermissiontable.create(
data={"object_permission_id": new_id, "blocked_tools": []}
)
updated_count = await prisma_client.db.litellm_teamtable.update_many(
where={"team_id": team_id_clean, "object_permission_id": None},
data={"object_permission_id": new_id},
)
if updated_count == 0:
await prisma_client.db.litellm_objectpermissiontable.delete(
where={"object_permission_id": new_id}
)
row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id_clean},
select={"object_permission_id": True},
)
return getattr(row, "object_permission_id", None) if row else None
return new_id
@router.post(
"/v1/tool/policy",
tags=["tool management"],
@ -107,15 +424,20 @@ async def update_tool_policy(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Set the call policy for a tool.
Set the input_policy and/or output_policy for a tool (global), or block for a specific team/key (override).
Parameters:
- tool_name: str - The tool to update
- call_policy: "trusted" | "untrusted" | "dual_llm" | "blocked"
Setting a tool to "blocked" will cause the ToolPolicyGuardrail to remove
that tool_call from LLM responses before returning them to the client.
- input_policy: optional - "trusted" | "untrusted" | "blocked"
- output_policy: optional - "trusted" | "untrusted"
- team_id: optional - if set, create/update override for this team only
- key_hash: optional - if set, create/update override for this key only
"""
from litellm.proxy.db.tool_registry_writer import (
add_tool_to_object_permission_blocked,
get_tool_policy_registry,
remove_tool_from_object_permission_blocked,
)
from litellm.proxy.db.tool_registry_writer import (
update_tool_policy as db_update_tool_policy,
)
@ -127,19 +449,80 @@ async def update_tool_policy(
)
try:
if data.team_id is not None or data.key_hash is not None:
if data.team_id is not None and data.key_hash is not None:
raise HTTPException(
status_code=400,
detail="Provide either team_id or key_hash, not both",
)
if data.key_hash is not None:
op_id = await _resolve_key_hash_to_object_permission_id(
prisma_client, data.key_hash
)
else:
op_id = await _resolve_team_id_to_object_permission_id(
prisma_client, data.team_id or ""
)
if op_id is None:
raise HTTPException(
status_code=404,
detail="Key or team not found for the given identifier",
)
is_blocking = data.input_policy == "blocked"
if is_blocking:
ok = await add_tool_to_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=data.tool_name,
)
else:
ok = await remove_tool_from_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=data.tool_name,
)
if not ok:
raise HTTPException(
status_code=500,
detail=f"Failed to update policy override for tool '{data.tool_name}'",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return ToolPolicyUpdateResponse(
tool_name=data.tool_name,
input_policy=data.input_policy,
output_policy=data.output_policy,
updated=True,
team_id=data.team_id,
key_hash=data.key_hash,
)
if data.input_policy is None and data.output_policy is None:
raise HTTPException(
status_code=400,
detail="At least one of input_policy or output_policy must be provided",
)
updated = await db_update_tool_policy(
prisma_client=prisma_client,
tool_name=data.tool_name,
call_policy=data.call_policy,
updated_by=user_api_key_dict.user_id,
input_policy=data.input_policy,
output_policy=data.output_policy,
)
if updated is None:
raise HTTPException(
status_code=500, detail=f"Failed to update policy for tool '{data.tool_name}'"
status_code=500,
detail=f"Failed to update policy for tool '{data.tool_name}'",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return ToolPolicyUpdateResponse(
tool_name=updated.tool_name,
call_policy=updated.call_policy,
input_policy=updated.input_policy,
output_policy=updated.output_policy,
updated=True,
)
except HTTPException:
@ -147,3 +530,77 @@ async def update_tool_policy(
except Exception as e:
verbose_proxy_logger.exception("Error updating tool policy: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/v1/tool/{tool_name:path}/overrides",
tags=["tool management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_tool_policy_override(
tool_name: str,
team_id: Optional[str] = Query(
None, description="Team ID of the override to remove"
),
key_hash: Optional[str] = Query(
None, description="Key hash of the override to remove"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Remove a policy override for a tool. Specify the override by team_id or key_hash
(exactly one required).
"""
from litellm.proxy.db.tool_registry_writer import (
get_tool_policy_registry,
remove_tool_from_object_permission_blocked,
)
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
)
if team_id is None and key_hash is None:
raise HTTPException(
status_code=400,
detail="At least one of team_id or key_hash is required to identify the override",
)
if team_id is not None and key_hash is not None:
raise HTTPException(
status_code=400,
detail="Provide either team_id or key_hash, not both",
)
try:
if key_hash is not None:
op_id = await _resolve_key_hash_to_object_permission_id(
prisma_client, key_hash
)
else:
op_id = await _resolve_team_id_to_object_permission_id(
prisma_client, team_id or ""
)
if op_id is None:
raise HTTPException(
status_code=404,
detail="Key or team not found for the given identifier",
)
deleted = await remove_tool_from_object_permission_blocked(
prisma_client=prisma_client,
object_permission_id=op_id,
tool_name=tool_name,
)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"No override found for tool '{tool_name}' with the given scope",
)
registry = get_tool_policy_registry()
if registry.is_initialized():
await registry.sync_tool_policy_from_db(prisma_client)
return {"deleted": True, "tool_name": tool_name}
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception("Error deleting tool policy override: %s", e)
raise HTTPException(status_code=500, detail=str(e))

View File

@ -4411,6 +4411,9 @@ class ProxyConfig:
if self._should_load_db_object(object_type="search_tools"):
await self._init_search_tools_in_db(prisma_client=prisma_client)
if self._should_load_db_object(object_type="tools"):
await self._init_tool_policy_in_db(prisma_client=prisma_client)
if self._should_load_db_object(object_type="model_cost_map"):
await self._check_and_reload_model_cost_map(prisma_client=prisma_client)
@ -4847,6 +4850,24 @@ class ProxyConfig:
)
)
async def _init_tool_policy_in_db(self, prisma_client: PrismaClient):
"""
Initialize tool policy from database into the in-memory registry.
Synced periodically by add_deployment -> _init_non_llm_objects_in_db.
"""
from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry
try:
registry = get_tool_policy_registry()
await registry.sync_tool_policy_from_db(prisma_client=prisma_client)
verbose_proxy_logger.debug("Successfully synced tool policy from DB")
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.py::ProxyConfig:_init_tool_policy_in_db - {}".format(
str(e)
)
)
async def _init_vector_stores_in_db(self, prisma_client: PrismaClient):
from litellm.vector_stores.vector_store_registry import VectorStoreRegistry
@ -10577,6 +10598,12 @@ async def async_queue_request(
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_object_permission_id"] = getattr(
user_api_key_dict, "object_permission_id", None
)
data["metadata"]["user_api_key_team_object_permission_id"] = getattr(
user_api_key_dict, "team_object_permission_id", None
)
data["metadata"]["endpoint"] = str(request.url)
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
@ -11093,9 +11120,7 @@ async def get_favicon():
if favicon_url.startswith(("http://", "https://")):
try:
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.custom_http import httpxSpecialProvider
async_client = get_async_httpx_client(

View File

@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
vector_stores String[] @default([])
agents String[] @default([])
agent_access_groups String[] @default([])
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
teams LiteLLM_TeamTable[]
projects LiteLLM_ProjectTable[]
verification_tokens LiteLLM_VerificationToken[]
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
@@index([policy_id, start_time])
}
// Index for fast "last N logs for tool" from SpendLogs see how a tool is called in production
model LiteLLM_SpendLogToolIndex {
request_id String
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
start_time DateTime
@@id([request_id, tool_name])
@@index([tool_name, start_time])
}
// Prompt table for storing prompt configurations
model LiteLLM_PromptTable {
id String @id @default(uuid())
@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable {
updated_by String?
}
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here
model LiteLLM_ToolTable {
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
output_policy String @default("untrusted") // "trusted" | "untrusted"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
user_agent String? // user-agent of the first request that discovered this tool
last_used_at DateTime? // timestamp of the most recent call
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
@@index([call_policy])
@@index([input_policy])
@@index([output_policy])
@@index([team_id])
}

View File

@ -3583,8 +3583,9 @@ class PrismaClient:
def _get_engine_pid(self) -> int:
try:
engine = self.db._original_prisma._engine # type: ignore[attr-defined]
if engine is not None and engine.process is not None:
return engine.process.pid
process = getattr(engine, "process", None) if engine is not None else None
if process is not None:
return process.pid
except (AttributeError, TypeError):
pass
return 0
@ -4688,6 +4689,19 @@ async def update_spend_logs_job(
guardrail_tracking_err,
)
# Tool usage tracking (same batch): SpendLogToolIndex for "last N requests for tool X"
try:
from litellm.proxy.db.spend_log_tool_index import process_spend_logs_tool_usage
await process_spend_logs_tool_usage(
prisma_client=prisma_client,
logs_to_process=logs_to_process,
)
except Exception as tool_tracking_err:
verbose_proxy_logger.warning(
"Spend tracking - tool usage tracking failed (non-fatal): %s",
tool_tracking_err,
)
async def _monitor_spend_logs_queue(
prisma_client: PrismaClient,

View File

@ -5,21 +5,27 @@ Pydantic models for Tool Policy management endpoints.
from datetime import datetime
from typing import Dict, List, Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
ToolCallPolicy = Literal["trusted", "untrusted", "dual_llm", "blocked"]
ToolInputPolicy = Literal["trusted", "untrusted", "blocked"]
ToolOutputPolicy = Literal["trusted", "untrusted"]
class LiteLLM_ToolTableRow(BaseModel):
tool_id: str
tool_name: str
origin: Optional[str] = None
call_policy: ToolCallPolicy = "untrusted"
input_policy: ToolInputPolicy = "untrusted"
output_policy: ToolOutputPolicy = "untrusted"
call_count: int = 0
assignments: Optional[Dict] = None
key_hash: Optional[str] = None
team_id: Optional[str] = None
key_alias: Optional[str] = None
user_agent: Optional[str] = None
last_used_at: Optional[datetime] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
created_by: Optional[str] = None
@ -33,10 +39,62 @@ class ToolListResponse(BaseModel):
class ToolPolicyUpdateRequest(BaseModel):
tool_name: str
call_policy: ToolCallPolicy
input_policy: Optional[ToolInputPolicy] = None
output_policy: Optional[ToolOutputPolicy] = None
team_id: Optional[str] = None
key_hash: Optional[str] = None
key_alias: Optional[str] = None
class ToolPolicyUpdateResponse(BaseModel):
tool_name: str
call_policy: ToolCallPolicy
input_policy: Optional[ToolInputPolicy] = None
output_policy: Optional[ToolOutputPolicy] = None
updated: bool
team_id: Optional[str] = None
key_hash: Optional[str] = None
class ToolPolicyOverrideRow(BaseModel):
override_id: str
tool_name: str
team_id: Optional[str] = None
key_hash: Optional[str] = None
input_policy: ToolInputPolicy = "blocked"
key_alias: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class ToolPolicyOption(BaseModel):
value: str
label: str
description: str
class ToolPolicyOptionsResponse(BaseModel):
input_policies: List[ToolPolicyOption]
output_policies: List[ToolPolicyOption]
class ToolDetailResponse(BaseModel):
tool: LiteLLM_ToolTableRow
overrides: List[ToolPolicyOverrideRow] = Field(default_factory=list)
class ToolUsageLogEntry(BaseModel):
"""One spend log row for a tool call (for UI "recent logs" table)."""
id: str # request_id
timestamp: str
model: Optional[str] = None
spend: Optional[float] = None
total_tokens: Optional[int] = None
input_snippet: Optional[str] = None
class ToolUsageLogsResponse(BaseModel):
logs: List[ToolUsageLogEntry]
total: int
page: int
page_size: int

View File

@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
vector_stores String[] @default([])
agents String[] @default([])
agent_access_groups String[] @default([])
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
teams LiteLLM_TeamTable[]
projects LiteLLM_ProjectTable[]
verification_tokens LiteLLM_VerificationToken[]
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
@@index([policy_id, start_time])
}
// Index for fast "last N logs for tool" from SpendLogs see how a tool is called in production
model LiteLLM_SpendLogToolIndex {
request_id String
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
start_time DateTime
@@id([request_id, tool_name])
@@index([tool_name, start_time])
}
// Prompt table for storing prompt configurations
model LiteLLM_PromptTable {
id String @id @default(uuid())
@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable {
updated_by String?
}
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here
model LiteLLM_ToolTable {
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
tool_id String @id @default(uuid())
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
origin String? // MCP server name or "user_defined"
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
output_policy String @default("untrusted") // "trusted" | "untrusted"
call_count Int @default(0) // cumulative number of times this tool was seen
assignments Json? @default("{}")
key_hash String? // hash of the virtual key that first called this tool
team_id String? // team that first called this tool
key_alias String? // human-readable alias of the virtual key
user_agent String? // user-agent of the first request that discovered this tool
last_used_at DateTime? // timestamp of the most recent call
created_at DateTime @default(now())
created_by String?
updated_at DateTime @default(now()) @updatedAt
updated_by String?
@@index([call_policy])
@@index([input_policy])
@@index([output_policy])
@@index([team_id])
}

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""
Standalone script to test tool allowlist enforcement and tool name extraction.
Run from repo root:
poetry run python scripts/test_tool_allowlist_script.py
Or run the unit tests:
poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v
"""
import asyncio
import sys
from pathlib import Path
# Ensure repo root is on path
repo_root = Path(__file__).resolve().parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
def test_extraction():
"""Test extract_request_tool_names for each API shape."""
from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names
cases = [
("OpenAI chat tools", "/v1/chat/completions", {"tools": [{"type": "function", "function": {"name": "get_weather"}}]}),
("OpenAI chat functions", "/v1/chat/completions", {"functions": [{"name": "run_sql"}]}),
("OpenAI responses function", "/v1/responses", {"tools": [{"type": "function", "name": "get_current_weather"}]}),
("OpenAI responses MCP", "/v1/responses", {"tools": [{"type": "mcp", "server_label": "dmcp"}]}),
("Anthropic", "/v1/messages", {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]}),
("Google generateContent", "/generate_content", {"tools": [{"functionDeclarations": [{"name": "schedule_meeting"}]}]}),
("MCP call_tool", "/mcp/call_tool", {"name": "my_tool", "arguments": {}}),
("Non-tool route", "/v1/embeddings", {"tools": [{"type": "function", "function": {"name": "x"}}]}),
]
print("=== extract_request_tool_names(route, data) ===\n")
for label, route, data in cases:
names = extract_request_tool_names(route, data)
print(f" {label}: {names}")
print()
async def test_check_tools_allowlist():
"""Test check_tools_allowlist with mock tokens."""
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import check_tools_allowlist
def token(metadata=None, team_metadata=None):
return UserAPIKeyAuth(
api_key="test-key",
user_id="user",
team_id="team",
org_id=None,
models=["*"],
metadata=metadata or {},
team_metadata=team_metadata or {},
)
print("=== check_tools_allowlist (auth) ===\n")
# No allowlist -> pass
await check_tools_allowlist(
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
valid_token=token(),
team_object=None,
route="/v1/chat/completions",
)
print(" No allowlist, body has tools: PASS")
# Allowed tool -> pass
await check_tools_allowlist(
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
valid_token=token(metadata={"allowed_tools": ["get_weather"]}),
team_object=None,
route="/v1/chat/completions",
)
print(" allowed_tools=['get_weather'], body has get_weather: PASS")
# Disallowed tool -> raise
try:
await check_tools_allowlist(
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
valid_token=token(metadata={"allowed_tools": ["other_tool"]}),
team_object=None,
route="/v1/chat/completions",
)
print(" DISALLOWED: expected ProxyException")
except ProxyException as e:
if e.type == ProxyErrorTypes.tool_access_denied:
print(" allowed_tools=['other_tool'], body has get_weather: PASS (raised tool_access_denied)")
else:
print(f" Unexpected ProxyException type: {e.type}")
except Exception as e:
print(f" Unexpected: {e}")
# Team allowlist when key empty
await check_tools_allowlist(
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
valid_token=token(team_metadata={"allowed_tools": ["get_weather"]}),
team_object=None,
route="/v1/chat/completions",
)
print(" team_metadata.allowed_tools=['get_weather']: PASS")
print()
def main():
print("Tool allowlist / tool name extraction script checks\n")
test_extraction()
asyncio.run(test_check_tools_allowlist())
print("Done. For full unit tests run:")
print(" poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v")
if __name__ == "__main__":
main()

View File

@ -1,6 +1,6 @@
"""
Unit tests for tool_registry_writer.py uses a mock prisma client
that exposes execute_raw / query_raw (matching the actual raw-SQL implementation).
that exposes litellm_tooltable.upsert / find_many / find_unique.
"""
import os
@ -13,21 +13,28 @@ import pytest
sys.path.insert(0, os.path.abspath("../../.."))
from litellm.proxy.db.tool_registry_writer import (
ToolPolicyRegistry,
batch_upsert_tools,
get_tool,
get_tool_policy_registry,
get_tools_by_names,
list_tools,
update_tool_policy,
)
def _make_prisma(query_rows=None):
"""Return a minimal mock prisma_client with execute_raw / query_raw."""
default_row = {
def _mock_row(**kwargs):
"""Build a row-like object with real attributes (no MagicMock) for _row_to_model."""
class Row:
pass
default = {
"tool_id": "uuid-1",
"tool_name": "my_tool",
"origin": "user_defined",
"call_policy": "untrusted",
"input_policy": "untrusted",
"output_policy": "untrusted",
"call_count": 1,
"assignments": {},
"key_hash": None,
@ -38,31 +45,54 @@ def _make_prisma(query_rows=None):
"created_by": None,
"updated_by": None,
}
rows = query_rows if query_rows is not None else [default_row]
default.update(kwargs)
row = Row()
for k, v in default.items():
setattr(row, k, v)
return row
def _make_prisma(
*,
upsert_return=None,
find_many_rows=None,
find_unique_row=None,
):
"""Return a mock prisma_client with litellm_tooltable.upsert, find_many, find_unique."""
prisma = MagicMock()
prisma.db.execute_raw = AsyncMock(return_value=None)
prisma.db.query_raw = AsyncMock(return_value=rows)
prisma.db.litellm_tooltable = MagicMock()
prisma.db.litellm_tooltable.upsert = AsyncMock(return_value=upsert_return)
prisma.db.litellm_tooltable.find_many = AsyncMock(
return_value=find_many_rows if find_many_rows is not None else []
)
prisma.db.litellm_tooltable.find_unique = AsyncMock(
return_value=find_unique_row
)
return prisma
@pytest.mark.asyncio
async def test_batch_upsert_tools_calls_execute_raw():
async def test_batch_upsert_tools_calls_upsert():
prisma = _make_prisma()
items = [{"tool_name": "tool_a", "origin": "mcp_server", "created_by": None}]
await batch_upsert_tools(prisma, items)
prisma.db.execute_raw.assert_awaited_once()
call_args = prisma.db.execute_raw.call_args
sql = call_args.args[0]
assert "LiteLLM_ToolTable" in sql
assert "ON CONFLICT" in sql
prisma.db.litellm_tooltable.upsert.assert_awaited_once()
call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs
assert call_kw["where"] == {"tool_name": "tool_a"}
assert call_kw["data"]["create"]["tool_name"] == "tool_a"
assert call_kw["data"]["create"]["origin"] == "mcp_server"
assert call_kw["data"]["create"]["input_policy"] == "untrusted"
assert call_kw["data"]["create"]["output_policy"] == "untrusted"
assert call_kw["data"]["create"]["call_count"] == 1
assert call_kw["data"]["update"]["call_count"] == {"increment": 1}
assert "updated_at" in call_kw["data"]["update"]
@pytest.mark.asyncio
async def test_batch_upsert_tools_empty_list():
prisma = _make_prisma()
await batch_upsert_tools(prisma, [])
prisma.db.execute_raw.assert_not_awaited()
prisma.db.litellm_tooltable.upsert.assert_not_awaited()
@pytest.mark.asyncio
@ -70,123 +100,120 @@ async def test_batch_upsert_tools_skips_empty_names():
prisma = _make_prisma()
items = [{"tool_name": "", "origin": None}, {"tool_name": None}] # type: ignore[list-item]
await batch_upsert_tools(prisma, items)
prisma.db.execute_raw.assert_not_awaited()
prisma.db.litellm_tooltable.upsert.assert_not_awaited()
@pytest.mark.asyncio
async def test_batch_upsert_multiple_tools_calls_execute_raw_per_tool():
async def test_batch_upsert_multiple_tools_calls_upsert_per_tool():
prisma = _make_prisma()
items = [
{"tool_name": "tool_a", "origin": "mcp_server", "created_by": None},
{"tool_name": "tool_b", "origin": "user_defined", "created_by": "alice"},
]
await batch_upsert_tools(prisma, items)
assert prisma.db.execute_raw.await_count == 2
assert prisma.db.litellm_tooltable.upsert.await_count == 2
calls = prisma.db.litellm_tooltable.upsert.call_args_list
assert calls[0].kwargs["where"]["tool_name"] == "tool_a"
assert calls[1].kwargs["where"]["tool_name"] == "tool_b"
@pytest.mark.asyncio
async def test_list_tools_no_filter():
row = {
"tool_id": "id1",
"tool_name": "tool_a",
"origin": "mcp",
"call_policy": "untrusted",
"call_count": 5,
"assignments": {},
"key_hash": None,
"team_id": None,
"key_alias": None,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": None,
"updated_by": None,
}
prisma = _make_prisma(query_rows=[row])
row = _mock_row(
tool_id="id1",
tool_name="tool_a",
origin="mcp",
input_policy="untrusted",
output_policy="untrusted",
call_count=5,
)
prisma = _make_prisma(find_many_rows=[row])
result = await list_tools(prisma)
assert len(result) == 1
assert result[0].tool_name == "tool_a"
assert result[0].call_count == 5
prisma.db.query_raw.assert_awaited_once()
prisma.db.litellm_tooltable.find_many.assert_awaited_once()
call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs
assert call_kw["where"] == {}
assert call_kw["order"] == {"created_at": "desc"}
@pytest.mark.asyncio
async def test_list_tools_with_policy_filter():
row = {
"tool_id": "id1",
"tool_name": "blocked_tool",
"origin": None,
"call_policy": "blocked",
"call_count": 2,
"assignments": None,
"key_hash": None,
"team_id": None,
"key_alias": None,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": None,
"updated_by": None,
}
prisma = _make_prisma(query_rows=[row])
result = await list_tools(prisma, call_policy="blocked")
assert result[0].call_policy == "blocked"
call_args = prisma.db.query_raw.call_args
sql = call_args.args[0]
assert "WHERE call_policy" in sql
async def test_list_tools_with_input_policy_filter():
row = _mock_row(
tool_id="id1",
tool_name="blocked_tool",
origin=None,
input_policy="blocked",
output_policy="untrusted",
call_count=2,
assignments=None,
)
prisma = _make_prisma(find_many_rows=[row])
result = await list_tools(prisma, input_policy="blocked")
assert result[0].input_policy == "blocked"
call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs
assert call_kw["where"] == {"input_policy": "blocked"}
@pytest.mark.asyncio
async def test_get_tool_found():
prisma = _make_prisma()
row = _mock_row(tool_name="my_tool")
prisma = _make_prisma(find_unique_row=row)
result = await get_tool(prisma, "my_tool")
assert result is not None
assert result.tool_name == "my_tool"
prisma.db.query_raw.assert_awaited_once()
prisma.db.litellm_tooltable.find_unique.assert_awaited_once_with(
where={"tool_name": "my_tool"}
)
@pytest.mark.asyncio
async def test_get_tool_not_found():
prisma = _make_prisma(query_rows=[])
prisma = _make_prisma(find_unique_row=None)
result = await get_tool(prisma, "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_update_tool_policy_calls_execute_raw():
row = {
"tool_id": "uuid-1",
"tool_name": "my_tool",
"origin": "user_defined",
"call_policy": "blocked",
"call_count": 1,
"assignments": {},
"key_hash": None,
"team_id": None,
"key_alias": None,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": None,
"updated_by": "admin",
}
prisma = _make_prisma(query_rows=[row])
result = await update_tool_policy(prisma, "my_tool", "blocked", "admin")
async def test_update_tool_policy_calls_upsert_then_get_tool():
row = _mock_row(
tool_name="my_tool",
input_policy="blocked",
output_policy="untrusted",
updated_by="admin",
)
prisma = _make_prisma(find_unique_row=row)
result = await update_tool_policy(
prisma, "my_tool", updated_by="admin", input_policy="blocked"
)
assert result is not None
assert result.call_policy == "blocked"
prisma.db.execute_raw.assert_awaited_once()
call_args = prisma.db.execute_raw.call_args
sql = call_args.args[0]
assert "ON CONFLICT" in sql
assert "call_policy" in sql
assert result.input_policy == "blocked"
prisma.db.litellm_tooltable.upsert.assert_awaited_once()
call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs
assert call_kw["where"] == {"tool_name": "my_tool"}
assert call_kw["data"]["update"]["input_policy"] == "blocked"
assert call_kw["data"]["update"]["updated_by"] == "admin"
prisma.db.litellm_tooltable.find_unique.assert_awaited_with(
where={"tool_name": "my_tool"}
)
@pytest.mark.asyncio
async def test_get_tools_by_names_returns_policy_map():
rows = [
{"tool_name": "tool_a", "call_policy": "trusted"},
{"tool_name": "tool_b", "call_policy": "blocked"},
_mock_row(tool_name="tool_a", input_policy="trusted", output_policy="untrusted"),
_mock_row(tool_name="tool_b", input_policy="blocked", output_policy="untrusted"),
]
prisma = _make_prisma(query_rows=rows)
prisma = _make_prisma(find_many_rows=rows)
result = await get_tools_by_names(prisma, ["tool_a", "tool_b"])
assert result == {"tool_a": "trusted", "tool_b": "blocked"}
assert result == {
"tool_a": ("trusted", "untrusted"),
"tool_b": ("blocked", "untrusted"),
}
prisma.db.litellm_tooltable.find_many.assert_awaited_once_with(
where={"tool_name": {"in": ["tool_a", "tool_b"]}}
)
@pytest.mark.asyncio
@ -194,4 +221,71 @@ async def test_get_tools_by_names_empty_list():
prisma = _make_prisma()
result = await get_tools_by_names(prisma, [])
assert result == {}
prisma.db.query_raw.assert_not_awaited()
prisma.db.litellm_tooltable.find_many.assert_not_awaited()
# --- ToolPolicyRegistry ---
def _mock_tool_row(
tool_name: str,
input_policy: str = "untrusted",
output_policy: str = "untrusted",
):
row = MagicMock()
row.tool_name = tool_name
row.input_policy = input_policy
row.output_policy = output_policy
return row
def _mock_perm_row(object_permission_id: str, blocked_tools: list):
row = MagicMock()
row.object_permission_id = object_permission_id
row.blocked_tools = blocked_tools
return row
@pytest.mark.asyncio
async def test_tool_policy_registry_sync_and_get_effective_policies():
"""Registry syncs from DB; get_effective_policies returns merged blocked + global."""
prisma = MagicMock()
prisma.db.litellm_tooltable.find_many = AsyncMock(
return_value=[
_mock_tool_row("tool_a", input_policy="trusted"),
_mock_tool_row("tool_b", input_policy="blocked"),
_mock_tool_row("tool_c", input_policy="untrusted"),
]
)
prisma.db.litellm_objectpermissiontable.find_many = AsyncMock(
return_value=[
_mock_perm_row("op-key-1", ["tool_a"]),
_mock_perm_row("op-team-1", ["tool_c"]),
]
)
registry = get_tool_policy_registry()
await registry.sync_tool_policy_from_db(prisma)
assert registry.is_initialized()
# Key blocked: tool_a. Team blocked: tool_c. Global: tool_b blocked.
result = registry.get_effective_policies(
["tool_a", "tool_b", "tool_c"],
object_permission_id="op-key-1",
team_object_permission_id="op-team-1",
)
assert result["tool_a"] == "blocked"
assert result["tool_b"] == "blocked"
assert result["tool_c"] == "blocked"
# No op ids: only global
result_global = registry.get_effective_policies(["tool_a", "tool_b", "tool_c"])
assert result_global["tool_a"] == "trusted"
assert result_global["tool_b"] == "blocked"
assert result_global["tool_c"] == "untrusted"
@pytest.mark.asyncio
async def test_tool_policy_registry_not_initialized_returns_untrusted():
"""When not synced, get_effective_policies still returns untrusted for unknown tools."""
registry = ToolPolicyRegistry()
assert not registry.is_initialized()
result = registry.get_effective_policies(["unknown_tool"])
assert result == {"unknown_tool": "untrusted"}

View File

@ -12,9 +12,8 @@ from fastapi import HTTPException
sys.path.insert(0, os.path.abspath("../../../../../.."))
from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import (
ToolPolicyGuardrail,
)
from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import \
ToolPolicyGuardrail
from litellm.types.guardrails import GuardrailEventHooks
@ -70,10 +69,21 @@ async def test_no_tool_calls_in_response_passes_through(guardrail):
assert result is inputs
def _registry_mock(policy_map: dict):
"""Return a mock registry with is_initialized=True and get_effective_policies returning policy_map."""
reg = MagicMock()
reg.is_initialized.return_value = True
reg.get_effective_policies.return_value = policy_map
return reg
@pytest.mark.asyncio
async def test_untrusted_tools_pass_through(guardrail):
policy_map = {"search": "untrusted", "read_file": "trusted"}
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock(policy_map),
):
inputs: Any = _tool_request_inputs(["search", "read_file"])
result = await guardrail.apply_guardrail(
inputs=inputs, request_data={}, input_type="request"
@ -84,7 +94,10 @@ async def test_untrusted_tools_pass_through(guardrail):
@pytest.mark.asyncio
async def test_blocked_tool_in_request_raises_http_exception(guardrail):
policy_map = {"dangerous_tool": "blocked"}
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock(policy_map),
):
inputs: Any = _tool_request_inputs(["dangerous_tool"])
with pytest.raises(HTTPException) as exc_info:
await guardrail.apply_guardrail(
@ -97,7 +110,10 @@ async def test_blocked_tool_in_request_raises_http_exception(guardrail):
@pytest.mark.asyncio
async def test_blocked_tool_in_response_raises_http_exception(guardrail):
policy_map = {"exfil_tool": "blocked"}
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock(policy_map),
):
inputs: Any = _tool_response_inputs(["exfil_tool"])
with pytest.raises(HTTPException) as exc_info:
await guardrail.apply_guardrail(
@ -110,7 +126,10 @@ async def test_blocked_tool_in_response_raises_http_exception(guardrail):
@pytest.mark.asyncio
async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail):
policy_map = {"safe_tool": "trusted", "bad_tool": "blocked"}
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock(policy_map),
):
inputs: Any = _tool_request_inputs(["safe_tool", "bad_tool"])
with pytest.raises(HTTPException) as exc_info:
await guardrail.apply_guardrail(
@ -123,8 +142,11 @@ async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail):
@pytest.mark.asyncio
async def test_tool_not_in_db_passes_through(guardrail):
"""Tools not found in the DB (no entry) should not be blocked."""
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value={})):
"""When registry returns no policy (or empty), tools are not blocked."""
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock({}),
):
inputs: Any = _tool_request_inputs(["unknown_tool"])
result = await guardrail.apply_guardrail(
inputs=inputs, request_data={}, input_type="request"
@ -133,43 +155,30 @@ async def test_tool_not_in_db_passes_through(guardrail):
@pytest.mark.asyncio
async def test_get_policies_cached_uses_cache(guardrail):
"""Second call with same tool names should return the cached result."""
policy_map = {"tool_a": "trusted"}
async def test_registry_not_initialized_passes_through(guardrail):
"""When registry is not initialized, no tools are blocked (empty policy map)."""
reg = MagicMock()
reg.is_initialized.return_value = False
with patch(
"litellm.proxy.db.tool_registry_writer.get_tools_by_names",
new=AsyncMock(return_value=policy_map),
) as mock_db, patch(
"litellm.proxy.proxy_server.prisma_client",
new=MagicMock(),
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=reg,
):
# first call — should hit DB
result1 = await guardrail._get_policies_cached(["tool_a"])
assert result1 == policy_map
# second call — should hit cache, not DB again
result2 = await guardrail._get_policies_cached(["tool_a"])
assert result2 == policy_map
assert mock_db.call_count == 1
@pytest.mark.asyncio
async def test_get_policies_cached_no_prisma(guardrail):
"""Without a prisma client, returns empty dict."""
with patch(
"litellm.proxy.proxy_server.prisma_client",
None,
):
result = await guardrail._get_policies_cached(["tool_a"])
assert result == {}
inputs: Any = _tool_request_inputs(["any_tool"])
result = await guardrail.apply_guardrail(
inputs=inputs, request_data={}, input_type="request"
)
assert result is inputs
reg.get_effective_policies.assert_not_called()
@pytest.mark.asyncio
async def test_response_tool_calls_as_objects(guardrail):
"""tool_calls that are objects (not dicts) with .function.name should work."""
policy_map = {"obj_tool": "blocked"}
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
with patch(
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
return_value=_registry_mock(policy_map),
):
fn = MagicMock()
fn.name = "obj_tool"
tc = MagicMock()

View File

@ -0,0 +1,200 @@
"""
Tests for tool allowlist enforcement (key/team metadata.allowed_tools).
Covers:
- check_tools_allowlist: allowed, disallowed, no allowlist, non-tool routes
- extract_request_tool_names: OpenAI chat, responses, Anthropic, generate_content, MCP
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.proxy._types import (ProxyErrorTypes, ProxyException,
UserAPIKeyAuth)
from litellm.proxy.auth.auth_checks import check_tools_allowlist
from litellm.proxy.guardrails.tool_name_extraction import (
TOOL_CAPABLE_CALL_TYPES, extract_request_tool_names)
def _token(metadata=None, team_metadata=None):
return UserAPIKeyAuth(
api_key="test-key",
user_id="user",
team_id="team",
org_id=None,
models=["*"],
metadata=metadata or {},
team_metadata=team_metadata or {},
)
class TestExtractRequestToolNames:
"""Test tool name extraction per API format."""
def test_openai_chat_tools(self):
data = {
"tools": [
{"type": "function", "function": {"name": "get_weather"}},
{"type": "function", "function": {"name": "run_sql"}},
]
}
assert extract_request_tool_names("/v1/chat/completions", data) == [
"get_weather",
"run_sql",
]
def test_openai_chat_functions_legacy(self):
data = {"functions": [{"name": "get_weather"}, {"name": "run_sql"}]}
assert extract_request_tool_names("/v1/chat/completions", data) == [
"get_weather",
"run_sql",
]
def test_openai_responses_function_tools(self):
data = {
"tools": [
{"type": "function", "name": "get_current_weather", "description": "x"},
]
}
assert extract_request_tool_names("/v1/responses", data) == [
"get_current_weather"
]
def test_openai_responses_mcp_tools(self):
data = {
"tools": [
{"type": "mcp", "server_label": "dmcp", "server_url": "http://x"},
]
}
assert extract_request_tool_names("/v1/responses", data) == ["dmcp"]
def test_anthropic_tools(self):
data = {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]}
assert extract_request_tool_names("/v1/messages", data) == [
"get_weather",
"run_sql",
]
def test_generate_content_tools(self):
data = {
"tools": [
{
"functionDeclarations": [
{"name": "schedule_meeting", "description": "x"},
]
},
]
}
assert extract_request_tool_names("/generate_content", data) == [
"schedule_meeting"
]
def test_mcp_call_tool_name(self):
data = {"name": "my_tool", "arguments": {}}
assert extract_request_tool_names("/mcp/call_tool", data) == ["my_tool"]
def test_mcp_call_tool_mcp_tool_name(self):
data = {"mcp_tool_name": "other_tool"}
assert extract_request_tool_names("/mcp/call_tool", data) == ["other_tool"]
def test_non_tool_route_returns_empty(self):
data = {"tools": [{"type": "function", "function": {"name": "x"}}]}
assert extract_request_tool_names("/v1/embeddings", data) == []
class TestCheckToolsAllowlist:
"""Test allowlist enforcement in auth (no DB in hot path)."""
@pytest.mark.asyncio
async def test_no_allowlist_passes(self):
token = _token(metadata={}, team_metadata={})
body = {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
}
await check_tools_allowlist(
request_body=body,
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)
@pytest.mark.asyncio
async def test_allowed_tool_passes(self):
token = _token(metadata={"allowed_tools": ["get_weather"]})
body = {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
}
await check_tools_allowlist(
request_body=body,
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)
@pytest.mark.asyncio
async def test_disallowed_tool_raises(self):
token = _token(metadata={"allowed_tools": ["other_tool"]})
body = {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
}
with pytest.raises(ProxyException) as exc_info:
await check_tools_allowlist(
request_body=body,
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)
assert exc_info.value.type == ProxyErrorTypes.tool_access_denied
assert "get_weather" in str(exc_info.value.message)
@pytest.mark.asyncio
async def test_team_allowlist_used_when_key_empty(self):
token = _token(
metadata={},
team_metadata={"allowed_tools": ["get_weather"]},
)
body = {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
}
await check_tools_allowlist(
request_body=body,
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)
@pytest.mark.asyncio
async def test_key_allowlist_overrides_team(self):
token = _token(
metadata={"allowed_tools": ["get_weather"]},
team_metadata={"allowed_tools": ["other_tool"]},
)
body = {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
}
await check_tools_allowlist(
request_body=body,
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)
@pytest.mark.asyncio
async def test_valid_token_none_skips(self):
await check_tools_allowlist(
request_body={"tools": [{"type": "function", "function": {"name": "x"}}]},
valid_token=None,
team_object=None,
route="/v1/chat/completions",
)
@pytest.mark.asyncio
async def test_no_tools_in_body_passes(self):
token = _token(metadata={"allowed_tools": ["get_weather"]})
await check_tools_allowlist(
request_body={"messages": []},
valid_token=token,
team_object=None,
route="/v1/chat/completions",
)

View File

@ -39,7 +39,7 @@ import UserDashboard from "@/components/user_dashboard";
import { AccessGroupsPage } from "@/components/AccessGroups/AccessGroupsPage";
import { ProjectsPage } from "@/components/Projects/ProjectsPage";
import VectorStoreManagement from "@/components/vector_store_management";
import ToolPolicies from "@/components/ToolPolicies";
import ToolPoliciesView from "@/components/ToolPoliciesView";
import SpendLogsTable from "@/components/view_logs";
import ViewUserDashboard from "@/components/view_users";
import { ThemeProvider } from "@/contexts/ThemeContext";
@ -549,7 +549,7 @@ function CreateKeyPageContent() {
) : page == "vector-stores" ? (
<VectorStoreManagement accessToken={accessToken} userRole={userRole} userID={userID} />
) : page == "tool-policies" ? (
<ToolPolicies accessToken={accessToken} userRole={userRole} />
<ToolPoliciesView accessToken={accessToken} userRole={userRole} />
) : page == "guardrails-monitor" ? (
<GuardrailsMonitorView accessToken={accessToken} />
) : page == "new_usage" ? (

View File

@ -0,0 +1,445 @@
"use client";
import { ArrowLeftOutlined, HistoryOutlined, ToolOutlined } from "@ant-design/icons";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { Button, Select, Spin } from "antd";
import React, { useCallback, useEffect, useMemo, useState } from "react";
import TeamDropdown from "@/components/common_components/team_dropdown";
import { LogViewer } from "@/components/GuardrailsMonitor/LogViewer";
import type { LogEntry } from "@/components/GuardrailsMonitor/mockData";
import { PolicySelect } from "@/components/ToolPolicies/PolicySelect";
import {
deleteToolPolicyOverride,
fetchToolDetail,
fetchToolPolicyOptions,
getToolUsageLogs,
keyListCall,
teamListCall,
updateToolPolicy,
type ToolPolicyOption,
type ToolPolicyOverrideRow,
} from "@/components/networking";
import type { Team } from "@/components/key_team_helpers/key_list";
interface ToolDetailProps {
toolName: string;
onBack: () => void;
accessToken: string | null;
}
interface KeyOption {
token: string;
key_alias?: string;
}
const TOOL_DETAIL_QUERY_KEY = "tool-detail";
const LOGS_PAGE_SIZE = 50;
function getDefaultLogsDateRange(): { start: string; end: string } {
const end = new Date();
const start = new Date();
start.setDate(start.getDate() - 90);
const fmt = (d: Date) =>
d.toISOString().slice(0, 19).replace("T", " ");
return { start: fmt(start), end: fmt(end) };
}
export function ToolDetail({ toolName, onBack, accessToken }: ToolDetailProps) {
const queryClient = useQueryClient();
const [overrideSaving, setOverrideSaving] = useState(false);
const [inputPolicySaving, setInputPolicySaving] = useState(false);
const [outputPolicySaving, setOutputPolicySaving] = useState(false);
const [blockScope, setBlockScope] = useState<"team" | "key">("team");
const [blockTeamId, setBlockTeamId] = useState<string | null>(null);
const [blockKey, setBlockKey] = useState<KeyOption | null>(null);
const logsDateRange = useMemo(() => getDefaultLogsDateRange(), []);
const { data: detail, isLoading: detailLoading, error: detailError } = useQuery({
queryKey: [TOOL_DETAIL_QUERY_KEY, toolName],
queryFn: () => fetchToolDetail(accessToken!, toolName),
enabled: !!accessToken && !!toolName,
});
const { data: policyOptions } = useQuery({
queryKey: ["tool-policy-options"],
queryFn: () => fetchToolPolicyOptions(accessToken!),
enabled: !!accessToken,
staleTime: 60_000,
});
const { data: teamsData } = useQuery({
queryKey: ["teams-list-tool-detail"],
queryFn: () => teamListCall(accessToken!, null, null),
enabled: !!accessToken,
});
const { data: keysData } = useQuery({
queryKey: ["keys-list-tool-detail"],
queryFn: () => keyListCall(accessToken!, null, null, null, null, null, 1, 100),
enabled: !!accessToken,
});
const { data: logsData, isLoading: logsLoading } = useQuery({
queryKey: ["tool-usage-logs", toolName, logsDateRange.start, logsDateRange.end],
queryFn: () =>
getToolUsageLogs(accessToken!, toolName, {
page: 1,
pageSize: LOGS_PAGE_SIZE,
startDate: logsDateRange.start,
endDate: logsDateRange.end,
}),
enabled: !!accessToken && !!toolName,
});
const logs: LogEntry[] = useMemo(() => {
const list = logsData?.logs ?? [];
return list.map((l) => ({
id: l.id,
timestamp: l.timestamp,
action: "passed" as const,
model: l.model ?? undefined,
input_snippet: l.input_snippet ?? undefined,
}));
}, [logsData?.logs]);
const teams: Team[] = useMemo(() => {
const arr = Array.isArray(teamsData) ? teamsData : teamsData?.data ?? [];
return arr.map((t: { team_id?: string; id?: string; team_alias?: string }) => ({
team_id: t.team_id ?? t.id ?? "",
team_alias: t.team_alias ?? t.team_id ?? "",
models: [],
max_budget: null,
budget_duration: null,
tpm_limit: null,
rpm_limit: null,
organization_id: "",
created_at: "",
keys: [],
members_with_roles: [],
spend: 0,
}));
}, [teamsData]);
const keys: KeyOption[] = useMemo(() => {
const keysRes = keysData?.keys ?? keysData?.data ?? [];
return keysRes.map((k: { token?: string; api_key?: string; key_hash?: string; key_alias?: string }) => ({
token: k.token ?? k.api_key ?? k.key_hash ?? "",
key_alias: k.key_alias ?? (k.token ?? k.api_key ?? k.key_hash)?.toString?.()?.substring?.(0, 8),
}));
}, [keysData]);
const invalidateDetail = useCallback(() => {
queryClient.invalidateQueries({ queryKey: [TOOL_DETAIL_QUERY_KEY, toolName] });
}, [queryClient, toolName]);
const handleInputPolicyChange = useCallback(
async (_name: string, newPolicy: string) => {
if (!accessToken) return;
setInputPolicySaving(true);
try {
await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy });
invalidateDetail();
} catch (e: unknown) {
alert(`Failed to update input policy: ${e instanceof Error ? e.message : String(e)}`);
} finally {
setInputPolicySaving(false);
}
},
[accessToken, toolName, invalidateDetail]
);
const handleOutputPolicyChange = useCallback(
async (_name: string, newPolicy: string) => {
if (!accessToken) return;
setOutputPolicySaving(true);
try {
await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy });
invalidateDetail();
} catch (e: unknown) {
alert(`Failed to update output policy: ${e instanceof Error ? e.message : String(e)}`);
} finally {
setOutputPolicySaving(false);
}
},
[accessToken, toolName, invalidateDetail]
);
const handleAddOverride = useCallback(async () => {
if (!accessToken || !toolName) return;
const isTeam = blockScope === "team";
if (isTeam && !blockTeamId) return;
if (!isTeam && !blockKey?.token) return;
setOverrideSaving(true);
try {
await updateToolPolicy(accessToken, toolName, { input_policy: "blocked" }, {
team_id: isTeam ? blockTeamId : undefined,
key_hash: !isTeam ? blockKey!.token : undefined,
key_alias: !isTeam ? blockKey!.key_alias : undefined,
});
invalidateDetail();
setBlockTeamId(null);
setBlockKey(null);
} catch (e: unknown) {
alert(`Failed to add override: ${e instanceof Error ? e.message : String(e)}`);
} finally {
setOverrideSaving(false);
}
}, [accessToken, toolName, blockScope, blockTeamId, blockKey, invalidateDetail]);
const handleRemoveOverride = useCallback(
async (override: ToolPolicyOverrideRow) => {
if (!accessToken || !toolName) return;
setOverrideSaving(true);
try {
await deleteToolPolicyOverride(accessToken, toolName, {
team_id: override.team_id ?? undefined,
key_hash: override.key_hash ?? undefined,
});
invalidateDetail();
} catch (e: unknown) {
alert(`Failed to remove override: ${e instanceof Error ? e.message : String(e)}`);
} finally {
setOverrideSaving(false);
}
},
[accessToken, toolName, invalidateDetail]
);
if (detailLoading && !detail) {
return (
<div className="flex items-center justify-center py-12">
<Spin size="large" />
</div>
);
}
if (detailError && !detail) {
return (
<div>
<Button type="link" icon={<ArrowLeftOutlined />} onClick={onBack} className="pl-0 mb-4">
Back to Tool Policies
</Button>
<p className="text-red-600">Failed to load tool details.</p>
</div>
);
}
if (!detail) {
return null;
}
const { tool, overrides } = detail;
const inputDesc = policyOptions?.input_policies?.find(
(p) => p.value === tool.input_policy
)?.description;
const outputDesc = policyOptions?.output_policies?.find(
(p) => p.value === tool.output_policy
)?.description;
return (
<div>
<div className="mb-6">
<Button
type="link"
icon={<ArrowLeftOutlined />}
onClick={onBack}
className="pl-0 mb-4"
>
Back to Tool Policies
</Button>
<div className="flex items-start justify-between">
<div>
<div className="flex items-center gap-3 mb-1 flex-wrap">
<ToolOutlined className="text-xl text-gray-400" />
<h1 className="text-xl font-semibold text-gray-900 font-mono">{tool.tool_name}</h1>
<span className="inline-flex items-center px-2.5 py-1 text-xs font-medium rounded-md bg-gray-100 text-gray-700 border border-gray-200">
{tool.origin ?? "—"}
</span>
<span className="inline-flex items-center px-2.5 py-1 text-xs font-medium rounded-md bg-indigo-50 text-indigo-700 border border-indigo-200">
{(tool.call_count ?? 0).toLocaleString()} calls
</span>
</div>
<dl className="mt-3 flex flex-wrap gap-x-6 gap-y-1 text-sm text-gray-600">
{tool.user_agent && (
<div className="flex items-center gap-1.5">
<dt className="font-medium text-gray-500 whitespace-nowrap">User Agent:</dt>
<dd className="font-mono truncate max-w-[40ch]" title={tool.user_agent}>{tool.user_agent}</dd>
</div>
)}
{tool.created_at && (
<div className="flex items-center gap-1.5">
<dt className="font-medium text-gray-500 whitespace-nowrap">First Discovered:</dt>
<dd>{new Date(tool.created_at).toLocaleString()}</dd>
</div>
)}
{tool.last_used_at && (
<div className="flex items-center gap-1.5">
<dt className="font-medium text-gray-500 whitespace-nowrap">Last Used:</dt>
<dd>{new Date(tool.last_used_at).toLocaleString()}</dd>
</div>
)}
</dl>
</div>
</div>
</div>
<div className="space-y-6">
{/* Two-panel policy layout */}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
<h2 className="text-sm font-semibold text-gray-700 mb-1">Input Policy</h2>
<p className="text-xs text-gray-500 mb-3">
{inputDesc ?? "Controls what data this tool is allowed to accept."}
</p>
<PolicySelect
value={tool.input_policy}
toolName={tool.tool_name}
saving={inputPolicySaving}
onChange={handleInputPolicyChange}
policyType="input"
size="middle"
minWidth={140}
stopPropagation={false}
/>
</section>
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
<h2 className="text-sm font-semibold text-gray-700 mb-1">Output Policy</h2>
<p className="text-xs text-gray-500 mb-3">
{outputDesc ?? "Controls how this tool's output is trusted by downstream tools."}
</p>
<PolicySelect
value={tool.output_policy}
toolName={tool.tool_name}
saving={outputPolicySaving}
onChange={handleOutputPolicyChange}
policyType="output"
size="middle"
minWidth={140}
stopPropagation={false}
/>
</section>
</div>
{overrides.length > 0 && (
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
<h2 className="text-sm font-semibold text-gray-700 mb-3">Blocked for team or key</h2>
<ul className="border rounded-md divide-y divide-gray-100 bg-red-50/30">
{overrides.map((ov) => (
<li
key={ov.override_id}
className="flex items-center justify-between px-3 py-2.5 text-sm"
>
<span className="text-gray-700">
{ov.team_id ? `Team: ${ov.team_id}` : ""}
{ov.team_id && ov.key_hash ? " · " : ""}
{ov.key_hash ? `Key: ${ov.key_alias || ov.key_hash.substring(0, 8)}` : ""}
{!ov.team_id && !ov.key_hash ? "—" : ""}
</span>
<Button
type="link"
danger
size="small"
disabled={overrideSaving}
onClick={() => handleRemoveOverride(ov)}
>
Remove
</Button>
</li>
))}
</ul>
</section>
)}
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
<h2 className="text-sm font-semibold text-gray-700 mb-3">Block for team or key</h2>
<div className="flex flex-col gap-4 max-w-md">
<div>
<span className="text-sm font-medium text-gray-700 block mb-2">Scope</span>
<div className="flex items-center gap-6">
<label className="flex items-center gap-2 cursor-pointer text-sm text-gray-700">
<input
type="radio"
checked={blockScope === "team"}
onChange={() => setBlockScope("team")}
className="align-middle"
/>
Team
</label>
<label className="flex items-center gap-2 cursor-pointer text-sm text-gray-700">
<input
type="radio"
checked={blockScope === "key"}
onChange={() => setBlockScope("key")}
className="align-middle"
/>
Key
</label>
</div>
</div>
<div>
<span className="text-sm font-medium text-gray-700 block mb-2">
{blockScope === "team" ? "Team" : "Key"}
</span>
{blockScope === "team" ? (
<TeamDropdown
teams={teams}
value={blockTeamId ?? undefined}
onChange={(id) => setBlockTeamId(id || null)}
/>
) : (
<Select
placeholder="Select key"
allowClear
showSearch
optionFilterProp="label"
value={blockKey ? blockKey.token : undefined}
onChange={(token) => {
const k = keys.find((x) => x.token === token);
setBlockKey(k ?? null);
}}
options={keys.map((k) => ({
value: k.token,
label: k.key_alias || k.token?.substring?.(0, 12) || k.token,
}))}
className="w-full"
style={{ minWidth: 200 }}
/>
)}
</div>
<Button
type="primary"
danger
disabled={overrideSaving || (blockScope === "team" ? !blockTeamId : !blockKey?.token)}
loading={overrideSaving}
onClick={handleAddOverride}
>
Block for {blockScope}
</Button>
</div>
</section>
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
<h2 className="text-sm font-semibold text-gray-700 mb-3 flex items-center gap-2">
<HistoryOutlined />
Recent logs
</h2>
<LogViewer
guardrailName={tool.tool_name}
filterAction="passed"
logs={logs}
logsLoading={logsLoading}
totalLogs={logsData?.total ?? 0}
accessToken={accessToken}
startDate={logsDateRange.start}
endDate={logsDateRange.end}
/>
</section>
</div>
</div>
);
}

View File

@ -1,24 +1,46 @@
"use client";
import { Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow } from "@tremor/react";
import { Select, Switch, Tooltip } from "antd";
import React, { useCallback, useDeferredValue, useEffect, useState } from "react";
import React, { useCallback, useDeferredValue, useEffect, useMemo, useState } from "react";
import { Button, Switch, Tooltip } from "antd";
import { Table, TableHead, TableHeaderCell, TableBody, TableRow, TableCell } from "@tremor/react";
import { TimeCell } from "./view_logs/time_cell";
import type { SortState } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
import { TableHeaderSortDropdown } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
import FilterComponent, { FilterOption } from "./molecules/filter";
import { fetchToolsList, ToolRow, updateToolPolicy } from "./networking";
import { TimeCell } from "./view_logs/time_cell";
import { MetricCard } from "./GuardrailsMonitor/MetricCard";
import { PolicySelect, INPUT_POLICY_OPTIONS, OUTPUT_POLICY_OPTIONS } from "./ToolPolicies/PolicySelect";
import {
fetchToolsList,
updateToolPolicy,
ToolRow,
} from "./networking";
const POLICY_OPTIONS = [
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
{ value: "blocked", label: "blocked", color: "#991b1b", bg: "#fee2e2", border: "#fca5a5" },
] as const;
function getUTCDateKey(date: Date): string {
return `${date.getUTCFullYear()}-${String(date.getUTCMonth() + 1).padStart(2, "0")}-${String(date.getUTCDate()).padStart(2, "0")}`;
}
type PolicyValue = "trusted" | "blocked";
function isCreatedInUTCDay(createdAt: string | undefined, utcDateKey: string): boolean {
if (!createdAt) return false;
try {
const d = new Date(createdAt);
return getUTCDateKey(d) === utcDateKey;
} catch {
return false;
}
}
const policyStyle = (p: string) => POLICY_OPTIONS.find((o) => o.value === p) ?? POLICY_OPTIONS[1];
function countToolsInUTCDay(tools: ToolRow[], utcDateKey: string): number {
return tools.filter((t) => isCreatedInUTCDay(t.created_at, utcDateKey)).length;
}
type SortField = "tool_name" | "call_policy" | "team_id" | "key_alias" | "created_at" | "call_count";
function getTrendSubtitle(newToday: number, newYesterday: number): string | undefined {
const diff = newToday - newYesterday;
if (diff === 0) return undefined;
if (diff > 0) return `+${diff} since yesterday`;
return `${diff} since yesterday`;
}
type SortField = "tool_name" | "input_policy" | "output_policy" | "team_id" | "key_alias" | "created_at" | "call_count";
interface FilterValues {
[key: string]: string;
@ -27,65 +49,16 @@ interface FilterValues {
interface ToolPoliciesProps {
accessToken: string | null;
userRole?: string;
onSelectTool?: (toolName: string) => void;
}
const PolicySelect: React.FC<{
value: string;
toolName: string;
saving: boolean;
onChange: (toolName: string, policy: string) => void;
}> = ({ value, toolName, saving, onChange }) => {
const style = policyStyle(value);
return (
<Select
size="small"
value={value}
disabled={saving}
loading={saving}
onChange={(v) => onChange(toolName, v)}
onClick={(e) => e.stopPropagation()}
style={{
minWidth: 110,
fontWeight: 500,
}}
popupMatchSelectWidth={false}
options={POLICY_OPTIONS.map((o) => ({
value: o.value,
label: (
<span
style={{
display: "inline-flex",
alignItems: "center",
gap: 6,
fontSize: 12,
fontWeight: 500,
color: o.color,
}}
>
<span
style={{
width: 8,
height: 8,
borderRadius: "50%",
backgroundColor: o.color,
display: "inline-block",
flexShrink: 0,
}}
/>
{o.label}
</span>
),
}))}
/>
);
};
export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken, onSelectTool }) => {
const [tools, setTools] = useState<ToolRow[]>([]);
const [loading, setLoading] = useState(true);
const [isFetching, setIsFetching] = useState(false);
const [error, setError] = useState<string | null>(null);
const [saving, setSaving] = useState<string | null>(null);
const [savingInput, setSavingInput] = useState<string | null>(null);
const [savingOutput, setSavingOutput] = useState<string | null>(null);
const [searchTerm, setSearchTerm] = useState("");
const [sortField, setSortField] = useState<SortField>("created_at");
@ -123,16 +96,29 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
return () => clearInterval(id);
}, [isLiveTail, load]);
const handlePolicyChange = async (toolName: string, newPolicy: string) => {
const handleInputPolicyChange = async (toolName: string, newPolicy: string) => {
if (!accessToken) return;
setSaving(toolName);
setSavingInput(toolName);
try {
await updateToolPolicy(accessToken, toolName, newPolicy);
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, call_policy: newPolicy } : t)));
await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy });
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, input_policy: newPolicy } : t)));
} catch (e: any) {
alert(`Failed to update policy: ${e.message}`);
alert(`Failed to update input policy: ${e.message}`);
} finally {
setSaving(null);
setSavingInput(null);
}
};
const handleOutputPolicyChange = async (toolName: string, newPolicy: string) => {
if (!accessToken) return;
setSavingOutput(toolName);
try {
await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy });
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, output_policy: newPolicy } : t)));
} catch (e: any) {
alert(`Failed to update output policy: ${e.message}`);
} finally {
setSavingOutput(null);
}
};
@ -157,7 +143,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
setCurrentPage(1);
};
// Build unique team/key options from loaded data
const teamOptions = Array.from(new Set(tools.map((t) => t.team_id).filter(Boolean))).map((v) => ({
label: v as string,
value: v as string,
@ -169,9 +154,14 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
const filterOptions: FilterOption[] = [
{
name: "Policy",
label: "Policy",
options: POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
name: "Input Policy",
label: "Input Policy",
options: INPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
},
{
name: "Output Policy",
label: "Output Policy",
options: OUTPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
},
{
name: "Team Name",
@ -185,6 +175,39 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
},
];
const { newToday, newYesterday, trendSubtitle, totalTools, blockedCount, activeTeamsCount, needsReviewTools } =
useMemo(() => {
const now = new Date();
const todayKey = getUTCDateKey(now);
const yesterday = new Date(now);
yesterday.setUTCDate(yesterday.getUTCDate() - 1);
const yesterdayKey = getUTCDateKey(yesterday);
const newToday = countToolsInUTCDay(tools, todayKey);
const newYesterday = countToolsInUTCDay(tools, yesterdayKey);
const trendSubtitle = getTrendSubtitle(newToday, newYesterday);
const totalTools = tools.length;
const blockedCount = tools.filter((t) => t.input_policy === "blocked").length;
const activeTeamsCount = new Set(tools.map((t) => t.team_id).filter(Boolean)).size;
const needsReviewTools = tools.filter(
(t) =>
isCreatedInUTCDay(t.created_at, todayKey) &&
t.input_policy === "untrusted"
);
return {
newToday,
newYesterday,
trendSubtitle,
totalTools,
blockedCount,
activeTeamsCount,
needsReviewTools,
};
}, [tools]);
const SortHeader = ({ label, field }: { label: string; field: SortField }) => (
<div className="flex items-center gap-1">
<span>{label}</span>
@ -203,10 +226,12 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
(t.team_id ?? "").toLowerCase().includes(q) ||
(t.key_alias ?? "").toLowerCase().includes(q) ||
(t.key_hash ?? "").toLowerCase().includes(q) ||
t.call_policy.toLowerCase().includes(q);
t.input_policy.toLowerCase().includes(q) ||
t.output_policy.toLowerCase().includes(q);
if (!matchesSearch) return false;
}
if (activeFilters["Policy"] && t.call_policy !== activeFilters["Policy"]) return false;
if (activeFilters["Input Policy"] && t.input_policy !== activeFilters["Input Policy"]) return false;
if (activeFilters["Output Policy"] && t.output_policy !== activeFilters["Output Policy"]) return false;
if (activeFilters["Team Name"] && t.team_id !== activeFilters["Team Name"]) return false;
if (activeFilters["Key Name"] && t.key_alias !== activeFilters["Key Name"]) return false;
return true;
@ -223,11 +248,74 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
const totalPages = Math.max(1, Math.ceil(sorted.length / pageSize));
const paginated = sorted.slice((currentPage - 1) * pageSize, currentPage * pageSize);
const scrollToToolRow = (toolId: string) => {
const idx = sorted.findIndex((t) => t.tool_id === toolId);
if (idx >= 0) {
const page = Math.floor(idx / pageSize) + 1;
if (page !== currentPage) setCurrentPage(page);
requestAnimationFrame(() => {
setTimeout(() => {
document.getElementById(`tool-row-${toolId}`)?.scrollIntoView({ behavior: "smooth", block: "center" });
}, 100);
});
}
};
return (
<div className="p-6 w-full">
<div className="w-full">
<h1 className="text-2xl font-semibold text-gray-900 mb-6">Tool Policies</h1>
<div className="grid grid-cols-2 lg:grid-cols-4 gap-4 mb-6">
<MetricCard
label="New Today"
value={newToday}
valueColor="text-green-600"
subtitle={trendSubtitle}
icon={
<svg className="w-4 h-4 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 7h8m0 0v8m0-8l-8 8-4-4-6 6" />
</svg>
}
/>
<MetricCard label="Total Tools Discovered" value={totalTools} />
<MetricCard
label="Blocked Tools"
value={blockedCount}
valueColor={blockedCount > 0 ? "text-red-600" : undefined}
/>
<MetricCard label="Active Teams" value={activeTeamsCount > 0 ? activeTeamsCount : "—"} />
</div>
{needsReviewTools.length > 0 && (
<div className="bg-amber-50 border border-amber-200 rounded-lg p-4 mb-6">
<h2 className="text-sm font-semibold text-amber-900 mb-1">Needs Review</h2>
<p className="text-sm text-amber-800 mb-3">
{needsReviewTools.length} new tool{needsReviewTools.length !== 1 ? "s" : ""} discovered that require
policy decisions.
</p>
<div className="flex flex-wrap gap-2">
{needsReviewTools.map((t) => (
<span
key={t.tool_id}
className="inline-flex items-center gap-2 px-3 py-1.5 bg-white border border-amber-200 rounded-md text-sm"
>
<span className="font-mono text-amber-900 truncate max-w-[200px]" title={t.tool_name}>
{t.tool_name}
</span>
<button
type="button"
onClick={() => scrollToToolRow(t.tool_id)}
className="text-amber-700 hover:text-amber-900 font-medium text-xs whitespace-nowrap"
>
Review
</button>
</span>
))}
</div>
</div>
)}
<div className="bg-white rounded-lg shadow w-full max-w-full box-border">
{/* Toolbar */}
<div className="border-b px-6 py-4 w-full max-w-full box-border">
<div className="flex flex-col md:flex-row items-start md:items-center justify-between space-y-4 md:space-y-0 w-full max-w-full box-border">
<div className="flex flex-wrap items-center gap-3">
@ -311,7 +399,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
</div>
</div>
{/* Filter row */}
<div className="mt-3">
<FilterComponent
options={filterOptions}
@ -322,7 +409,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
</div>
</div>
{/* Auto-refresh banner */}
{isLiveTail && (
<div className="bg-green-50 border-b border-green-100 px-6 py-2 flex items-center justify-between">
<span className="text-sm text-green-700">Auto-refreshing every 15 seconds</span>
@ -336,7 +422,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
<div className="mx-6 mt-4 p-3 bg-red-50 border border-red-200 rounded text-sm text-red-700">{error}</div>
)}
{/* Table */}
<Table className="[&_td]:py-0.5 [&_th]:py-1 w-full">
<TableHead>
<TableRow>
@ -347,7 +432,10 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
<SortHeader label="Tool Name" field="tool_name" />
</TableHeaderCell>
<TableHeaderCell className="py-1 h-8">
<SortHeader label="Policy" field="call_policy" />
<SortHeader label="Input Policy" field="input_policy" />
</TableHeaderCell>
<TableHeaderCell className="py-1 h-8">
<SortHeader label="Output Policy" field="output_policy" />
</TableHeaderCell>
<TableHeaderCell className="py-1 h-8">
<SortHeader label="# Calls" field="call_count" />
@ -359,45 +447,61 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
<TableHeaderCell className="py-1 h-8">
<SortHeader label="Key Name" field="key_alias" />
</TableHeaderCell>
<TableHeaderCell className="py-1 h-8">Origin</TableHeaderCell>
<TableHeaderCell className="py-1 h-8">User Agent</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{loading ? (
<TableRow>
<TableCell colSpan={8} className="h-8 text-center text-gray-500">
<TableCell colSpan={9} className="h-8 text-center text-gray-500">
Loading tools
</TableCell>
</TableRow>
) : paginated.length === 0 ? (
<TableRow>
<TableCell colSpan={8} className="h-8 text-center text-gray-500">
<TableCell colSpan={9} className="h-8 text-center text-gray-500">
No tools discovered yet. Make a chat completion that returns tool_calls to start auto-discovery.
</TableCell>
</TableRow>
) : (
paginated.map((tool) => (
<TableRow key={tool.tool_id} className="h-8 hover:bg-gray-50">
<TableRow key={tool.tool_id} id={`tool-row-${tool.tool_id}`} className="h-8 hover:bg-gray-50">
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
<TimeCell utcTime={tool.created_at ?? ""} />
</TableCell>
<TableCell className="py-0.5 max-h-8 overflow-hidden">
<Tooltip title={tool.tool_name}>
<span className="font-mono text-xs max-w-[20ch] truncate block font-medium">
{tool.tool_name}
</span>
</Tooltip>
<button
type="button"
onClick={() => onSelectTool?.(tool.tool_name)}
className="text-left w-full font-mono text-xs max-w-[20ch] truncate block font-medium text-blue-600 hover:text-blue-800 hover:underline focus:outline-none focus:ring-0"
>
<Tooltip title={onSelectTool ? "Click to view details and block for team/key" : tool.tool_name}>
<span>{tool.tool_name}</span>
</Tooltip>
</button>
</TableCell>
<TableCell className="py-0.5 max-h-8">
<PolicySelect
value={tool.call_policy}
value={tool.input_policy}
toolName={tool.tool_name}
saving={saving === tool.tool_name}
onChange={handlePolicyChange}
saving={savingInput === tool.tool_name}
onChange={handleInputPolicyChange}
policyType="input"
/>
</TableCell>
<TableCell className="py-0.5 max-h-8 text-right tabular-nums text-sm font-mono text-gray-700">
{(tool.call_count ?? 0).toLocaleString()}
<TableCell className="py-0.5 max-h-8">
<PolicySelect
value={tool.output_policy}
toolName={tool.tool_name}
saving={savingOutput === tool.tool_name}
onChange={handleOutputPolicyChange}
policyType="output"
/>
</TableCell>
<TableCell className="py-0.5 max-h-8">
<div className="flex items-center justify-end h-8 tabular-nums text-sm font-mono text-gray-700">
{(tool.call_count ?? 0).toLocaleString()}
</div>
</TableCell>
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
<Tooltip title={tool.team_id ?? "-"}>
@ -417,8 +521,8 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
</Tooltip>
</TableCell>
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
<Tooltip title={tool.origin ?? "-"}>
<span className="max-w-[15ch] truncate block">{tool.origin ?? "-"}</span>
<Tooltip title={tool.user_agent ?? "-"}>
<span className="font-mono max-w-[20ch] truncate block text-xs text-gray-500">{tool.user_agent ?? "-"}</span>
</Tooltip>
</TableCell>
</TableRow>
@ -427,7 +531,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
</TableBody>
</Table>
{/* Bottom pagination (only when > 1 page) */}
{totalPages > 1 && (
<div className="border-t px-6 py-3 flex items-center justify-between text-sm text-gray-600">
<span>
@ -453,6 +556,7 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
</div>
)}
</div>
</div>
);
};

View File

@ -0,0 +1,92 @@
"use client";
import React from "react";
import { Select } from "antd";
export const INPUT_POLICY_OPTIONS = [
{ value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" },
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
{ value: "blocked", label: "blocked", color: "#991b1b", bg: "#fee2e2", border: "#fca5a5" },
] as const;
export const OUTPUT_POLICY_OPTIONS = [
{ value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" },
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
] as const;
export const POLICY_OPTIONS = INPUT_POLICY_OPTIONS;
export const policyStyle = (p: string) =>
INPUT_POLICY_OPTIONS.find((o) => o.value === p) ?? INPUT_POLICY_OPTIONS[0];
export interface PolicySelectProps {
value: string;
toolName: string;
saving: boolean;
onChange: (toolName: string, policy: string) => void;
policyType?: "input" | "output";
size?: "small" | "middle";
minWidth?: number;
stopPropagation?: boolean;
}
export const PolicySelect: React.FC<PolicySelectProps> = ({
value,
toolName,
saving,
onChange,
policyType = "input",
size = "small",
minWidth = 110,
stopPropagation = true,
}) => {
const options = policyType === "output" ? OUTPUT_POLICY_OPTIONS : INPUT_POLICY_OPTIONS;
const style = policyStyle(value);
return (
<Select
size={size}
value={value}
disabled={saving}
loading={saving}
onChange={(v) => onChange(toolName, v)}
onClick={(e) => stopPropagation && e.stopPropagation()}
style={{
minWidth,
fontWeight: 500,
backgroundColor: style.bg,
borderColor: style.border,
color: style.color,
borderRadius: 999,
fontSize: size === "small" ? 11 : 12,
}}
popupMatchSelectWidth={false}
options={options.map((o) => ({
value: o.value,
label: (
<span
style={{
display: "inline-flex",
alignItems: "center",
gap: 6,
fontSize: 12,
fontWeight: 500,
color: o.color,
}}
>
<span
style={{
width: 8,
height: 8,
borderRadius: "50%",
backgroundColor: o.color,
display: "inline-block",
flexShrink: 0,
}}
/>
{o.label}
</span>
),
}))}
/>
);
};

View File

@ -0,0 +1,44 @@
"use client";
import React, { useState } from "react";
import { ToolDetail } from "@/components/ToolDetail";
import { ToolPolicies } from "@/components/ToolPolicies";
type View =
| { type: "overview" }
| { type: "detail"; toolName: string };
interface ToolPoliciesViewProps {
accessToken: string | null;
userRole?: string;
}
export default function ToolPoliciesView({ accessToken, userRole }: ToolPoliciesViewProps) {
const [view, setView] = useState<View>({ type: "overview" });
const handleSelectTool = (toolName: string) => {
setView({ type: "detail", toolName });
};
const handleBack = () => {
setView({ type: "overview" });
};
return (
<div className="p-6 w-full min-w-0 flex-1">
{view.type === "detail" ? (
<ToolDetail
toolName={view.toolName}
onBack={handleBack}
accessToken={accessToken}
/>
) : (
<ToolPolicies
accessToken={accessToken}
userRole={userRole}
onSelectTool={handleSelectTool}
/>
)}
</div>
);
}

View File

@ -10127,7 +10127,8 @@ export interface ToolRow {
tool_id: string;
tool_name: string;
origin?: string;
call_policy: string;
input_policy: string;
output_policy: string;
call_count?: number;
assignments?: Record<string, any>;
key_hash?: string;
@ -10139,6 +10140,37 @@ export interface ToolRow {
updated_by?: string;
}
export interface ToolPolicyOption {
value: string;
label: string;
description: string;
}
export interface ToolPolicyOptionsResponse {
input_policies: ToolPolicyOption[];
output_policies: ToolPolicyOption[];
}
export const fetchToolPolicyOptions = async (
accessToken: string
): Promise<ToolPolicyOptionsResponse> => {
const url = proxyBaseUrl
? `${proxyBaseUrl}/v1/tool/policy/options`
: `/v1/tool/policy/options`;
const response = await fetch(url, {
method: "GET",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
throw new Error(errorData);
}
return response.json();
};
export const fetchToolsList = async (accessToken: string): Promise<ToolRow[]> => {
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/list` : `/v1/tool/list`;
const response = await fetch(url, {
@ -10156,19 +10188,137 @@ export const fetchToolsList = async (accessToken: string): Promise<ToolRow[]> =>
return data.tools ?? [];
};
export const updateToolPolicy = async (
export interface ToolPolicyOverrideRow {
override_id: string;
tool_name: string;
team_id?: string | null;
key_hash?: string | null;
input_policy: string;
key_alias?: string | null;
created_at?: string;
updated_at?: string;
}
export interface ToolDetailResponse {
tool: ToolRow;
overrides: ToolPolicyOverrideRow[];
}
export interface ToolUsageLogEntry {
id: string;
timestamp: string;
model?: string | null;
spend?: number | null;
total_tokens?: number | null;
input_snippet?: string | null;
}
export interface ToolUsageLogsResponse {
logs: ToolUsageLogEntry[];
total: number;
page: number;
page_size: number;
}
export const getToolUsageLogs = async (
accessToken: string,
toolName: string,
callPolicy: string
): Promise<ToolRow> => {
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/policy` : `/v1/tool/policy`;
const response = await fetch(url, {
method: "POST",
options: { page?: number; pageSize?: number; startDate?: string; endDate?: string }
): Promise<ToolUsageLogsResponse> => {
const encoded = encodeURIComponent(toolName);
const url = proxyBaseUrl
? `${proxyBaseUrl}/v1/tool/${encoded}/logs`
: `/v1/tool/${encoded}/logs`;
const params = new URLSearchParams();
if (options.page != null) params.append("page", String(options.page));
if (options.pageSize != null) params.append("page_size", String(options.pageSize));
if (options.startDate) params.append("start_date", options.startDate);
if (options.endDate) params.append("end_date", options.endDate);
const fullUrl = params.toString() ? `${url}?${params.toString()}` : url;
const response = await fetch(fullUrl, {
method: "GET",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
throw new Error(deriveErrorMessage(errorData));
}
return response.json();
};
export const fetchToolDetail = async (
accessToken: string,
toolName: string
): Promise<ToolDetailResponse> => {
const encoded = encodeURIComponent(toolName);
const url = proxyBaseUrl
? `${proxyBaseUrl}/v1/tool/${encoded}/detail`
: `/v1/tool/${encoded}/detail`;
const response = await fetch(url, {
method: "GET",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({ tool_name: toolName, call_policy: callPolicy }),
});
if (!response.ok) {
const errorData = await response.text();
throw new Error(errorData);
}
return response.json();
};
export const updateToolPolicy = async (
accessToken: string,
toolName: string,
policies: { input_policy?: string; output_policy?: string },
options?: { team_id?: string | null; key_hash?: string | null; key_alias?: string | null }
): Promise<ToolRow> => {
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/policy` : `/v1/tool/policy`;
const body: Record<string, string | undefined | null> = {
tool_name: toolName,
};
if (policies.input_policy != null) body.input_policy = policies.input_policy;
if (policies.output_policy != null) body.output_policy = policies.output_policy;
if (options?.team_id != null) body.team_id = options.team_id || undefined;
if (options?.key_hash != null) body.key_hash = options.key_hash || undefined;
if (options?.key_alias != null) body.key_alias = options.key_alias || undefined;
const response = await fetch(url, {
method: "POST",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
});
if (!response.ok) {
const errorData = await response.text();
throw new Error(errorData);
}
return response.json();
};
export const deleteToolPolicyOverride = async (
accessToken: string,
toolName: string,
params: { team_id?: string | null; key_hash?: string | null }
): Promise<{ deleted: boolean; tool_name: string }> => {
const encoded = encodeURIComponent(toolName);
const q = new URLSearchParams();
if (params.team_id != null && params.team_id !== "") q.set("team_id", params.team_id);
if (params.key_hash != null && params.key_hash !== "") q.set("key_hash", params.key_hash);
const query = q.toString();
const url = proxyBaseUrl
? `${proxyBaseUrl}/v1/tool/${encoded}/overrides${query ? `?${query}` : ""}`
: `/v1/tool/${encoded}/overrides${query ? `?${query}` : ""}`;
const response = await fetch(url, {
method: "DELETE",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
},
});
if (!response.ok) {
const errorData = await response.text();