* fix(spend-tracking): fall back to direct spend-counter increment when reservation reconcile fails When the reservation-reconcile path in `_reconcile_budget_reservation_for_counter_update` hits a Redis error, it now correctly returns an empty set so that `increment_spend_counters` re-runs the direct increment for the affected counters. Previously, the function logged the failure, invalidated the reserved counters, and still returned the reserved counter keys, which caused the caller to skip the direct increment. With the increment skipped and the counter deleted, the next request reseeded the counter from `LiteLLM_VerificationToken.spend`, a column the batched flusher only updates every few seconds, so the enforced cross-pod spend value collapsed to a stale snapshot and budget gating stopped firing for affected keys. Adds a regression test that exercises the failure path with a flaky redis backend and asserts the actual response cost lands in the shared counter. * fix(register_model): preserve built-in cache pricing when registering custom overrides under unmapped keys When a custom-priced model is registered under a key shape that get_model_info cannot resolve (e.g. litellm_params.model set to bedrock/bedrock/us.anthropic.claude-sonnet-4-6 or another non-canonical alias), register_model previously fell back to an empty existing_model. The merged entry then carried only the fields the user set explicitly (input/output cost, provider) and dropped cache pricing. Downstream the cost calculator defaulted cache_creation_input_token_cost and cache_read_input_token_cost to 0, silently dropping the bulk of the bill for cache-heavy Anthropic traffic. register_model now attempts to resolve a canonical built-in entry by stripping provider prefixes, region prefixes, and provider-specific suffixes before giving up. When a variant resolves, its defaults (notably cache pricing) are inherited while the user's explicit overrides still win. When nothing resolves and the user supplied no cache pricing, it logs a warning instead of silently under-billing. * fix(router): inherit built-in cache pricing on deployments with partial custom pricing A deployment configured with only input_cost_per_token and output_cost_per_token under model_info was being registered under its model_info.id with no cache cost fields. The cost calculator then defaulted cache_creation_input_token_cost and cache_read_input_token_cost to 0, silently billing cache_read and cache_creation tokens at zero. For cache-heavy Anthropic traffic this drops the bulk of the bill. When the deployment's litellm_params.model resolves to a built-in cost-map entry, pull the cache pricing fields from there before registering. User-specified cache fields still win on merge; only missing fields are inherited. Pairs with the register_model fallback added earlier in this branch: that handles unmapped key shapes like bedrock/bedrock/x, this handles deploy-id keys whose backend model is mapped. * fix(register_model): inherit only cache pricing on unmapped-key fallback, not provider The unmapped-key fallback in register_model copied the entire resolved built-in entry, so registering openai/command-r-plus inherited the cohere built-in's litellm_provider and get_model_info(custom_llm_provider=openai) could no longer resolve it. Restrict the fallback to the cache-pricing fields, matching the router-side _inherit_builtin_cache_pricing, so the cache-cost dropout stays fixed without clobbering the registered provider. Add a direct unit test for Router._inherit_builtin_cache_pricing so the router coverage check sees it, and pin the fixed spend-counter contract: when reservation reconcile fails the counter must hold the directly incremented cost rather than being left at None.
7939 lines
288 KiB
Python
7939 lines
288 KiB
Python
import asyncio
|
||
import importlib
|
||
import json
|
||
import os
|
||
import socket
|
||
import subprocess
|
||
import sys
|
||
from datetime import datetime, timedelta, timezone
|
||
from pathlib import Path
|
||
from unittest import mock
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
import click
|
||
import httpx
|
||
import pytest
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.testclient import TestClient
|
||
|
||
sys.path.insert(
|
||
0, os.path.abspath("../../..")
|
||
) # Adds the parent directory to the system-path
|
||
|
||
import litellm
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||
from litellm.proxy.proxy_server import app, initialize
|
||
from litellm.utils import _invalidate_model_cost_lowercase_map
|
||
|
||
example_embedding_result = {
|
||
"object": "list",
|
||
"data": [
|
||
{
|
||
"object": "embedding",
|
||
"index": 0,
|
||
"embedding": [
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
-0.006929283495992422,
|
||
-0.005336422007530928,
|
||
-4.547132266452536e-05,
|
||
-0.024047505110502243,
|
||
],
|
||
}
|
||
],
|
||
"model": "text-embedding-3-small",
|
||
"usage": {"prompt_tokens": 5, "total_tokens": 5},
|
||
}
|
||
|
||
|
||
def mock_patch_aembedding():
|
||
return mock.patch(
|
||
"litellm.proxy.proxy_server.llm_router.aembedding",
|
||
return_value=example_embedding_result,
|
||
)
|
||
|
||
|
||
@pytest.fixture(scope="function")
|
||
def client_no_auth():
|
||
# Assuming litellm.proxy.proxy_server is an object
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
return TestClient(app)
|
||
|
||
|
||
def test_login_v2_returns_redirect_url_and_sets_cookie(monkeypatch):
|
||
mock_login_result = {"user_id": "test-user"}
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(return_value=mock_login_result)
|
||
mock_create_ui_token_object = MagicMock(return_value={"user_id": "test-user"})
|
||
mock_jwt_encode = MagicMock(return_value="signed-token")
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
mock_create_ui_token_object,
|
||
)
|
||
monkeypatch.setattr("jwt.encode", mock_jwt_encode)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {
|
||
"redirect_url": "http://testserver/ui/?login=success",
|
||
"token": "signed-token",
|
||
}
|
||
assert response.cookies.get("token") == "signed-token"
|
||
|
||
mock_authenticate_user.assert_awaited_once_with(
|
||
username="alice",
|
||
password="secret",
|
||
master_key="test-master-key",
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
mock_create_ui_token_object.assert_called_once_with(
|
||
login_result=mock_login_result,
|
||
general_settings={},
|
||
premium_user=False,
|
||
)
|
||
mock_jwt_encode.assert_called_once_with(
|
||
{"user_id": "test-user"},
|
||
"test-master-key",
|
||
algorithm="HS256",
|
||
)
|
||
|
||
|
||
def test_login_v2_returns_json_on_proxy_exception(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when ProxyException is raised"""
|
||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=ProxyException(
|
||
message="Invalid credentials",
|
||
type=ProxyErrorTypes.auth_error,
|
||
param="password",
|
||
code=401,
|
||
)
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "wrong"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert data["error"]["message"] == "Invalid credentials"
|
||
assert data["error"]["type"] == "auth_error"
|
||
|
||
|
||
def test_login_v2_returns_json_on_http_exception(monkeypatch):
|
||
"""Test that /v2/login converts HTTPException to JSON error response"""
|
||
from fastapi import HTTPException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=HTTPException(status_code=401, detail="Unauthorized")
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
|
||
|
||
def test_login_v2_returns_json_on_unexpected_exception(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when unexpected exception occurs"""
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(side_effect=ValueError("Unexpected error"))
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 500
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
assert "Unexpected error" in data["error"]["message"]
|
||
|
||
|
||
def test_login_v2_returns_json_on_invalid_json_body(monkeypatch):
|
||
"""Test that /v2/login returns JSON error when request body is invalid JSON"""
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v2/login",
|
||
content="invalid json",
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
|
||
assert response.status_code == 500
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert isinstance(data["error"], dict)
|
||
|
||
|
||
def test_login_v3_rejected_without_control_plane_url(monkeypatch):
|
||
"""v3/login returns 404 when control_plane_url is not configured."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 404
|
||
assert "control_plane_url" in response.json()["error"]["message"]
|
||
|
||
|
||
def test_login_v3_returns_code(monkeypatch):
|
||
"""v3/login returns an opaque code, not the JWT directly."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "code" in data
|
||
assert data["expires_in"] == 60
|
||
assert "token" not in data
|
||
|
||
|
||
def test_login_v3_exchange_happy_path(monkeypatch):
|
||
"""Full flow: v3/login returns code, v3/login/exchange redeems it for JWT."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
|
||
# Step 1: login — get code
|
||
login_response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
assert login_response.status_code == 200
|
||
code = login_response.json()["code"]
|
||
|
||
# Step 2: exchange — get JWT
|
||
exchange_response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": code},
|
||
)
|
||
assert exchange_response.status_code == 200
|
||
exchange_data = exchange_response.json()
|
||
assert exchange_data["token"] == "signed-token"
|
||
assert "redirect_url" in exchange_data
|
||
assert exchange_response.cookies.get("token") == "signed-token"
|
||
|
||
|
||
def test_login_v3_exchange_single_use(monkeypatch):
|
||
"""Code can only be redeemed once."""
|
||
mock_prisma_client = MagicMock()
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
AsyncMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.create_ui_token_object",
|
||
MagicMock(return_value={"user_id": "test-user"}),
|
||
)
|
||
monkeypatch.setattr("jwt.encode", MagicMock(return_value="signed-token"))
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
mock_config = MagicMock()
|
||
mock_config.worker_registry = []
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.proxy_config", mock_config)
|
||
monkeypatch.setattr("litellm.proxy.utils.get_server_root_path", lambda: "")
|
||
monkeypatch.setattr("litellm.proxy.utils.get_proxy_base_url", lambda: None)
|
||
|
||
client = TestClient(app)
|
||
|
||
login_response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "secret"},
|
||
)
|
||
code = login_response.json()["code"]
|
||
|
||
# First exchange succeeds
|
||
first = client.post("/v3/login/exchange", json={"code": code})
|
||
assert first.status_code == 200
|
||
|
||
# Second exchange fails
|
||
second = client.post("/v3/login/exchange", json={"code": code})
|
||
assert second.status_code == 401
|
||
|
||
|
||
def test_login_v3_exchange_invalid_code(monkeypatch):
|
||
"""Random code returns 401."""
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": "nonexistent-code"},
|
||
)
|
||
assert response.status_code == 401
|
||
|
||
|
||
def test_login_v3_exchange_rejected_without_control_plane_url(monkeypatch):
|
||
"""v3/login/exchange returns 404 when control_plane_url is not configured."""
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.general_settings", {})
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login/exchange",
|
||
json={"code": "some-code"},
|
||
)
|
||
|
||
assert response.status_code == 404
|
||
assert "control_plane_url" in response.json()["error"]["message"]
|
||
|
||
|
||
def test_login_v3_returns_json_on_proxy_exception(monkeypatch):
|
||
"""Test that /v3/login returns JSON error when ProxyException is raised"""
|
||
from litellm.proxy._types import ProxyErrorTypes, ProxyException
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_authenticate_user = AsyncMock(
|
||
side_effect=ProxyException(
|
||
message="Invalid credentials",
|
||
type=ProxyErrorTypes.auth_error,
|
||
param="password",
|
||
code=401,
|
||
)
|
||
)
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.auth.login_utils.authenticate_user",
|
||
mock_authenticate_user,
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.master_key", "test-master-key")
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"control_plane_url": "https://cp.example.com"},
|
||
)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
client = TestClient(app)
|
||
response = client.post(
|
||
"/v3/login",
|
||
json={"username": "alice", "password": "wrong"},
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
assert response.headers["content-type"] == "application/json"
|
||
data = response.json()
|
||
assert "error" in data
|
||
assert data["error"]["message"] == "Invalid credentials"
|
||
assert data["error"]["type"] == "auth_error"
|
||
|
||
|
||
def test_fallback_login_has_no_deprecation_banner(client_no_auth):
|
||
response = client_no_auth.get("/fallback/login")
|
||
|
||
assert response.status_code == 200
|
||
html = response.text
|
||
assert '<div class="deprecation-banner">' not in html
|
||
assert "Deprecated:" not in html
|
||
assert "<form" in html
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"ui_logo_path",
|
||
[
|
||
"/etc/litellm/secret-config.json",
|
||
"/var/secrets/admin.key",
|
||
"/proc/self/environ",
|
||
"relative/path/logo.png",
|
||
],
|
||
)
|
||
def test_get_logo_url_does_not_disclose_local_paths(
|
||
client_no_auth, monkeypatch, ui_logo_path
|
||
):
|
||
# ``/get_logo_url`` is unauthenticated. Returning a local filesystem
|
||
# path verbatim discloses admin-only config to any caller. Only
|
||
# browser-loadable HTTP(S) URLs should be returned; for local paths
|
||
# the dashboard falls back to ``/get_image``.
|
||
monkeypatch.setenv("UI_LOGO_PATH", ui_logo_path)
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": ""}
|
||
|
||
|
||
def test_get_logo_url_returns_https_url(client_no_auth, monkeypatch):
|
||
monkeypatch.setenv("UI_LOGO_PATH", "https://cdn.public.example/logo.png")
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": "https://cdn.public.example/logo.png"}
|
||
|
||
|
||
def test_get_logo_url_returns_http_url(client_no_auth, monkeypatch):
|
||
# HTTP URLs (typically internal CDN) are still returned — those are
|
||
# intended to be loaded directly by the browser.
|
||
monkeypatch.setenv("UI_LOGO_PATH", "http://internal-cdn.corp:8080/logo.png")
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": "http://internal-cdn.corp:8080/logo.png"}
|
||
|
||
|
||
def test_get_logo_url_returns_empty_when_unset(client_no_auth, monkeypatch):
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
response = client_no_auth.get("/get_logo_url")
|
||
|
||
assert response.status_code == 200
|
||
assert response.json() == {"logo_url": ""}
|
||
|
||
|
||
def test_sso_key_generate_shows_deprecation_banner(client_no_auth, monkeypatch):
|
||
# Ensure the route returns the HTML form instead of redirecting
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.show_missing_vars_in_env",
|
||
lambda: None,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.get_redirect_url_for_sso",
|
||
lambda *args, **kwargs: "http://test/redirect",
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler._get_cli_state",
|
||
lambda *args, **kwargs: None,
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.management_endpoints.ui_sso.SSOAuthenticationHandler.should_use_sso_handler",
|
||
lambda *args, **kwargs: False,
|
||
)
|
||
# Mock premium_user to bypass enterprise check (prevents 403 Forbidden)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.premium_user",
|
||
True,
|
||
)
|
||
monkeypatch.setenv("UI_USERNAME", "admin")
|
||
|
||
response = client_no_auth.get("/sso/key/generate")
|
||
|
||
assert response.status_code == 200
|
||
html = response.text
|
||
assert '<div class="deprecation-banner">' in html
|
||
assert "Deprecated:" in html
|
||
|
||
|
||
def test_restructure_ui_html_files_handles_nested_routes(tmp_path):
|
||
"""
|
||
Test that _restructure_ui_html_files correctly restructures HTML files.
|
||
Note: This function is always called now, both in development and non-root Docker environments.
|
||
"""
|
||
from litellm.proxy import proxy_server
|
||
|
||
ui_root = tmp_path / "ui"
|
||
ui_root.mkdir()
|
||
|
||
def write_file(path: Path, content: str) -> None:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
path.write_text(content)
|
||
|
||
write_file(ui_root / "home.html", "home")
|
||
write_file(ui_root / "mcp" / "oauth" / "callback.html", "callback")
|
||
write_file(ui_root / "existing" / "index.html", "keep")
|
||
write_file(ui_root / "_next" / "ignore.html", "asset")
|
||
write_file(ui_root / "litellm-asset-prefix" / "ignore.html", "asset")
|
||
|
||
proxy_server._restructure_ui_html_files(str(ui_root))
|
||
|
||
assert not (ui_root / "home.html").exists()
|
||
assert (ui_root / "home" / "index.html").read_text() == "home"
|
||
assert not (ui_root / "mcp" / "oauth" / "callback.html").exists()
|
||
assert (
|
||
ui_root / "mcp" / "oauth" / "callback" / "index.html"
|
||
).read_text() == "callback"
|
||
assert (ui_root / "existing" / "index.html").read_text() == "keep"
|
||
assert (ui_root / "_next" / "ignore.html").read_text() == "asset"
|
||
assert (ui_root / "litellm-asset-prefix" / "ignore.html").read_text() == "asset"
|
||
|
||
|
||
def test_ui_extensionless_route_requires_restructure(tmp_path):
|
||
"""
|
||
Regression for non-root fallback: /ui/login expects login/index.html.
|
||
Note: Restructuring always happens now, both in development and non-root Docker environments.
|
||
"""
|
||
|
||
from litellm.proxy import proxy_server
|
||
|
||
ui_root = tmp_path / "ui"
|
||
ui_root.mkdir()
|
||
(ui_root / "index.html").write_text("index")
|
||
(ui_root / "login.html").write_text("login")
|
||
|
||
fastapi_app = FastAPI()
|
||
fastapi_app.mount("/ui", StaticFiles(directory=str(ui_root), html=True), name="ui")
|
||
client = TestClient(fastapi_app)
|
||
|
||
assert client.get("/ui/login.html").status_code == 200
|
||
assert client.get("/ui/login").status_code == 404
|
||
|
||
proxy_server._restructure_ui_html_files(str(ui_root))
|
||
|
||
response = client.get("/ui/login")
|
||
assert response.status_code == 200
|
||
assert "login" in response.text
|
||
|
||
|
||
def test_admin_ui_export_serves_nested_extensionless_routes():
|
||
out_dir = Path(litellm.__file__).parent / "proxy" / "_experimental" / "out"
|
||
assert out_dir.is_dir(), f"missing UI export at {out_dir}"
|
||
|
||
nested_html_offenders = [
|
||
path.relative_to(out_dir).as_posix()
|
||
for path in out_dir.rglob("*.html")
|
||
if path.parent != out_dir
|
||
and path.name != "index.html"
|
||
and "_next" not in path.parts
|
||
and "litellm-asset-prefix" not in path.parts
|
||
]
|
||
assert not nested_html_offenders, (
|
||
"Nested routes must be named index.html. Offenders: " f"{nested_html_offenders}"
|
||
)
|
||
|
||
callback_index = out_dir / "mcp" / "oauth" / "callback" / "index.html"
|
||
assert callback_index.is_file(), (
|
||
f"MCP OAuth callback page must exist at {callback_index}; "
|
||
"without it /ui/mcp/oauth/callback 404s after Linear redirects back."
|
||
)
|
||
|
||
fastapi_app = FastAPI()
|
||
fastapi_app.mount("/ui", StaticFiles(directory=str(out_dir), html=True), name="ui")
|
||
client = TestClient(fastapi_app)
|
||
|
||
redirect = client.get(
|
||
"/ui/mcp/oauth/callback?code=abc&state=xyz",
|
||
follow_redirects=False,
|
||
)
|
||
assert redirect.status_code == 307
|
||
assert redirect.headers["location"].endswith(
|
||
"/ui/mcp/oauth/callback/?code=abc&state=xyz"
|
||
)
|
||
|
||
landed = client.get("/ui/mcp/oauth/callback?code=abc&state=xyz")
|
||
assert landed.status_code == 200
|
||
assert "<html" in landed.text.lower()
|
||
|
||
|
||
def test_restructure_always_happens(monkeypatch):
|
||
"""
|
||
Test that restructuring logic always executes regardless of LITELLM_NON_ROOT setting.
|
||
In development (is_non_root=False), restructuring happens directly in _experimental/out.
|
||
In non-root Docker (is_non_root=True), restructuring happens in /var/lib/litellm/ui.
|
||
"""
|
||
# Test Case 1: is_non_root is True - restructuring happens in /var/lib/litellm/ui
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
|
||
runtime_ui_path = "/var/lib/litellm/ui"
|
||
packaged_ui_path = "/some/packaged/ui/path"
|
||
|
||
# Simulate the logic from proxy_server.py
|
||
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
|
||
if is_non_root:
|
||
ui_path = runtime_ui_path
|
||
else:
|
||
ui_path = packaged_ui_path
|
||
|
||
# Restructuring always happens now, regardless of ui_path vs packaged_ui_path
|
||
should_restructure = True
|
||
|
||
assert is_non_root is True
|
||
assert should_restructure is True
|
||
assert ui_path == runtime_ui_path
|
||
|
||
# Test Case 2: is_non_root is False - restructuring happens directly in packaged_ui_path
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
|
||
# Simulate the logic from proxy_server.py
|
||
is_non_root = os.getenv("LITELLM_NON_ROOT", "").lower() == "true"
|
||
if is_non_root:
|
||
ui_path = runtime_ui_path
|
||
else:
|
||
ui_path = packaged_ui_path
|
||
|
||
# Restructuring always happens now, even when ui_path == packaged_ui_path
|
||
should_restructure = True
|
||
|
||
assert is_non_root is False
|
||
assert should_restructure is True
|
||
assert ui_path == packaged_ui_path
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_initialize_scheduled_jobs_credentials(monkeypatch):
|
||
"""
|
||
Test that get_credentials is only called when store_model_in_db is True
|
||
"""
|
||
monkeypatch.delenv("DISABLE_PRISMA_SCHEMA_UPDATE", raising=False)
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
# Mock dependencies
|
||
mock_prisma_client = MagicMock()
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
): # set store_model_in_db to False
|
||
# Test when store_model_in_db is False
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# Verify get_credentials was not called
|
||
mock_proxy_config.get_credentials.assert_not_called()
|
||
|
||
# Now test with store_model_in_db = True
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# Verify get_credentials was called both directly and scheduled
|
||
assert mock_proxy_config.get_credentials.call_count == 1 # Direct call
|
||
|
||
# Verify a scheduled job was added for get_credentials
|
||
mock_scheduler_calls = [
|
||
call[0] for call in mock_proxy_config.get_credentials.mock_calls
|
||
]
|
||
assert len(mock_scheduler_calls) > 0
|
||
|
||
|
||
def test_update_config_fields_deep_merge_db_wins():
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
current_config = {
|
||
"router_settings": {
|
||
"routing_mode": "cost_optimized",
|
||
"model_group_alias": {
|
||
# Existing alias with older model + different hidden flag
|
||
"claude-sonnet-4": {
|
||
"model": "claude-sonnet-4-20240219",
|
||
"hidden": True,
|
||
},
|
||
# An extra alias that should remain untouched unless DB overrides it
|
||
"legacy-sonnet": {
|
||
"model": "claude-2.1",
|
||
"hidden": True,
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
db_param_value = {
|
||
"model_group_alias": {
|
||
# Conflict: DB should win (both 'model' and 'hidden')
|
||
"claude-sonnet-4": {
|
||
"model": "claude-sonnet-4-20250514",
|
||
"hidden": False,
|
||
},
|
||
# New alias to be added by the merge
|
||
"claude-sonnet-latest": {
|
||
"model": "claude-sonnet-4-20250514",
|
||
"hidden": True,
|
||
},
|
||
# Demonstrate that None values from DB are skipped (preserve existing)
|
||
"legacy-sonnet": {"hidden": None}, # should not clobber current True
|
||
}
|
||
}
|
||
|
||
updated = proxy_config._update_config_fields(
|
||
current_config=current_config,
|
||
param_name="router_settings",
|
||
db_param_value=db_param_value,
|
||
)
|
||
|
||
rs = updated["router_settings"]
|
||
aliases = rs["model_group_alias"]
|
||
|
||
# DB wins on conflicts (deep) for existing alias
|
||
assert aliases["claude-sonnet-4"]["model"] == "claude-sonnet-4-20250514"
|
||
assert aliases["claude-sonnet-4"]["hidden"] is False
|
||
|
||
# New alias introduced by DB is present with its values
|
||
assert "claude-sonnet-latest" in aliases
|
||
assert aliases["claude-sonnet-latest"]["model"] == "claude-sonnet-4-20250514"
|
||
assert aliases["claude-sonnet-latest"]["hidden"] is True
|
||
|
||
# None in DB does not overwrite existing values
|
||
assert aliases["legacy-sonnet"]["model"] == "claude-2.1"
|
||
assert aliases["legacy-sonnet"]["hidden"] is True
|
||
|
||
# Unrelated router_settings keys are preserved
|
||
assert rs["routing_mode"] == "cost_optimized"
|
||
|
||
|
||
def test_get_config_custom_callback_api_env_vars(monkeypatch):
|
||
"""
|
||
Ensure /get/config/callbacks returns custom callback env vars when both custom values are provided.
|
||
"""
|
||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||
|
||
# Mock config with custom_callback_api enabled and generic logger env vars present
|
||
config_data = {
|
||
"litellm_settings": {"success_callback": ["custom_callback_api"]},
|
||
"general_settings": {},
|
||
"environment_variables": {
|
||
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
|
||
"GENERIC_LOGGER_HEADERS": "Auth: token",
|
||
},
|
||
}
|
||
|
||
# Mock proxy_config.get_config and router settings
|
||
mock_router = MagicMock()
|
||
mock_router.get_settings.return_value = {}
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
|
||
|
||
# Bypass auth dependency
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
|
||
|
||
client = TestClient(app)
|
||
try:
|
||
response = client.get("/get/config/callbacks")
|
||
finally:
|
||
app.dependency_overrides = original_overrides
|
||
|
||
assert response.status_code == 200
|
||
callbacks = response.json()["callbacks"]
|
||
custom_cb = next(
|
||
(cb for cb in callbacks if cb["name"] == "custom_callback_api"), None
|
||
)
|
||
|
||
assert custom_cb is not None
|
||
assert custom_cb["variables"] == {
|
||
"GENERIC_LOGGER_ENDPOINT": "https://callback.example.com",
|
||
"GENERIC_LOGGER_HEADERS": "Auth: token",
|
||
}
|
||
|
||
|
||
# Mock Prisma
|
||
class MockPrisma:
|
||
def __init__(self, database_url=None, proxy_logging_obj=None, http_client=None):
|
||
self.database_url = database_url
|
||
self.proxy_logging_obj = proxy_logging_obj
|
||
self.http_client = http_client
|
||
|
||
async def connect(self):
|
||
pass
|
||
|
||
async def disconnect(self):
|
||
pass
|
||
|
||
|
||
mock_prisma = MockPrisma()
|
||
|
||
|
||
@patch(
|
||
"litellm.proxy.proxy_server.ProxyStartupEvent._setup_prisma_client",
|
||
return_value=mock_prisma,
|
||
)
|
||
@pytest.mark.asyncio
|
||
async def test_aaaproxy_startup_master_key(mock_prisma, monkeypatch, tmp_path):
|
||
"""
|
||
Test that master_key is correctly loaded from either config.yaml or environment variables
|
||
"""
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
|
||
# Import happens here - this is when the module probably reads the config path
|
||
from litellm.proxy.proxy_server import proxy_startup_event
|
||
|
||
# Mock the Prisma import
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.PrismaClient", MockPrisma)
|
||
|
||
# Create test app
|
||
app = FastAPI()
|
||
|
||
# Test Case 1: Master key from config.yaml
|
||
test_master_key = "sk-12345"
|
||
test_config = {"general_settings": {"master_key": test_master_key}}
|
||
|
||
# Create a temporary config file
|
||
config_path = tmp_path / "config.yaml"
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(test_config, f)
|
||
|
||
print(f"SET ENV VARIABLE - CONFIG_FILE_PATH, str(config_path): {str(config_path)}")
|
||
# Second setting of CONFIG_FILE_PATH to a different value
|
||
monkeypatch.setenv("CONFIG_FILE_PATH", str(config_path))
|
||
print(f"config_path: {config_path}")
|
||
print(f"os.getenv('CONFIG_FILE_PATH'): {os.getenv('CONFIG_FILE_PATH')}")
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_master_key
|
||
|
||
# Test Case 2: Master key from environment variable
|
||
test_env_master_key = "sk-test-67890"
|
||
|
||
# Create empty config
|
||
empty_config = {"general_settings": {}}
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(empty_config, f)
|
||
|
||
monkeypatch.setenv("LITELLM_MASTER_KEY", test_env_master_key)
|
||
print("test_env_master_key: {}".format(test_env_master_key))
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_env_master_key
|
||
|
||
# Test Case 3: Master key with os.environ prefix
|
||
test_resolved_key = "sk-resolved-key"
|
||
test_config_with_prefix = {
|
||
"general_settings": {"master_key": "os.environ/CUSTOM_MASTER_KEY"}
|
||
}
|
||
|
||
# Create config with os.environ prefix
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(test_config_with_prefix, f)
|
||
|
||
monkeypatch.setenv("CUSTOM_MASTER_KEY", test_resolved_key)
|
||
async with proxy_startup_event(app):
|
||
from litellm.proxy.proxy_server import master_key
|
||
|
||
assert master_key == test_resolved_key
|
||
|
||
|
||
def test_team_info_masking():
|
||
"""
|
||
Test that sensitive team information is properly masked
|
||
|
||
Ref: https://huntr.com/bounties/661b388a-44d8-4ad5-862b-4dc5b80be30a
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
# Test team object with sensitive data
|
||
team1_info = {
|
||
"success_callback": "['langfuse', 's3']",
|
||
"langfuse_secret": "secret-test-key",
|
||
"langfuse_public_key": "public-test-key",
|
||
}
|
||
|
||
with pytest.raises(Exception) as exc_info:
|
||
proxy_config._get_team_config(
|
||
team_id="test_dev",
|
||
all_teams_config=[team1_info],
|
||
)
|
||
|
||
print("Got exception: {}".format(exc_info.value))
|
||
assert "secret-test-key" not in str(exc_info.value)
|
||
assert "public-test-key" not in str(exc_info.value)
|
||
|
||
|
||
def test_embedding_input_array_of_tokens(client_no_auth):
|
||
"""
|
||
Test to bypass decoding input as array of tokens for selected providers
|
||
|
||
Ref: https://github.com/BerriAI/litellm/issues/10113
|
||
"""
|
||
from litellm.proxy import proxy_server
|
||
|
||
# The client_no_auth fixture should initialize the router
|
||
# Assert this to catch any router initialization regressions
|
||
assert proxy_server.llm_router is not None, (
|
||
"llm_router is None after client_no_auth fixture initialized. "
|
||
"This indicates a router initialization issue that should be investigated."
|
||
)
|
||
|
||
try:
|
||
with mock.patch.object(
|
||
proxy_server.llm_router,
|
||
"aembedding",
|
||
return_value=example_embedding_result,
|
||
) as mock_aembedding:
|
||
test_data = {
|
||
"model": "vllm_embed_model",
|
||
"input": [[2046, 13269, 158208]],
|
||
}
|
||
|
||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||
|
||
# Assert that aembedding was called, and that input was not modified
|
||
mock_aembedding.assert_called_once()
|
||
call_args, call_kwargs = mock_aembedding.call_args
|
||
assert call_kwargs["model"] == "vllm_embed_model"
|
||
assert call_kwargs["input"] == [[2046, 13269, 158208]]
|
||
|
||
assert response.status_code == 200
|
||
result = response.json()
|
||
print(len(result["data"][0]["embedding"]))
|
||
assert (
|
||
len(result["data"][0]["embedding"]) > 10
|
||
) # this usually has len==1536 so
|
||
except Exception as e:
|
||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_all_team_models():
|
||
"""
|
||
Test get_all_team_models function with both "*" and specific team IDs
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import get_all_team_models
|
||
|
||
# Mock team data
|
||
mock_team1 = MagicMock()
|
||
mock_team1.model_dump.return_value = {
|
||
"team_id": "team1",
|
||
"models": ["gpt-4", "gpt-3.5-turbo"],
|
||
"team_alias": "Team 1",
|
||
}
|
||
|
||
mock_team2 = MagicMock()
|
||
mock_team2.model_dump.return_value = {
|
||
"team_id": "team2",
|
||
"models": ["claude-3", "gpt-4"],
|
||
"team_alias": "Team 2",
|
||
}
|
||
|
||
# Mock model data returned by router
|
||
mock_models_gpt4 = [
|
||
{"model_info": {"id": "gpt-4-model-1"}},
|
||
{"model_info": {"id": "gpt-4-model-2"}},
|
||
]
|
||
mock_models_gpt35 = [
|
||
{"model_info": {"id": "gpt-3.5-turbo-model-1"}},
|
||
]
|
||
mock_models_claude = [
|
||
{"model_info": {"id": "claude-3-model-1"}},
|
||
]
|
||
|
||
# Mock prisma client
|
||
mock_prisma_client = MagicMock()
|
||
mock_db = MagicMock()
|
||
mock_litellm_teamtable = MagicMock()
|
||
|
||
mock_prisma_client.db = mock_db
|
||
mock_db.litellm_teamtable = mock_litellm_teamtable
|
||
|
||
# Make find_many async
|
||
mock_litellm_teamtable.find_many = AsyncMock()
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return mock_models_gpt4
|
||
elif model_name == "gpt-3.5-turbo":
|
||
return mock_models_gpt35
|
||
elif model_name == "claude-3":
|
||
return mock_models_claude
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Test Case 1: user_teams = "*" (all teams)
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1, mock_team2]
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
# Configure the mock class to return proper instances
|
||
def mock_team_table_constructor(**kwargs):
|
||
mock_instance = MagicMock()
|
||
mock_instance.team_id = kwargs["team_id"]
|
||
mock_instance.models = kwargs["models"]
|
||
mock_instance.access_group_ids = kwargs.get("access_group_ids")
|
||
return mock_instance
|
||
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams="*",
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called without where clause for "*"
|
||
mock_litellm_teamtable.find_many.assert_called_with()
|
||
|
||
# Verify router.get_model_list was called for each model
|
||
expected_calls = [
|
||
mock.call(model_name="gpt-4", team_id="team1"),
|
||
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
|
||
mock.call(model_name="claude-3", team_id="team2"),
|
||
mock.call(model_name="gpt-4", team_id="team2"),
|
||
]
|
||
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
|
||
|
||
# Test Case 2: user_teams = specific list
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Only return team1 for specific team query
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1]
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called with where clause for specific teams
|
||
mock_litellm_teamtable.find_many.assert_called_with(
|
||
where={"team_id": {"in": ["team1"]}}
|
||
)
|
||
|
||
# Verify router.get_model_list was called only for team1 models
|
||
expected_calls = [
|
||
mock.call(model_name="gpt-4", team_id="team1"),
|
||
mock.call(model_name="gpt-3.5-turbo", team_id="team1"),
|
||
]
|
||
mock_router.get_model_list.assert_has_calls(expected_calls, any_order=True)
|
||
|
||
# Test Case 3: Empty teams list
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_litellm_teamtable.find_many.return_value = []
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=[],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Verify find_many was called with empty list
|
||
mock_litellm_teamtable.find_many.assert_called_with(where={"team_id": {"in": []}})
|
||
|
||
# Should return empty list when no teams
|
||
assert result == {}
|
||
|
||
# Test Case 4: Router returns None for some models
|
||
mock_litellm_teamtable.reset_mock()
|
||
mock_router.reset_mock()
|
||
mock_litellm_teamtable.find_many.return_value = [mock_team1]
|
||
|
||
def mock_get_model_list_with_none(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return mock_models_gpt4
|
||
# Return None for gpt-3.5-turbo to test None handling
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list_with_none
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_team_table_class:
|
||
mock_team_table_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# Should handle None return gracefully
|
||
assert isinstance(result, dict)
|
||
print("result: ", result)
|
||
assert result == {"gpt-4-model-1": ["team1"], "gpt-4-model-2": ["team1"]}
|
||
|
||
|
||
def test_add_team_models_to_all_models():
|
||
"""
|
||
Test add_team_models_to_all_models function
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_team_models_to_all_models
|
||
|
||
team_db_objects_typed = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_db_objects_typed.team_id = "team1"
|
||
team_db_objects_typed.models = ["all-proxy-models"]
|
||
|
||
llm_router = MagicMock()
|
||
llm_router.get_model_list.return_value = [
|
||
{"model_info": {"id": "gpt-4-model-1", "team_id": "team2"}},
|
||
{"model_info": {"id": "gpt-4-model-2"}},
|
||
]
|
||
|
||
result = _add_team_models_to_all_models(
|
||
team_db_objects_typed=[team_db_objects_typed],
|
||
llm_router=llm_router,
|
||
)
|
||
assert result == {"gpt-4-model-2": {"team1"}}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_matches_team_public_model_name():
|
||
"""
|
||
Regression test: team BYOK models persist an internal model_name
|
||
(e.g. `model_name_{team_id}_{uuid}`) and surface the user-facing name
|
||
via `model_info.team_public_model_name`. The /v2/model/info search
|
||
filter must match that public name so BYOK rows appear in results.
|
||
"""
|
||
from litellm.proxy.proxy_server import _apply_search_filter_to_models
|
||
|
||
byok_model = {
|
||
"model_name": "model_name_team-abc-123_4a6b8",
|
||
"litellm_params": {"model": "claude-sonnet-4-5"},
|
||
"model_info": {
|
||
"id": "byok-id-1",
|
||
"team_id": "team-abc-123",
|
||
"team_public_model_name": "team-claude-sonnet",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
unrelated_model = {
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {"id": "normal-id-1", "db_model": False},
|
||
}
|
||
|
||
# Search matching only team_public_model_name should still include BYOK
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="claude",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
filtered_ids = {m["model_info"]["id"] for m in filtered}
|
||
assert "byok-id-1" in filtered_ids
|
||
assert "normal-id-1" not in filtered_ids
|
||
|
||
# Search by internal model_name still matches as before
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="model_name_team-abc-123",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
assert [m["model_info"]["id"] for m in filtered] == ["byok-id-1"]
|
||
|
||
# Non-matching search returns nothing
|
||
filtered, _ = await _apply_search_filter_to_models(
|
||
all_models=[byok_model, unrelated_model],
|
||
search="gemini",
|
||
prisma_client=None,
|
||
proxy_config=MagicMock(),
|
||
)
|
||
assert filtered == []
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_scopes_byok_to_caller_teams():
|
||
"""
|
||
Regression test: `/v2/model/info?search=...` must not leak BYOK rows
|
||
from teams the caller is not a member of. Even with a bounded
|
||
`model_name`-contains DB query, a non-admin caller could otherwise
|
||
see other teams' BYOK rows that happen to match by internal name.
|
||
The post-fetch team scope drops those.
|
||
"""
|
||
from litellm.proxy.proxy_server import _apply_search_filter_to_models
|
||
|
||
# In-router BYOK rows: one in the caller's team, one in someone else's.
|
||
caller_team_byok = {
|
||
"model_name": "model_name_team-mine_internal",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-mine",
|
||
"team_id": "team-mine",
|
||
"team_public_model_name": "claude-sonnet-prod",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
other_team_byok = {
|
||
"model_name": "model_name_team-other_internal",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-other",
|
||
"team_id": "team-other",
|
||
"team_public_model_name": "claude-sonnet-staging",
|
||
"db_model": True,
|
||
},
|
||
}
|
||
# Non-team row stays in the router-side result regardless of teams.
|
||
public_model = {
|
||
"model_name": "claude-public",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {"id": "public-id", "db_model": False},
|
||
}
|
||
|
||
# DB-only BYOK rows fetched by the over-broad JSON branch.
|
||
db_caller_row = MagicMock()
|
||
db_caller_row.model_id = "byok-db-mine"
|
||
db_caller_row.model_name = "model_name_team-mine_db"
|
||
db_caller_row.model_info = {
|
||
"id": "byok-db-mine",
|
||
"team_id": "team-mine",
|
||
"team_public_model_name": "Claude DB Mine",
|
||
"db_model": True,
|
||
}
|
||
db_other_row = MagicMock()
|
||
db_other_row.model_id = "byok-db-other"
|
||
db_other_row.model_name = "model_name_team-other_db"
|
||
db_other_row.model_info = {
|
||
"id": "byok-db-other",
|
||
"team_id": "team-other",
|
||
"team_public_model_name": "Claude DB Other",
|
||
"db_model": True,
|
||
}
|
||
|
||
prisma_client = MagicMock()
|
||
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=2)
|
||
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(
|
||
return_value=[db_caller_row, db_other_row]
|
||
)
|
||
caller_user_row = MagicMock()
|
||
caller_user_row.teams = ["team-mine"]
|
||
prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
||
return_value=caller_user_row
|
||
)
|
||
|
||
proxy_config = MagicMock()
|
||
proxy_config.decrypt_model_list_from_db = lambda rows: [
|
||
{
|
||
"model_name": r.model_name,
|
||
"model_info": r.model_info,
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
}
|
||
for r in rows
|
||
]
|
||
|
||
non_admin = MagicMock(spec=UserAPIKeyAuth)
|
||
non_admin.user_role = LitellmUserRoles.INTERNAL_USER
|
||
non_admin.user_id = "user-mine"
|
||
|
||
filtered, total_count = await _apply_search_filter_to_models(
|
||
all_models=[caller_team_byok, other_team_byok, public_model],
|
||
search="claude",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
user_api_key_dict=non_admin,
|
||
)
|
||
|
||
filtered_ids = {m["model_info"]["id"] for m in filtered}
|
||
assert "byok-mine" in filtered_ids
|
||
assert "byok-db-mine" in filtered_ids
|
||
assert "public-id" in filtered_ids
|
||
assert "byok-other" not in filtered_ids, (
|
||
"router-side BYOK from another team must be dropped from search "
|
||
"when caller doesn't belong to that team"
|
||
)
|
||
assert "byok-db-other" not in filtered_ids, (
|
||
"DB-only BYOK from another team must be dropped from search when "
|
||
"caller doesn't belong to that team"
|
||
)
|
||
# total_count is router_models_count (2: caller_team_byok + public_model,
|
||
# other_team_byok dropped router-side) + DB count (2 from the mocked
|
||
# `count()`). The DB count is the *unscoped* match count; non-admin
|
||
# team scoping applies only to the returned page so the count can be
|
||
# over-reported, but it must never under-report (callers can paginate
|
||
# within the bound).
|
||
assert total_count == 4
|
||
|
||
# Admins keep the un-scoped view across teams.
|
||
admin = MagicMock(spec=UserAPIKeyAuth)
|
||
admin.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
admin.user_id = "admin-1"
|
||
|
||
filtered_admin, _ = await _apply_search_filter_to_models(
|
||
all_models=[caller_team_byok, other_team_byok, public_model],
|
||
search="claude",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
user_api_key_dict=admin,
|
||
)
|
||
admin_ids = {m["model_info"]["id"] for m in filtered_admin}
|
||
assert "byok-other" in admin_ids
|
||
assert "byok-db-other" in admin_ids
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_apply_search_filter_bounds_db_fetch_by_page_and_cap():
|
||
"""
|
||
Regression test: a broad search term must not force a full BYOK-table
|
||
read + decrypt on each request.
|
||
|
||
* Unsorted searches: `find_many(take=N)` where N is just enough to
|
||
fill the current page after counting router-side matches.
|
||
* Sorted searches: `find_many(take=cap)` falls back to
|
||
`_SORTED_SEARCH_DB_FETCH_CAP` so ordering still works across a
|
||
large match set without scanning the whole table.
|
||
"""
|
||
from litellm.proxy.proxy_server import (
|
||
_SORTED_SEARCH_DB_FETCH_CAP,
|
||
_apply_search_filter_to_models,
|
||
)
|
||
|
||
prisma_client = MagicMock()
|
||
prisma_client.db.litellm_proxymodeltable.count = AsyncMock(return_value=10_000)
|
||
prisma_client.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
proxy_config = MagicMock()
|
||
proxy_config.decrypt_model_list_from_db = lambda rows: []
|
||
|
||
# Unsorted: page=1, size=50, no router-side matches -> take must be 50.
|
||
await _apply_search_filter_to_models(
|
||
all_models=[],
|
||
search="model",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
page=1,
|
||
size=50,
|
||
sort_by=None,
|
||
)
|
||
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
|
||
assert take == 50, "unsorted search must take just one page's worth of rows"
|
||
|
||
# Sorted: still bounded, but by the hard cap rather than the page.
|
||
prisma_client.db.litellm_proxymodeltable.find_many.reset_mock()
|
||
await _apply_search_filter_to_models(
|
||
all_models=[],
|
||
search="model",
|
||
prisma_client=prisma_client,
|
||
proxy_config=proxy_config,
|
||
page=1,
|
||
size=50,
|
||
sort_by="model_name",
|
||
)
|
||
take = prisma_client.db.litellm_proxymodeltable.find_many.call_args.kwargs["take"]
|
||
assert take == _SORTED_SEARCH_DB_FETCH_CAP
|
||
assert take < 10_000, "sorted search must cap below the full match set"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_excludes_viewer_direct_access():
|
||
"""
|
||
Regression test: when the UI picks a specific team in the Current Team
|
||
selector, the model list must show only that team's BYOK rows + the
|
||
models assigned to the team. The admin viewer's `direct_access` flag
|
||
(set on every non-team model upstream) must NOT widen the team's
|
||
visible set, or selecting team-111 still shows every public model.
|
||
"""
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
public_model = {
|
||
"model_name": "gpt-4",
|
||
"litellm_params": {"model": "gpt-4"},
|
||
"model_info": {
|
||
"id": "public-id",
|
||
# admin viewer has direct_access on this public model
|
||
"direct_access": True,
|
||
# team-111 is NOT in access_via_team_ids -> shouldn't show for team-111
|
||
"access_via_team_ids": ["team-222"],
|
||
},
|
||
}
|
||
team111_byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude-sonnet"},
|
||
"model_info": {
|
||
"id": "byok-team-111",
|
||
"team_id": "team-111",
|
||
"team_public_model_name": "team-claude",
|
||
"access_via_team_ids": ["team-111"],
|
||
},
|
||
}
|
||
team222_byok = {
|
||
"model_name": "model_name_team-222_uuid",
|
||
"litellm_params": {"model": "claude-haiku"},
|
||
"model_info": {
|
||
"id": "byok-team-222",
|
||
"team_id": "team-222",
|
||
"team_public_model_name": "team-haiku",
|
||
"access_via_team_ids": ["team-222"],
|
||
},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
team_db = MagicMock()
|
||
team_db.model_dump.return_value = {
|
||
"team_id": "team-111",
|
||
"team_alias": "Team 111",
|
||
# specific models list that doesn't include the BYOK's internal name
|
||
"models": ["some-other-model"],
|
||
"access_group_ids": None,
|
||
}
|
||
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
|
||
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
router = MagicMock()
|
||
router.get_model_access_groups = MagicMock(return_value={})
|
||
# team-111 only resolves "some-other-model", which has no deployments
|
||
router.get_model_list = MagicMock(return_value=[])
|
||
|
||
filtered = await _filter_models_by_team_id(
|
||
all_models=[public_model, team111_byok, team222_byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=router,
|
||
)
|
||
visible_ids = sorted(m["model_info"]["id"] for m in filtered)
|
||
|
||
assert "byok-team-111" in visible_ids, "team-111's own BYOK must always be visible"
|
||
assert "byok-team-222" not in visible_ids, "must not leak other teams' BYOK"
|
||
assert (
|
||
"public-id" not in visible_ids
|
||
), "viewer's direct_access must not widen the team's visible set"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_rejects_non_member():
|
||
"""
|
||
Regression test: /v2/model/info?teamId=X includes BYOK rows solely on
|
||
`model_info.team_id == X`. Without an auth check, any authenticated user
|
||
could enumerate another team's BYOK metadata by guessing its id. Callers
|
||
that are neither proxy admins nor members of `team_id` must get 403.
|
||
"""
|
||
from fastapi import HTTPException
|
||
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude"},
|
||
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
# Caller is in team-222 only
|
||
user_row = MagicMock()
|
||
user_row.teams = ["team-222"]
|
||
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="alice",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
api_key="sk-test",
|
||
)
|
||
|
||
with pytest.raises(HTTPException) as excinfo:
|
||
await _filter_models_by_team_id(
|
||
all_models=[byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=MagicMock(),
|
||
user_api_key_dict=caller,
|
||
)
|
||
assert excinfo.value.status_code == 403
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_filter_models_by_team_id_allows_team_member():
|
||
"""
|
||
A caller who IS a member of `team_id` must be allowed to filter, and
|
||
should see that team's BYOK rows.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _filter_models_by_team_id
|
||
|
||
byok = {
|
||
"model_name": "model_name_team-111_uuid",
|
||
"litellm_params": {"model": "claude"},
|
||
"model_info": {"id": "byok-team-111", "team_id": "team-111"},
|
||
}
|
||
|
||
prisma = MagicMock()
|
||
user_row = MagicMock()
|
||
user_row.teams = ["team-111", "team-999"]
|
||
prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
team_db = MagicMock()
|
||
team_db.model_dump.return_value = {
|
||
"team_id": "team-111",
|
||
"team_alias": "Team 111",
|
||
"models": [],
|
||
"access_group_ids": None,
|
||
}
|
||
prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=team_db)
|
||
prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||
|
||
router = MagicMock()
|
||
router.get_model_access_groups = MagicMock(return_value={})
|
||
router.get_model_list = MagicMock(return_value=[byok])
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="bob",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
api_key="sk-test",
|
||
)
|
||
|
||
result = await _filter_models_by_team_id(
|
||
all_models=[byok],
|
||
team_id="team-111",
|
||
prisma_client=prisma,
|
||
llm_router=router,
|
||
user_api_key_dict=caller,
|
||
)
|
||
assert [m["model_info"]["id"] for m in result] == ["byok-team-111"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_caller_byok_team_scope_treats_view_only_admin_as_unscoped():
|
||
"""
|
||
Regression test: `PROXY_ADMIN_VIEW_ONLY` is an admin role
|
||
("can login, view all own keys, view all spend"). Search results for
|
||
this role must show BYOK rows across all teams, not be silently scoped
|
||
to the user-id's `teams` field — that path narrows results to whatever
|
||
teams the admin happens to be a member of, regressing pre-PR behavior.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import _get_caller_byok_team_scope
|
||
|
||
caller = UserAPIKeyAuth(
|
||
user_id="view-admin",
|
||
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||
api_key="sk-test",
|
||
)
|
||
scope = await _get_caller_byok_team_scope(
|
||
user_api_key_dict=caller,
|
||
prisma_client=MagicMock(),
|
||
)
|
||
assert scope is None, "PROXY_ADMIN_VIEW_ONLY must be unscoped, like PROXY_ADMIN"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_to_team_models():
|
||
"""
|
||
Test that models reachable via team access groups are included in team_models.
|
||
|
||
Scenario: A team has models=["gpt-4"] and access_group_ids=["premium"].
|
||
The "premium" access group contains ["claude-3", "gemini"].
|
||
After resolution, the team should see gpt-4 (direct) + claude-3/gemini (via access group).
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
# Team with specific models AND access groups
|
||
team_with_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_with_access_groups.team_id = "team1"
|
||
team_with_access_groups.models = ["gpt-4"] # non-empty = specific models
|
||
team_with_access_groups.access_group_ids = ["premium"]
|
||
|
||
# Team with no access groups — should be skipped
|
||
team_without_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_without_access_groups.team_id = "team2"
|
||
team_without_access_groups.models = ["gpt-4"]
|
||
team_without_access_groups.access_group_ids = None
|
||
|
||
# Team with empty access_group_ids list — should be skipped
|
||
team_empty_access_groups = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_empty_access_groups.team_id = "team2b"
|
||
team_empty_access_groups.models = ["gpt-4"]
|
||
team_empty_access_groups.access_group_ids = []
|
||
|
||
# Team with empty models (all access) — should be skipped
|
||
team_all_access = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_all_access.team_id = "team3"
|
||
team_all_access.models = []
|
||
team_all_access.access_group_ids = ["premium"]
|
||
|
||
# Team with all-proxy-models sentinel (all access) — should be skipped
|
||
team_all_proxy = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_all_proxy.team_id = "team4"
|
||
team_all_proxy.models = ["all-proxy-models"]
|
||
team_all_proxy.access_group_ids = ["premium"]
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-id"}}]
|
||
elif model_name == "gemini":
|
||
return [{"model_info": {"id": "gemini-id"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
# Pre-existing team_models (e.g., from _add_team_models_to_all_models)
|
||
existing_team_models = {
|
||
"gpt-4-id": {"team1"},
|
||
}
|
||
|
||
# Mock prisma client with batch find_many returning access group rows
|
||
mock_ag_row = MagicMock()
|
||
mock_ag_row.access_group_id = "premium"
|
||
mock_ag_row.access_model_names = ["claude-3", "gemini"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
|
||
return_value=[mock_ag_row]
|
||
)
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[
|
||
team_with_access_groups,
|
||
team_without_access_groups,
|
||
team_empty_access_groups,
|
||
team_all_access,
|
||
team_all_proxy,
|
||
],
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
team_models=existing_team_models,
|
||
)
|
||
|
||
# Single batch query with only the eligible team's access group IDs
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
|
||
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
|
||
queried_ids = call_args[1]["where"]["access_group_id"]["in"]
|
||
assert set(queried_ids) == {"premium"}
|
||
|
||
# Original model still present
|
||
assert "gpt-4-id" in result
|
||
assert "team1" in result["gpt-4-id"]
|
||
|
||
# Access group models added for team1
|
||
assert "claude-3-id" in result
|
||
assert "team1" in result["claude-3-id"]
|
||
assert "gemini-id" in result
|
||
assert "team1" in result["gemini-id"]
|
||
|
||
# Skipped teams should NOT have added these models
|
||
for skipped_team in ["team2", "team2b", "team3", "team4"]:
|
||
assert skipped_team not in result.get("claude-3-id", set())
|
||
assert skipped_team not in result.get("gemini-id", set())
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_multiple_teams_shared_group():
|
||
"""
|
||
Test that multiple teams sharing the same access group each get the models,
|
||
and only one batch DB query is made.
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
team_a = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_a.team_id = "team-a"
|
||
team_a.models = ["gpt-4"]
|
||
team_a.access_group_ids = ["shared-group"]
|
||
|
||
team_b = MagicMock(spec=LiteLLM_TeamTable)
|
||
team_b.team_id = "team-b"
|
||
team_b.models = ["gpt-3.5"]
|
||
team_b.access_group_ids = ["shared-group", "extra-group"]
|
||
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-id"}}]
|
||
elif model_name == "gemini":
|
||
return [{"model_info": {"id": "gemini-id"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
mock_shared_row = MagicMock()
|
||
mock_shared_row.access_group_id = "shared-group"
|
||
mock_shared_row.access_model_names = ["claude-3"]
|
||
|
||
mock_extra_row = MagicMock()
|
||
mock_extra_row.access_group_id = "extra-group"
|
||
mock_extra_row.access_model_names = ["gemini"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock(
|
||
return_value=[mock_shared_row, mock_extra_row]
|
||
)
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[team_a, team_b],
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
team_models={},
|
||
)
|
||
|
||
# Single batch query for both groups
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_called_once()
|
||
call_args = mock_prisma_client.db.litellm_accessgrouptable.find_many.call_args
|
||
queried_ids = set(call_args[1]["where"]["access_group_id"]["in"])
|
||
assert queried_ids == {"shared-group", "extra-group"}
|
||
|
||
# Both teams get claude-3 from the shared group
|
||
assert "claude-3-id" in result
|
||
assert "team-a" in result["claude-3-id"]
|
||
assert "team-b" in result["claude-3-id"]
|
||
|
||
# Only team-b gets gemini (from extra-group)
|
||
assert "gemini-id" in result
|
||
assert "team-b" in result["gemini-id"]
|
||
assert "team-a" not in result["gemini-id"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_access_group_models_no_eligible_teams():
|
||
"""
|
||
When no teams have access groups, find_many should not be called at all.
|
||
"""
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
from litellm.proxy.proxy_server import _add_access_group_models_to_team_models
|
||
|
||
team = MagicMock(spec=LiteLLM_TeamTable)
|
||
team.team_id = "team1"
|
||
team.models = ["gpt-4"]
|
||
team.access_group_ids = None
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many = AsyncMock()
|
||
|
||
result = await _add_access_group_models_to_team_models(
|
||
team_db_objects_typed=[team],
|
||
llm_router=MagicMock(),
|
||
prisma_client=mock_prisma_client,
|
||
team_models={"existing-id": {"team1"}},
|
||
)
|
||
|
||
# No DB call made
|
||
mock_prisma_client.db.litellm_accessgrouptable.find_many.assert_not_called()
|
||
|
||
# Original data unchanged
|
||
assert result == {"existing-id": {"team1"}}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_all_team_models_with_access_groups():
|
||
"""
|
||
End-to-end test: get_all_team_models includes models from access groups.
|
||
|
||
Scenario: User is on team1 which has models=["gpt-4"] and
|
||
access_group_ids=["premium"]. The "premium" group has ["claude-3"].
|
||
The result should include both gpt-4 and claude-3 deployments for team1.
|
||
"""
|
||
from litellm.proxy.proxy_server import get_all_team_models
|
||
|
||
mock_team1 = MagicMock()
|
||
mock_team1.model_dump.return_value = {
|
||
"team_id": "team1",
|
||
"models": ["gpt-4"],
|
||
"team_alias": "Team 1",
|
||
"access_group_ids": ["premium"],
|
||
}
|
||
|
||
# Mock access group row returned by batch find_many
|
||
mock_ag_row = MagicMock()
|
||
mock_ag_row.access_group_id = "premium"
|
||
mock_ag_row.access_model_names = ["claude-3"]
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_db = MagicMock()
|
||
mock_litellm_teamtable = MagicMock()
|
||
mock_prisma_client.db = mock_db
|
||
mock_db.litellm_teamtable = mock_litellm_teamtable
|
||
mock_litellm_teamtable.find_many = AsyncMock(return_value=[mock_team1])
|
||
mock_db.litellm_accessgrouptable = MagicMock()
|
||
mock_db.litellm_accessgrouptable.find_many = AsyncMock(return_value=[mock_ag_row])
|
||
|
||
mock_router = MagicMock()
|
||
|
||
def mock_get_model_list(model_name, team_id=None):
|
||
if model_name == "gpt-4":
|
||
return [{"model_info": {"id": "gpt-4-deploy-1"}}]
|
||
elif model_name == "claude-3":
|
||
return [{"model_info": {"id": "claude-3-deploy-1"}}]
|
||
return None
|
||
|
||
mock_router.get_model_list.side_effect = mock_get_model_list
|
||
|
||
with patch("litellm.proxy.proxy_server.LiteLLM_TeamTable") as mock_tt_class:
|
||
|
||
def mock_team_table_constructor(**kwargs):
|
||
mock_instance = MagicMock()
|
||
mock_instance.team_id = kwargs["team_id"]
|
||
mock_instance.models = kwargs["models"]
|
||
mock_instance.access_group_ids = kwargs.get("access_group_ids")
|
||
return mock_instance
|
||
|
||
mock_tt_class.side_effect = mock_team_table_constructor
|
||
|
||
result = await get_all_team_models(
|
||
user_teams=["team1"],
|
||
prisma_client=mock_prisma_client,
|
||
llm_router=mock_router,
|
||
)
|
||
|
||
# gpt-4 from direct team.models
|
||
assert "gpt-4-deploy-1" in result
|
||
assert "team1" in result["gpt-4-deploy-1"]
|
||
|
||
# claude-3 from access group
|
||
assert "claude-3-deploy-1" in result
|
||
assert "team1" in result["claude-3-deploy-1"]
|
||
|
||
# Return type is Dict[str, List[str]]
|
||
assert isinstance(result["gpt-4-deploy-1"], list)
|
||
assert isinstance(result["claude-3-deploy-1"], list)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_deployment_type_mismatch():
|
||
"""
|
||
Test that the _delete_deployment function handles type mismatches correctly.
|
||
Specifically test that models 12345678 and 12345679 are NOT deleted when
|
||
they exist in both combined_id_list (as integers) and router_model_ids (as strings).
|
||
|
||
This test reproduces the bug where type mismatch causes valid models to be deleted.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create mock ProxyConfig instance
|
||
pc = ProxyConfig()
|
||
|
||
pc.get_config = MagicMock(
|
||
return_value={
|
||
"model_list": [
|
||
{
|
||
"model_name": "openai-gpt-4o",
|
||
"litellm_params": {"model": "gpt-4o"},
|
||
"model_info": {"id": 12345678},
|
||
},
|
||
{
|
||
"model_name": "openai-gpt-4o",
|
||
"litellm_params": {"model": "gpt-4o"},
|
||
"model_info": {"id": 12345679},
|
||
},
|
||
]
|
||
}
|
||
)
|
||
|
||
# Mock llm_router with string IDs (this is the source of the type mismatch)
|
||
mock_llm_router = MagicMock()
|
||
mock_llm_router.get_model_ids.return_value = [
|
||
"a96e12e76b36a57cfae57a41288eb41567629cac89b4828c6f7074afc3534695",
|
||
"a40186dd0fdb9b7282380277d7f57044d29de95bfbfcd7f4322b3493702d5cd3",
|
||
"12345678", # String ID
|
||
"12345679", # String ID
|
||
]
|
||
|
||
# Track which deployments were deleted
|
||
deleted_ids = []
|
||
|
||
def mock_delete_deployment(id):
|
||
deleted_ids.append(id)
|
||
return True # Simulate successful deletion
|
||
|
||
mock_llm_router.delete_deployment = MagicMock(side_effect=mock_delete_deployment)
|
||
|
||
# Mock get_config to return empty config (no config models)
|
||
async def mock_get_config(config_file_path):
|
||
return {}
|
||
|
||
pc.get_config = MagicMock(side_effect=mock_get_config)
|
||
|
||
# Patch the global llm_router
|
||
with (
|
||
patch("litellm.proxy.proxy_server.llm_router", mock_llm_router),
|
||
patch("litellm.proxy.proxy_server.user_config_file_path", "test_config.yaml"),
|
||
):
|
||
# Call the function under test
|
||
deleted_count = await pc._delete_deployment(db_models=[])
|
||
|
||
# Assertions: Models 12345678 and 12345679 should NOT be deleted
|
||
# because they exist in combined_id_list (as integers) even though
|
||
# router has them as strings
|
||
|
||
# The function should delete the other 2 models that are not in combined_id_list
|
||
assert deleted_count == 0, f"Expected 0 deletions, got {deleted_count}"
|
||
|
||
# Verify that 12345678 and 12345679 were NOT deleted
|
||
assert (
|
||
"12345678" not in deleted_ids
|
||
), f"Model 12345678 should NOT be deleted. Deleted IDs: {deleted_ids}"
|
||
assert (
|
||
"12345679" not in deleted_ids
|
||
), f"Model 12345679 should NOT be deleted. Deleted IDs: {deleted_ids}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_config_from_file(tmp_path, monkeypatch):
|
||
"""
|
||
Test the _get_config_from_file method of ProxyConfig class.
|
||
Tests various scenarios: valid file, non-existent file, no file path, None config.
|
||
"""
|
||
import yaml
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create a ProxyConfig instance
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: Valid YAML config file exists
|
||
test_config = {
|
||
"model_list": [{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}],
|
||
"general_settings": {"master_key": "sk-test"},
|
||
"router_settings": {"enable_pre_call_checks": True},
|
||
"litellm_settings": {"drop_params": True},
|
||
}
|
||
|
||
config_file = tmp_path / "test_config.yaml"
|
||
with open(config_file, "w") as f:
|
||
yaml.dump(test_config, f)
|
||
|
||
# Clear global user_config_file_path for this test
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
|
||
|
||
result = await proxy_config._get_config_from_file(str(config_file))
|
||
assert result == test_config
|
||
|
||
# Verify that user_config_file_path was set
|
||
from litellm.proxy.proxy_server import user_config_file_path
|
||
|
||
assert user_config_file_path == str(config_file)
|
||
|
||
# Test Case 2: File path provided but file doesn't exist
|
||
non_existent_file = tmp_path / "non_existent.yaml"
|
||
|
||
with pytest.raises(Exception, match=f"Config file not found: {non_existent_file}"):
|
||
await proxy_config._get_config_from_file(str(non_existent_file))
|
||
|
||
# Test Case 3: No file path provided (should return default config)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.user_config_file_path", None)
|
||
|
||
expected_default = {
|
||
"model_list": [],
|
||
"general_settings": {},
|
||
"router_settings": {},
|
||
"litellm_settings": {},
|
||
}
|
||
|
||
result = await proxy_config._get_config_from_file(None)
|
||
assert result == expected_default
|
||
|
||
# Test Case 4: Empty YAML file (should raise exception for None config)
|
||
empty_file = tmp_path / "empty_config.yaml"
|
||
with open(empty_file, "w") as f:
|
||
f.write("") # Write empty content which will result in None when loaded
|
||
|
||
with pytest.raises(Exception, match="Config cannot be None or Empty."):
|
||
await proxy_config._get_config_from_file(str(empty_file))
|
||
|
||
# Test Case 5: Using global user_config_file_path when no config_file_path provided
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", str(config_file)
|
||
)
|
||
|
||
result = await proxy_config._get_config_from_file(None)
|
||
assert result == test_config
|
||
|
||
|
||
def test_normalize_datetime_for_sorting():
|
||
"""
|
||
Test the _normalize_datetime_for_sorting function.
|
||
Tests various scenarios: None values, ISO format strings, datetime objects (naive and aware).
|
||
"""
|
||
from litellm.proxy.proxy_server import _normalize_datetime_for_sorting
|
||
|
||
# Test Case 1: None value
|
||
assert _normalize_datetime_for_sorting(None) is None
|
||
|
||
# Test Case 2: ISO format string with 'Z' suffix
|
||
dt_str_z = "2024-01-15T10:30:00Z"
|
||
result = _normalize_datetime_for_sorting(dt_str_z)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result.year == 2024
|
||
assert result.month == 1
|
||
assert result.day == 15
|
||
assert result.hour == 10
|
||
assert result.minute == 30
|
||
|
||
# Test Case 3: ISO format string without 'Z' suffix (naive)
|
||
dt_str_naive = "2024-01-15T10:30:00"
|
||
result = _normalize_datetime_for_sorting(dt_str_naive)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
|
||
# Test Case 4: ISO format string with timezone offset
|
||
dt_str_tz = "2024-01-15T10:30:00+05:00"
|
||
result = _normalize_datetime_for_sorting(dt_str_tz)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
# Should convert from +05:00 to UTC (subtract 5 hours)
|
||
assert result.hour == 5 # 10:30 - 5 hours = 5:30 UTC
|
||
|
||
# Test Case 5: Naive datetime object
|
||
naive_dt = datetime(2024, 1, 15, 10, 30, 0)
|
||
result = _normalize_datetime_for_sorting(naive_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result.year == 2024
|
||
assert result.month == 1
|
||
assert result.day == 15
|
||
|
||
# Test Case 6: Timezone-aware datetime object (non-UTC)
|
||
from datetime import timedelta
|
||
|
||
aware_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone(timedelta(hours=5)))
|
||
result = _normalize_datetime_for_sorting(aware_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
# Should convert from +05:00 to UTC
|
||
assert result.hour == 5
|
||
|
||
# Test Case 7: UTC-aware datetime object
|
||
utc_dt = datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc)
|
||
result = _normalize_datetime_for_sorting(utc_dt)
|
||
assert result is not None
|
||
assert isinstance(result, datetime)
|
||
assert result.tzinfo == timezone.utc
|
||
assert result == utc_dt
|
||
|
||
# Test Case 8: Invalid string format
|
||
invalid_str = "not-a-date"
|
||
result = _normalize_datetime_for_sorting(invalid_str)
|
||
assert result is None
|
||
|
||
# Test Case 9: Invalid type (should return None)
|
||
result = _normalize_datetime_for_sorting(12345)
|
||
assert result is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_proxy_budget_to_db_only_creates_user_no_keys():
|
||
"""
|
||
Test that _add_proxy_budget_to_db only creates a user and no keys are added.
|
||
|
||
This validates that generate_key_helper_fn is called with table_name="user"
|
||
which should prevent key creation in LiteLLM_VerificationToken table.
|
||
"""
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import litellm
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
|
||
# Set up required litellm settings
|
||
litellm.budget_duration = "30d"
|
||
litellm.max_budget = 100.0
|
||
|
||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||
|
||
# Mock generate_key_helper_fn to capture its call arguments
|
||
mock_generate_key_helper = AsyncMock(
|
||
return_value={
|
||
"user_id": litellm_proxy_budget_name,
|
||
"max_budget": 100.0,
|
||
"budget_duration": "30d",
|
||
"spend": 0,
|
||
"models": [],
|
||
}
|
||
)
|
||
|
||
# Patch generate_key_helper_fn in proxy_server where it's being called from
|
||
with patch(
|
||
"litellm.proxy.proxy_server.generate_key_helper_fn", mock_generate_key_helper
|
||
):
|
||
# Call the function under test
|
||
ProxyStartupEvent._add_proxy_budget_to_db(litellm_proxy_budget_name)
|
||
|
||
# Allow async task to complete
|
||
import asyncio
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
# Verify that generate_key_helper_fn was called
|
||
mock_generate_key_helper.assert_called_once()
|
||
call_args = mock_generate_key_helper.call_args
|
||
|
||
# Verify critical parameters that prevent key creation
|
||
assert call_args.kwargs["request_type"] == "user"
|
||
assert call_args.kwargs["table_name"] == "user"
|
||
assert call_args.kwargs["user_id"] == litellm_proxy_budget_name
|
||
assert call_args.kwargs["max_budget"] == 100.0
|
||
assert call_args.kwargs["budget_duration"] == "30d"
|
||
assert call_args.kwargs["query_type"] == "update_data"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_proxy_budget_to_db_backfills_budget_reset_at():
|
||
"""
|
||
Test that _upsert_proxy_budget_with_reset_at_backfill issues a conditional
|
||
update_many with `WHERE budget_reset_at IS NULL` to backfill the column on
|
||
rows that pre-existed without a reset schedule. Without this, the proxy
|
||
admin row stays at NULL and reset_budget_for_litellm_users never matches
|
||
it (NULL < now() is unknown in SQL), so the global proxy budget never
|
||
resets.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import litellm
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
|
||
litellm.budget_duration = "30d"
|
||
litellm.max_budget = 100.0
|
||
litellm_proxy_budget_name = "litellm-proxy-budget"
|
||
|
||
mock_prisma = MagicMock()
|
||
mock_prisma.db.litellm_usertable.update_many = AsyncMock(return_value={"count": 1})
|
||
|
||
mock_generate_key_helper = AsyncMock(
|
||
return_value={
|
||
"user_id": litellm_proxy_budget_name,
|
||
"max_budget": 100.0,
|
||
"budget_duration": "30d",
|
||
"spend": 0,
|
||
"models": [],
|
||
}
|
||
)
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.generate_key_helper_fn",
|
||
mock_generate_key_helper,
|
||
),
|
||
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
|
||
):
|
||
await ProxyStartupEvent._upsert_proxy_budget_with_reset_at_backfill(
|
||
litellm_proxy_budget_name
|
||
)
|
||
|
||
# Upsert ran with the configured budget
|
||
mock_generate_key_helper.assert_called_once()
|
||
|
||
# Backfill update_many ran with the conditional WHERE
|
||
mock_prisma.db.litellm_usertable.update_many.assert_called_once()
|
||
backfill_call = mock_prisma.db.litellm_usertable.update_many.call_args
|
||
assert backfill_call.kwargs["where"]["user_id"] == litellm_proxy_budget_name
|
||
assert backfill_call.kwargs["where"]["budget_reset_at"] is None
|
||
|
||
# The backfilled value must be a real future datetime — anything else and
|
||
# reset_budget_for_litellm_users would still skip the row.
|
||
from datetime import datetime, timezone
|
||
|
||
backfilled_reset_at = backfill_call.kwargs["data"]["budget_reset_at"]
|
||
assert isinstance(backfilled_reset_at, datetime)
|
||
assert backfilled_reset_at > datetime.now(timezone.utc)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_custom_ui_sso_sign_in_handler_config_loading():
|
||
"""
|
||
Test that custom_ui_sso_sign_in_handler from config gets properly loaded into the global variable
|
||
"""
|
||
import tempfile
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import yaml
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create a test config with custom_ui_sso_sign_in_handler
|
||
test_config = {
|
||
"general_settings": {
|
||
"custom_ui_sso_sign_in_handler": "custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler"
|
||
},
|
||
"model_list": [],
|
||
"router_settings": {},
|
||
"litellm_settings": {},
|
||
}
|
||
|
||
# Create temporary config file
|
||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||
yaml.dump(test_config, f)
|
||
config_file_path = f.name
|
||
|
||
# Mock the get_instance_fn to return a mock handler
|
||
mock_custom_handler = MagicMock()
|
||
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server.get_instance_fn",
|
||
return_value=mock_custom_handler,
|
||
) as mock_get_instance:
|
||
# Create ProxyConfig instance and load config
|
||
proxy_config = ProxyConfig()
|
||
# Create a mock router since load_config requires it
|
||
mock_router = MagicMock()
|
||
await proxy_config.load_config(
|
||
router=mock_router, config_file_path=config_file_path
|
||
)
|
||
|
||
# Verify get_instance_fn was called with correct parameters
|
||
mock_get_instance.assert_called_with(
|
||
value="custom_hooks.custom_ui_sso_hook.custom_ui_sso_sign_in_handler",
|
||
config_file_path=config_file_path,
|
||
)
|
||
|
||
# Verify the global variable was set
|
||
from litellm.proxy.proxy_server import user_custom_ui_sso_sign_in_handler
|
||
|
||
assert user_custom_ui_sso_sign_in_handler == mock_custom_handler
|
||
|
||
finally:
|
||
# Clean up temporary file
|
||
import os
|
||
|
||
os.unlink(config_file_path)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_config_max_budget_env_var_coerced_to_float(tmp_path, monkeypatch):
|
||
"""
|
||
max_budget configured as os.environ/MAX_BUDGET resolves to a string;
|
||
load_config must coerce it to float so the startup check
|
||
`litellm.max_budget > 0` doesn't raise TypeError.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
monkeypatch.setenv("MAX_BUDGET", "10")
|
||
test_config = {
|
||
"model_list": [],
|
||
"litellm_settings": {"max_budget": "os.environ/MAX_BUDGET"},
|
||
}
|
||
config_file = tmp_path / "config.yaml"
|
||
config_file.write_text(yaml.dump(test_config))
|
||
|
||
original_max_budget = litellm.max_budget
|
||
try:
|
||
proxy_config = ProxyConfig()
|
||
await proxy_config.load_config(
|
||
router=MagicMock(), config_file_path=str(config_file)
|
||
)
|
||
assert isinstance(litellm.max_budget, float)
|
||
assert litellm.max_budget == 10.0
|
||
assert litellm.max_budget > 0
|
||
finally:
|
||
litellm.max_budget = original_max_budget
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_direct_and_os_environ():
|
||
"""
|
||
Test _load_environment_variables method with direct values and os.environ/ prefixed values
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test config with both direct values and os.environ/ prefixed values
|
||
test_config = {
|
||
"environment_variables": {
|
||
"DIRECT_VAR": "direct_value",
|
||
"NUMERIC_VAR": 12345,
|
||
"BOOL_VAR": True,
|
||
"SECRET_VAR": "os.environ/ACTUAL_SECRET_VAR",
|
||
}
|
||
}
|
||
|
||
# Mock get_secret_str to return a resolved value
|
||
mock_secret_value = "resolved_secret_value"
|
||
|
||
with patch(
|
||
"litellm.proxy.proxy_server.get_secret_str", return_value=mock_secret_value
|
||
) as mock_get_secret:
|
||
with patch.dict(
|
||
os.environ, {}, clear=False
|
||
): # Don't clear existing env vars, just track changes
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
# Verify direct environment variables were set correctly
|
||
assert os.environ["DIRECT_VAR"] == "direct_value"
|
||
assert os.environ["NUMERIC_VAR"] == "12345" # Should be converted to string
|
||
assert os.environ["BOOL_VAR"] == "True" # Should be converted to string
|
||
|
||
# Verify os.environ/ prefixed variable was resolved and set
|
||
assert os.environ["SECRET_VAR"] == mock_secret_value
|
||
|
||
# Verify get_secret_str was called with the correct value
|
||
mock_get_secret.assert_called_once_with(
|
||
secret_name="os.environ/ACTUAL_SECRET_VAR"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_litellm_license_and_edge_cases():
|
||
"""
|
||
Test _load_environment_variables method with LITELLM_LICENSE special handling and edge cases
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: LITELLM_LICENSE in environment_variables
|
||
test_config_with_license = {
|
||
"environment_variables": {
|
||
"LITELLM_LICENSE": "test_license_key",
|
||
"OTHER_VAR": "other_value",
|
||
}
|
||
}
|
||
|
||
# Mock _license_check
|
||
mock_license_check = MagicMock()
|
||
mock_license_check.is_premium.return_value = True
|
||
|
||
with patch("litellm.proxy.proxy_server._license_check", mock_license_check):
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config_with_license)
|
||
|
||
# Verify LITELLM_LICENSE was set in environment
|
||
assert os.environ["LITELLM_LICENSE"] == "test_license_key"
|
||
|
||
# Verify license check was updated
|
||
assert mock_license_check.license_str == "test_license_key"
|
||
mock_license_check.is_premium.assert_called_once()
|
||
|
||
# Test Case 2: No environment_variables in config
|
||
test_config_no_env_vars = {}
|
||
|
||
# This should not raise any errors and should return without doing anything
|
||
result = proxy_config._load_environment_variables(test_config_no_env_vars)
|
||
assert result is None # Method returns None
|
||
|
||
# Test Case 3: environment_variables is None
|
||
test_config_none_env_vars = {"environment_variables": None}
|
||
|
||
# This should not raise any errors and should return without doing anything
|
||
result = proxy_config._load_environment_variables(test_config_none_env_vars)
|
||
assert result is None # Method returns None
|
||
|
||
# Test Case 4: os.environ/ prefix but get_secret_str returns None
|
||
test_config_secret_none = {
|
||
"environment_variables": {"FAILED_SECRET": "os.environ/NONEXISTENT_SECRET"}
|
||
}
|
||
|
||
with patch("litellm.proxy.proxy_server.get_secret_str", return_value=None):
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
# Call the method under test
|
||
proxy_config._load_environment_variables(test_config_secret_none)
|
||
|
||
# Verify that the environment variable was not set when secret resolution fails
|
||
assert "FAILED_SECRET" not in os.environ
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_blocks_dangerous_keys():
|
||
"""
|
||
Test that _load_environment_variables rejects dangerous env var keys
|
||
like PATH, LD_PRELOAD, PYTHONPATH, etc.
|
||
"""
|
||
import logging
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
original_path = os.environ.get("PATH", "")
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"PATH": "/tmp/evil",
|
||
"LD_PRELOAD": "/tmp/evil.so",
|
||
"PYTHONPATH": "/tmp/evil",
|
||
"SAFE_CUSTOM_VAR": "safe_value",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
# Blocked keys should not be set to the attacker value
|
||
assert os.environ.get("PATH") != "/tmp/evil"
|
||
assert (
|
||
"LD_PRELOAD" not in os.environ or os.environ["LD_PRELOAD"] != "/tmp/evil.so"
|
||
)
|
||
assert os.environ.get("PYTHONPATH") != "/tmp/evil"
|
||
|
||
# Safe keys should still be set
|
||
assert os.environ["SAFE_CUSTOM_VAR"] == "safe_value"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_allows_proxy_keys():
|
||
"""
|
||
Test that HTTP_PROXY/HTTPS_PROXY are allowed since they are commonly used
|
||
in corporate environments to route outbound API calls.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"HTTP_PROXY": "http://corp-proxy:8080",
|
||
"HTTPS_PROXY": "http://corp-proxy:8080",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
assert os.environ["HTTP_PROXY"] == "http://corp-proxy:8080"
|
||
assert os.environ["HTTPS_PROXY"] == "http://corp-proxy:8080"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_environment_variables_blocks_no_proxy():
|
||
"""
|
||
Test that NO_PROXY/no_proxy are blocked to prevent bypassing proxy-based
|
||
network monitoring.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
test_config = {
|
||
"environment_variables": {
|
||
"NO_PROXY": "internal-service",
|
||
"no_proxy": "internal-service",
|
||
}
|
||
}
|
||
|
||
with patch.dict(os.environ, {}, clear=False):
|
||
proxy_config._load_environment_variables(test_config)
|
||
|
||
assert os.environ.get("NO_PROXY") != "internal-service"
|
||
assert os.environ.get("no_proxy") != "internal-service"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_write_config_to_file(monkeypatch):
|
||
"""
|
||
Do not write config to file if store_model_in_db is True
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Set store_model_in_db to True
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", True)
|
||
|
||
# Mock prisma_client to not be None (so DB path is taken)
|
||
mock_prisma_client = AsyncMock()
|
||
mock_prisma_client.insert_data = AsyncMock()
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||
|
||
# Mock general_settings
|
||
mock_general_settings = {"store_model_in_db": True}
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||
)
|
||
|
||
# Mock user_config_file_path
|
||
test_config_path = "/tmp/test_config.yaml"
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
|
||
)
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the open function to track if file writing is attempted
|
||
mock_file_open = mock_open()
|
||
|
||
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
|
||
# Call save_config with test data
|
||
test_config = {"key": "value", "model_list": ["model1", "model2"]}
|
||
await proxy_config.save_config(new_config=test_config)
|
||
|
||
# Verify that file was NOT opened for writing (since store_model_in_db=True)
|
||
mock_file_open.assert_not_called()
|
||
mock_yaml_dump.assert_not_called()
|
||
|
||
# Verify that database insert was called instead
|
||
mock_prisma_client.insert_data.assert_called_once()
|
||
|
||
# Verify the config passed to DB has model_list removed
|
||
call_args = mock_prisma_client.insert_data.call_args
|
||
assert call_args.kwargs["data"] == {
|
||
"key": "value"
|
||
} # model_list should be popped
|
||
assert call_args.kwargs["table_name"] == "config"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_write_config_to_file_when_store_model_in_db_false(monkeypatch):
|
||
"""
|
||
Test that config IS written to file when store_model_in_db is False
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Set store_model_in_db to False
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.store_model_in_db", False)
|
||
|
||
# Mock prisma_client to be None (so file path is taken)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None)
|
||
|
||
# Mock general_settings
|
||
mock_general_settings = {"store_model_in_db": False}
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||
)
|
||
|
||
# Mock user_config_file_path
|
||
test_config_path = "/tmp/test_config.yaml"
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.user_config_file_path", test_config_path
|
||
)
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the open function and yaml.dump
|
||
mock_file_open = mock_open()
|
||
|
||
with patch("builtins.open", mock_file_open), patch("yaml.dump") as mock_yaml_dump:
|
||
# Call save_config with test data
|
||
test_config = {"key": "value", "other_key": "other_value"}
|
||
await proxy_config.save_config(new_config=test_config)
|
||
|
||
# Verify that file WAS opened for writing (since store_model_in_db=False)
|
||
mock_file_open.assert_called_once_with(f"{test_config_path}", "w")
|
||
|
||
# Verify yaml.dump was called with the config
|
||
mock_yaml_dump.assert_called_once_with(
|
||
test_config,
|
||
mock_file_open.return_value.__enter__.return_value,
|
||
default_flow_style=False,
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_midstream_error():
|
||
"""
|
||
Test async_data_generator handles midstream error from async_post_call_streaming_hook
|
||
Specifically testing the case where Azure Content Safety Guardrail returns an error
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
# Create mock objects
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
# Mock response chunks - simulating normal streaming that gets interrupted
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
{"choices": [{"delta": {"content": " this"}}]},
|
||
]
|
||
|
||
# Mock the proxy_logging_obj
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
# Mock async_post_call_streaming_iterator_hook to yield chunks
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
|
||
# Mock async_post_call_streaming_hook to return error on third chunk
|
||
def mock_streaming_hook(*args, **kwargs):
|
||
chunk = kwargs.get("response")
|
||
# Return error message for the third chunk (simulating guardrail trigger)
|
||
if chunk == mock_chunks[2]:
|
||
return 'data: {"error": {"error": "Azure Content Safety Guardrail: Hate crossed severity 2, Got severity: 2"}}'
|
||
# Return normal chunks for first two
|
||
return chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=mock_streaming_hook
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
# Mock the global proxy_logging_obj
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
# Create a mock response object
|
||
mock_response = MagicMock()
|
||
|
||
# Collect all yielded data from the generator
|
||
yielded_data = []
|
||
try:
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
except Exception as e:
|
||
# If there's an exception, that's also part of what we want to test
|
||
pass
|
||
|
||
# Verify the results
|
||
assert (
|
||
len(yielded_data) >= 3
|
||
), f"Expected at least 3 chunks, got {len(yielded_data)}: {yielded_data}"
|
||
|
||
# First two chunks should be normal data
|
||
assert yielded_data[0].startswith(
|
||
"data: "
|
||
), f"First chunk should start with 'data: ', got: {yielded_data[0]}"
|
||
assert yielded_data[1].startswith(
|
||
"data: "
|
||
), f"Second chunk should start with 'data: ', got: {yielded_data[1]}"
|
||
|
||
# The error message should be yielded
|
||
error_found = False
|
||
done_found = False
|
||
|
||
for data in yielded_data:
|
||
if "Azure Content Safety Guardrail: Hate crossed severity 2" in data:
|
||
error_found = True
|
||
if "data: [DONE]" in data:
|
||
done_found = True
|
||
|
||
assert (
|
||
error_found
|
||
), f"Error message should be found in yielded data. Got: {yielded_data}"
|
||
assert done_found, f"[DONE] message should be found at the end. Got: {yielded_data}"
|
||
|
||
# Verify that the streaming hook was called for each chunk
|
||
assert mock_proxy_logging_obj.async_post_call_streaming_hook.call_count == len(
|
||
mock_chunks
|
||
)
|
||
|
||
# Verify that post_call_failure_hook was NOT called (since this is not an exception case)
|
||
mock_proxy_logging_obj.post_call_failure_hook.assert_not_called()
|
||
|
||
|
||
def _has_nested_none_values(obj, path="root"):
|
||
"""
|
||
Recursively check if an object contains nested None values.
|
||
|
||
Args:
|
||
obj: The object to check
|
||
path: Current path in the object tree (for debugging)
|
||
|
||
Returns:
|
||
List of paths where None values were found
|
||
"""
|
||
none_paths = []
|
||
|
||
if obj is None:
|
||
none_paths.append(path)
|
||
elif isinstance(obj, dict):
|
||
for key, value in obj.items():
|
||
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
|
||
elif isinstance(obj, (list, tuple)):
|
||
for i, item in enumerate(obj):
|
||
none_paths.extend(_has_nested_none_values(item, f"{path}[{i}]"))
|
||
elif hasattr(obj, "__dict__"):
|
||
# Handle object attributes
|
||
for key, value in obj.__dict__.items():
|
||
if not key.startswith("_"): # Skip private attributes
|
||
none_paths.extend(_has_nested_none_values(value, f"{path}.{key}"))
|
||
|
||
return none_paths
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_completion_result_no_nested_none_values():
|
||
"""
|
||
Test that chat_completion result doesn't have nested None values when using exclude_none=True
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from fastapi import Request, Response
|
||
from pydantic import BaseModel
|
||
|
||
import litellm
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import chat_completion
|
||
|
||
# Create a mock ModelResponse with nested None values
|
||
mock_model_response = litellm.ModelResponse()
|
||
mock_model_response.id = "test-id"
|
||
mock_model_response.model = "gpt-3.5-turbo"
|
||
mock_model_response.object = "chat.completion"
|
||
mock_model_response.created = 1234567890
|
||
|
||
# Create message with None values that should be excluded
|
||
mock_message = litellm.Message(
|
||
content="Hello, world!",
|
||
role="assistant",
|
||
function_call=None, # This should be excluded
|
||
tool_calls=None, # This should be excluded
|
||
audio=None, # This should be excluded
|
||
reasoning_content=None, # This should be excluded
|
||
thinking_blocks=None, # This should be excluded
|
||
annotations=None, # This should be excluded
|
||
)
|
||
|
||
# Create choice with potential None values
|
||
mock_choice = litellm.Choices(
|
||
finish_reason="stop",
|
||
index=0,
|
||
message=mock_message,
|
||
logprobs=None, # This should be excluded when exclude_none=True
|
||
)
|
||
|
||
mock_model_response.choices = [mock_choice]
|
||
setattr(
|
||
mock_model_response,
|
||
"usage",
|
||
litellm.Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||
)
|
||
|
||
# Verify the mock has None values before serialization
|
||
raw_dict = mock_model_response.model_dump()
|
||
none_paths_before = _has_nested_none_values(raw_dict)
|
||
assert (
|
||
len(none_paths_before) > 0
|
||
), "Mock should have None values before exclude_none=True"
|
||
|
||
# Mock the request processing to return our mock response
|
||
mock_base_processor = MagicMock()
|
||
mock_base_processor.base_process_llm_request = AsyncMock(
|
||
return_value=mock_model_response
|
||
)
|
||
|
||
# Mock other dependencies
|
||
mock_request = MagicMock(spec=Request)
|
||
mock_response = MagicMock(spec=Response)
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server._read_request_body",
|
||
return_value={"model": "gpt-3.5-turbo", "messages": []},
|
||
),
|
||
patch(
|
||
"litellm.proxy.proxy_server.ProxyBaseLLMRequestProcessing",
|
||
return_value=mock_base_processor,
|
||
),
|
||
):
|
||
# Call the chat_completion function
|
||
result = await chat_completion(
|
||
request=mock_request,
|
||
fastapi_response=mock_response,
|
||
user_api_key_dict=mock_user_api_key_dict,
|
||
)
|
||
|
||
# Verify the result is a dict (since isinstance(result, BaseModel) was True)
|
||
assert isinstance(result, dict), f"Expected dict result, got {type(result)}"
|
||
|
||
# Check that there are no nested None values in the result
|
||
none_paths_after = _has_nested_none_values(result)
|
||
assert (
|
||
len(none_paths_after) == 0
|
||
), f"Result should not contain nested None values. Found None at: {none_paths_after}"
|
||
|
||
# Verify essential fields are present
|
||
assert "id" in result
|
||
assert "model" in result
|
||
assert "object" in result
|
||
assert "created" in result
|
||
assert "choices" in result
|
||
assert "usage" in result
|
||
|
||
# Verify that the choices contain the expected message content
|
||
assert len(result["choices"]) == 1
|
||
assert result["choices"][0]["message"]["content"] == "Hello, world!"
|
||
assert result["choices"][0]["message"]["role"] == "assistant"
|
||
|
||
# Verify that None fields were excluded (should not be present in the dict)
|
||
message = result["choices"][0]["message"]
|
||
excluded_fields = [
|
||
"function_call",
|
||
"tool_calls",
|
||
"audio",
|
||
"reasoning_content",
|
||
"thinking_blocks",
|
||
"annotations",
|
||
]
|
||
for field in excluded_fields:
|
||
assert (
|
||
field not in message
|
||
), f"Field '{field}' should be excluded when it's None"
|
||
|
||
|
||
# ============================================================================
|
||
# Price Data Reload Tests
|
||
# ============================================================================
|
||
|
||
|
||
class TestPriceDataReloadAPI:
|
||
"""Test cases for price data reload API endpoints"""
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with authentication"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
# Mock admin user authentication
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
def test_reload_model_cost_map_admin_access(self, client_with_auth):
|
||
"""Test that admin users can access the reload endpoint"""
|
||
# Save the original model_cost so the endpoint's direct assignment
|
||
# (litellm.model_cost = new_model_cost_map) does not contaminate
|
||
# subsequent tests running in the same worker process.
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {
|
||
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
|
||
}
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=None
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
assert "models_count" in data
|
||
# The new implementation immediately reloads and returns the count
|
||
assert (
|
||
"Price data reloaded successfully! 1 models updated."
|
||
in data["message"]
|
||
)
|
||
assert data["models_count"] == 1
|
||
finally:
|
||
# Restore the full model cost map so subsequent tests are not affected
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_reload_model_cost_map_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot access the reload endpoint"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_public_access(self, client_no_auth):
|
||
"""Test that the model cost map endpoint is publicly accessible"""
|
||
with patch(
|
||
"litellm.model_cost", {"gpt-3.5-turbo": {"input_cost_per_token": 0.001}}
|
||
):
|
||
response = client_no_auth.get("/public/litellm_model_cost_map")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "gpt-3.5-turbo" in data
|
||
|
||
def test_reload_model_cost_map_error_handling(self, client_with_auth):
|
||
"""Test error handling in the reload endpoint"""
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.side_effect = Exception("Network error")
|
||
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
|
||
assert (
|
||
response.status_code == 500
|
||
) # The new implementation immediately reloads and fails on error
|
||
data = response.json()
|
||
assert "Failed to reload model cost map" in data["detail"]
|
||
|
||
def test_schedule_model_cost_map_reload_admin_access(self, client_with_auth):
|
||
"""Test that admin users can schedule periodic reload"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database upsert
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert data["interval_hours"] == 6
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
|
||
def test_schedule_model_cost_map_reload_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot schedule periodic reload"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=6")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_schedule_model_cost_map_reload_invalid_hours(self, client_with_auth):
|
||
"""Test that invalid hours parameter is rejected"""
|
||
response = client_with_auth.post("/schedule/model_cost_map_reload?hours=0")
|
||
|
||
assert response.status_code == 400
|
||
data = response.json()
|
||
assert "Hours must be greater than 0" in data["detail"]
|
||
|
||
def test_cancel_model_cost_map_reload_admin_access(self, client_with_auth):
|
||
"""Test that admin users can cancel periodic reload"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database delete
|
||
mock_prisma.db.litellm_config.delete = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.delete("/schedule/model_cost_map_reload")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["status"] == "success"
|
||
assert "message" in data
|
||
assert "timestamp" in data
|
||
|
||
def test_cancel_model_cost_map_reload_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot cancel periodic reload"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.delete("/schedule/model_cost_map_reload")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_reload_status_admin_access(self, client_with_auth):
|
||
"""Test that admin users can get reload status"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock database config record
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_config
|
||
)
|
||
|
||
# Mock the last reload time and current time
|
||
with patch(
|
||
"litellm.proxy.proxy_server.last_model_cost_map_reload",
|
||
"2024-01-01T06:00:00",
|
||
):
|
||
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
|
||
# Mock current time to be 1 hour after last reload
|
||
mock_datetime.utcnow.return_value = datetime(2024, 1, 1, 7, 0, 0)
|
||
mock_datetime.fromisoformat = datetime.fromisoformat
|
||
|
||
response = client_with_auth.get(
|
||
"/schedule/model_cost_map_reload/status"
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == True
|
||
assert data["interval_hours"] == 6
|
||
assert data["last_run"] == "2024-01-01T06:00:00"
|
||
assert data["next_run"] == "2024-01-01T12:00:00"
|
||
|
||
def test_get_model_cost_map_reload_status_non_admin_access(self, client_with_auth):
|
||
"""Test that non-admin users cannot get reload status"""
|
||
# Mock non-admin user
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = "user" # Non-admin role
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 403
|
||
data = response.json()
|
||
assert "Access denied" in data["detail"]
|
||
assert "Admin role required" in data["detail"]
|
||
|
||
def test_get_model_cost_map_reload_status_no_config(self, client_with_auth):
|
||
"""Test that status returns not scheduled when no config exists"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == False
|
||
assert data["interval_hours"] == None
|
||
assert data["last_run"] == None
|
||
assert data["next_run"] == None
|
||
|
||
def test_get_model_cost_map_reload_status_no_interval(self, client_with_auth):
|
||
"""Test that status returns not scheduled when no interval is configured"""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Mock config with no interval
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": None, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_config
|
||
)
|
||
|
||
response = client_with_auth.get("/schedule/model_cost_map_reload/status")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["scheduled"] == False
|
||
assert data["interval_hours"] == None
|
||
assert data["last_run"] == None
|
||
assert data["next_run"] == None
|
||
|
||
|
||
class TestPriceDataReloadIntegration:
|
||
"""Integration tests for the complete price data reload feature"""
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _flush_litellm_config_cache(self):
|
||
from litellm.proxy.utils import litellm_config_cache
|
||
|
||
litellm_config_cache.flush_cache()
|
||
yield
|
||
litellm_config_cache.flush_cache()
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with authentication"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
# Mock admin user authentication
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
def test_complete_reload_flow(self, client_with_auth):
|
||
"""Test the complete reload flow from API to model cost update"""
|
||
# Mock the model cost map
|
||
mock_cost_map = {
|
||
"gpt-3.5-turbo": {
|
||
"input_cost_per_token": 0.001,
|
||
"output_cost_per_token": 0.002,
|
||
},
|
||
"gpt-4": {"input_cost_per_token": 0.03, "output_cost_per_token": 0.06},
|
||
}
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = mock_cost_map
|
||
|
||
# Mock the database connection
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=None
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Test reload endpoint
|
||
response = client_with_auth.post("/reload/model_cost_map")
|
||
assert response.status_code == 200
|
||
|
||
# Test get endpoint
|
||
response = client_with_auth.get("/public/litellm_model_cost_map")
|
||
assert response.status_code == 200
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_distributed_reload_check_function(self):
|
||
"""Test the _check_and_reload_model_cost_map function"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
from litellm.proxy.utils import litellm_config_cache
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test case 1: No config in database
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=None)
|
||
# _check_and_reload_model_cost_map routes through get_config_param,
|
||
# which calls prisma.get_generic_data on a cache miss.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=None)
|
||
|
||
# Should return early without reloading
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Test case 2: Config with interval but not time to reload
|
||
litellm_config_cache.flush_cache()
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
|
||
# Mock current time and last reload time
|
||
with patch(
|
||
"litellm.proxy.proxy_server.last_model_cost_map_reload",
|
||
"2024-01-01T06:00:00",
|
||
):
|
||
with patch("litellm.proxy.proxy_server.datetime") as mock_datetime:
|
||
mock_datetime.utcnow.return_value = datetime(
|
||
2024, 1, 1, 7, 0, 0
|
||
) # 1 hour later
|
||
|
||
# Should not reload (only 1 hour passed, need 6)
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Test case 3: Config with force reload
|
||
litellm_config_cache.flush_cache()
|
||
mock_config.param_value = {"interval_hours": 6, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {
|
||
"gpt-3.5-turbo": {"input_cost_per_token": 0.001}
|
||
}
|
||
|
||
# Should reload due to force flag
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Verify force_reload was reset to False
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
# The param_value is now a JSON string, so we need to parse it
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict.get("interval_hours") == 6
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_distributed_reload_preserves_interval_hours(self):
|
||
"""Test that _check_and_reload_model_cost_map preserves interval_hours after reload.
|
||
|
||
Regression test: the update branch of the upsert was previously dropping
|
||
interval_hours, causing scheduled reloads to self-destruct after first execution.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_prisma = MagicMock()
|
||
|
||
# Set up config with interval_hours=24 and force_reload=True to trigger reload
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 24, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
# _check_and_reload_model_cost_map now reads through get_generic_data.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
|
||
|
||
asyncio.run(proxy_config._check_and_reload_model_cost_map(mock_prisma))
|
||
|
||
# Verify the upsert update branch preserves interval_hours
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict["interval_hours"] == 24, (
|
||
"interval_hours must be preserved in the update branch; "
|
||
"dropping it causes the schedule to self-destruct"
|
||
)
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_manual_reload_preserves_interval_hours(self):
|
||
"""Test that manual reload via /reload/model_cost_map preserves existing interval_hours.
|
||
|
||
Regression test: the manual reload endpoint was overwriting param_value with
|
||
only force_reload=True, dropping any existing interval_hours schedule.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
client = TestClient(app)
|
||
|
||
original_model_cost = litellm.model_cost.copy()
|
||
try:
|
||
with patch(
|
||
"litellm.litellm_core_utils.get_model_cost_map.get_model_cost_map"
|
||
) as mock_get_map:
|
||
mock_get_map.return_value = {"gpt-4": {"input_cost_per_token": 0.001}}
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Simulate existing config with a schedule
|
||
mock_existing = MagicMock()
|
||
mock_existing.param_value = {
|
||
"interval_hours": 12,
|
||
"force_reload": False,
|
||
}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_existing
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client.post("/reload/model_cost_map")
|
||
assert response.status_code == 200
|
||
|
||
# Verify interval_hours was preserved in the upsert
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == True
|
||
assert param_value_dict["interval_hours"] == 12, (
|
||
"interval_hours must be preserved when manual reload sets force_reload; "
|
||
"dropping it destroys any existing schedule"
|
||
)
|
||
finally:
|
||
litellm.model_cost = original_model_cost
|
||
_invalidate_model_cost_lowercase_map()
|
||
|
||
def test_anthropic_beta_headers_reload_preserves_interval_hours(self):
|
||
"""Test that _check_and_reload_anthropic_beta_headers preserves interval_hours after reload.
|
||
|
||
Regression test: the update branch of the upsert was dropping interval_hours,
|
||
identical to the model cost map bug.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_prisma = MagicMock()
|
||
|
||
# Set up config with interval_hours=12 and force_reload=True to trigger reload
|
||
mock_config = MagicMock()
|
||
mock_config.param_value = {"interval_hours": 12, "force_reload": True}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(return_value=mock_config)
|
||
# _check_and_reload_anthropic_beta_headers now reads through get_generic_data.
|
||
mock_prisma.get_generic_data = AsyncMock(return_value=mock_config)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
with patch(
|
||
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
|
||
) as mock_reload:
|
||
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
|
||
|
||
asyncio.run(
|
||
proxy_config._check_and_reload_anthropic_beta_headers(mock_prisma)
|
||
)
|
||
|
||
# Verify the upsert update branch preserves interval_hours
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == False
|
||
assert param_value_dict["interval_hours"] == 12, (
|
||
"interval_hours must be preserved in the update branch; "
|
||
"dropping it causes the schedule to self-destruct"
|
||
)
|
||
|
||
def test_anthropic_beta_headers_manual_reload_preserves_interval_hours(self):
|
||
"""Test that manual reload via /reload/anthropic_beta_headers preserves existing interval_hours.
|
||
|
||
Regression test: the manual reload endpoint was overwriting param_value with
|
||
only force_reload=True, dropping any existing interval_hours schedule.
|
||
"""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
client = TestClient(app)
|
||
|
||
with patch(
|
||
"litellm.anthropic_beta_headers_manager.reload_beta_headers_config"
|
||
) as mock_reload:
|
||
mock_reload.return_value = {"anthropic": {"beta_header": "test-value"}}
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
# Simulate existing config with a schedule
|
||
mock_existing = MagicMock()
|
||
mock_existing.param_value = {"interval_hours": 8, "force_reload": False}
|
||
mock_prisma.db.litellm_config.find_unique = AsyncMock(
|
||
return_value=mock_existing
|
||
)
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
response = client.post("/reload/anthropic_beta_headers")
|
||
assert response.status_code == 200
|
||
|
||
# Verify interval_hours was preserved in the upsert
|
||
mock_prisma.db.litellm_config.upsert.assert_called()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
param_value_json = call_args[1]["data"]["update"]["param_value"]
|
||
param_value_dict = json.loads(param_value_json)
|
||
assert param_value_dict["force_reload"] == True
|
||
assert param_value_dict["interval_hours"] == 8, (
|
||
"interval_hours must be preserved when manual reload sets force_reload; "
|
||
"dropping it destroys any existing schedule"
|
||
)
|
||
|
||
def test_config_file_parsing(self):
|
||
"""Test parsing of config file with reload settings"""
|
||
config_content = """
|
||
general_settings:
|
||
master_key: sk-1234
|
||
model_cost_map_reload_interval: 21600
|
||
|
||
model_list:
|
||
- model_name: gpt-3.5-turbo
|
||
litellm_params:
|
||
model: gpt-3.5-turbo
|
||
- model_name: gpt-4
|
||
litellm_params:
|
||
model: gpt-4
|
||
"""
|
||
|
||
# Parse the config
|
||
config = yaml.safe_load(config_content)
|
||
|
||
# Verify the reload setting is present
|
||
assert "general_settings" in config
|
||
assert "model_cost_map_reload_interval" in config["general_settings"]
|
||
assert config["general_settings"]["model_cost_map_reload_interval"] == 21600
|
||
|
||
# Verify models are present
|
||
assert "model_list" in config
|
||
assert len(config["model_list"]) == 2
|
||
|
||
def test_database_config_storage(self):
|
||
"""Test that configuration is properly stored in database"""
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test the database upsert call that would be made by the schedule endpoint
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Simulate the database call that the schedule endpoint would make
|
||
asyncio.run(
|
||
mock_prisma.db.litellm_config.upsert(
|
||
where={"param_name": "model_cost_map_reload_config"},
|
||
data={
|
||
"create": {
|
||
"param_name": "model_cost_map_reload_config",
|
||
"param_value": {"interval_hours": 6, "force_reload": False},
|
||
},
|
||
"update": {
|
||
"param_value": {"interval_hours": 6, "force_reload": False}
|
||
},
|
||
},
|
||
)
|
||
)
|
||
|
||
# Verify database upsert was called with correct data
|
||
mock_prisma.db.litellm_config.upsert.assert_called_once()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
assert call_args[1]["where"]["param_name"] == "model_cost_map_reload_config"
|
||
assert call_args[1]["data"]["create"]["param_value"]["interval_hours"] == 6
|
||
assert call_args[1]["data"]["create"]["param_value"]["force_reload"] == False
|
||
|
||
def test_manual_reload_force_flag(self):
|
||
"""Test that manual reload sets force flag correctly"""
|
||
# Mock prisma client
|
||
mock_prisma = MagicMock()
|
||
|
||
# Test the database upsert call that would be made by the manual reload endpoint
|
||
mock_prisma.db.litellm_config.upsert = AsyncMock(return_value=None)
|
||
|
||
# Simulate the database call that the manual reload endpoint would make
|
||
asyncio.run(
|
||
mock_prisma.db.litellm_config.upsert(
|
||
where={"param_name": "model_cost_map_reload_config"},
|
||
data={
|
||
"create": {
|
||
"param_name": "model_cost_map_reload_config",
|
||
"param_value": {"interval_hours": None, "force_reload": True},
|
||
},
|
||
"update": {"param_value": {"force_reload": True}},
|
||
},
|
||
)
|
||
)
|
||
|
||
# Verify force_reload flag was set
|
||
mock_prisma.db.litellm_config.upsert.assert_called_once()
|
||
call_args = mock_prisma.db.litellm_config.upsert.call_args
|
||
assert call_args[1]["data"]["update"]["param_value"]["force_reload"] == True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_from_db_config_merge_logic():
|
||
"""
|
||
Test the _add_router_settings_from_db_config method's merge logic.
|
||
|
||
This tests how router settings from config file and database are combined,
|
||
including scenarios where nested dictionaries should be properly merged.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
# Create ProxyConfig instance
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock router
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Test Case 1: Both config and DB settings exist - should merge them
|
||
config_data = {
|
||
"router_settings": {
|
||
"routing_strategy": "usage-based-routing",
|
||
"model_group_alias": {"gpt-4": "openai-gpt-4"},
|
||
"enable_pre_call_checks": True,
|
||
"timeout": 30,
|
||
"nested_config": {"setting1": "config_value1", "setting2": "config_value2"},
|
||
}
|
||
}
|
||
|
||
# Mock database config record
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {
|
||
"routing_strategy": "least-busy", # This should override config value
|
||
"retry_delay": 2, # This is new, should be added
|
||
"nested_config": {
|
||
"setting2": "db_value2", # This should override config value
|
||
"setting3": "db_value3", # This is new, should be added
|
||
},
|
||
}
|
||
|
||
# Mock prisma client
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
# Call the method under test
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Verify find_first was called with correct parameters
|
||
mock_prisma_client.db.litellm_config.find_first.assert_called_once_with(
|
||
where={"param_name": "router_settings"}
|
||
)
|
||
|
||
# Verify update_settings was called
|
||
mock_router.update_settings.assert_called_once()
|
||
|
||
# Get the actual settings passed to update_settings
|
||
call_args = mock_router.update_settings.call_args
|
||
combined_settings = call_args[1] # kwargs
|
||
|
||
# Verify the merge results
|
||
# DB values should override config values
|
||
assert combined_settings["routing_strategy"] == "least-busy"
|
||
|
||
# Config-only values should be preserved
|
||
assert combined_settings["model_group_alias"] == {"gpt-4": "openai-gpt-4"}
|
||
assert combined_settings["enable_pre_call_checks"] == True
|
||
assert combined_settings["timeout"] == 30
|
||
|
||
# DB-only values should be added
|
||
assert combined_settings["retry_delay"] == 2
|
||
|
||
# Nested dictionaries should be merged (but this is shallow merge)
|
||
expected_nested = {
|
||
"setting1": "config_value1",
|
||
"setting2": "db_value2",
|
||
"setting3": "db_value3",
|
||
}
|
||
assert combined_settings["nested_config"] == expected_nested
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_from_db_config_edge_cases():
|
||
"""
|
||
Test edge cases for _add_router_settings_from_db_config method.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Test Case 1: No router provided
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={"router_settings": {"test": "value"}},
|
||
llm_router=None,
|
||
prisma_client=MagicMock(),
|
||
)
|
||
# Should not call anything when router is None
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 2: No prisma client provided
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={"router_settings": {"test": "value"}},
|
||
llm_router=mock_router,
|
||
prisma_client=None,
|
||
)
|
||
# Should not call anything when prisma_client is None
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 3: DB returns None (no router_settings in DB)
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
config_data = {"router_settings": {"routing_strategy": "usage-based"}}
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only config settings
|
||
mock_router.update_settings.assert_called_once_with(routing_strategy="usage-based")
|
||
mock_router.reset_mock()
|
||
|
||
# Test Case 4: Config has no router_settings
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {"db_setting": "db_value"}
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={}, # No router_settings in config
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only DB settings
|
||
mock_router.update_settings.assert_called_once_with(db_setting="db_value")
|
||
mock_router.reset_mock()
|
||
|
||
# Test Case 5: Both config and DB router_settings are None/empty
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data={}, llm_router=mock_router, prisma_client=mock_prisma_client
|
||
)
|
||
|
||
# Should not call update_settings when no settings exist
|
||
mock_router.update_settings.assert_not_called()
|
||
|
||
# Test Case 6: DB config exists but param_value is not a dict
|
||
mock_db_config_invalid = MagicMock()
|
||
mock_db_config_invalid.param_value = "not_a_dict"
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config_invalid
|
||
)
|
||
|
||
config_data = {"router_settings": {"config_setting": "config_value"}}
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Should use only config settings when DB param_value is invalid
|
||
mock_router.update_settings.assert_called_once_with(config_setting="config_value")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_router_settings_shallow_merge_behavior():
|
||
"""
|
||
Test that the merge behavior is shallow (nested dicts get replaced, not merged).
|
||
This documents the current behavior using _update_dictionary.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_router = MagicMock()
|
||
mock_router.update_settings = MagicMock()
|
||
|
||
# Config with nested dictionary
|
||
config_data = {
|
||
"router_settings": {
|
||
"nested_setting": {
|
||
"key1": "config_value1",
|
||
"key2": "config_value2",
|
||
"key3": "config_value3",
|
||
},
|
||
"top_level": "config_top",
|
||
}
|
||
}
|
||
|
||
# DB config that partially overlaps the nested dictionary
|
||
mock_db_config = MagicMock()
|
||
mock_db_config.param_value = {
|
||
"nested_setting": {
|
||
"key2": "db_value2", # Override existing key
|
||
"key4": "db_value4", # Add new key
|
||
# Note: key1 and key3 from config will be lost due to shallow merge
|
||
},
|
||
"top_level": "db_top", # Override top level
|
||
}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_config
|
||
)
|
||
|
||
await proxy_config._add_router_settings_from_db_config(
|
||
config_data=config_data,
|
||
llm_router=mock_router,
|
||
prisma_client=mock_prisma_client,
|
||
)
|
||
|
||
# Get the merged settings
|
||
call_args = mock_router.update_settings.call_args
|
||
merged_settings = call_args[1]
|
||
|
||
# Verify shallow merge behavior:
|
||
# The entire nested_setting dict from config is replaced by the DB version
|
||
expected_nested = {
|
||
"key1": "config_value1",
|
||
"key3": "config_value3",
|
||
"key2": "db_value2",
|
||
"key4": "db_value4",
|
||
}
|
||
|
||
assert merged_settings["nested_setting"] == expected_nested
|
||
assert merged_settings["top_level"] == "db_top"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_model_info_v1_oci_secrets_not_leaked():
|
||
"""
|
||
Test that model_info_v1 endpoint properly masks OCI sensitive parameters and does not leak secrets.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import model_info_v1
|
||
|
||
# Mock user authentication
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_user_api_key_dict.user_id = "test-user"
|
||
mock_user_api_key_dict.api_key = "test-key"
|
||
mock_user_api_key_dict.team_models = []
|
||
mock_user_api_key_dict.models = ["oci-grok-test"]
|
||
|
||
# Mock model data with OCI sensitive information
|
||
mock_model_data = {
|
||
"model_name": "oci-grok-test",
|
||
"litellm_params": {
|
||
"model": "oci/xai.grok-4",
|
||
"oci_key": "ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_region": "us-phoenix-1",
|
||
"oci_user": "ocid1.user.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_fingerprint": "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00",
|
||
"oci_tenancy": "ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"oci_key_file": "/path/to/oci_api_key.pem",
|
||
"oci_compartment_id": "ocid1.compartment.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk",
|
||
"drop_params": True,
|
||
},
|
||
"model_info": {"mode": "completion", "id": "test-model-id"},
|
||
}
|
||
|
||
# Mock the llm_router to return our test data
|
||
mock_router = MagicMock()
|
||
mock_router.get_model_names.return_value = ["oci-grok-test"]
|
||
mock_router.get_model_access_groups.return_value = {}
|
||
mock_router.get_model_list.return_value = [mock_model_data]
|
||
|
||
# Mock global variables
|
||
with (
|
||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
|
||
patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"infer_model_from_keys": False},
|
||
),
|
||
patch("litellm.proxy.proxy_server.user_model", None),
|
||
):
|
||
# Call the model_info_v1 endpoint
|
||
result = await model_info_v1(
|
||
user_api_key_dict=mock_user_api_key_dict, litellm_model_id=None
|
||
)
|
||
|
||
# Verify the result structure
|
||
assert "data" in result
|
||
assert len(result["data"]) == 1
|
||
|
||
model_info = result["data"][0]
|
||
litellm_params = model_info["litellm_params"]
|
||
|
||
# Verify that sensitive OCI fields are masked
|
||
assert "****" in litellm_params["oci_key"], "oci_key should be masked"
|
||
assert (
|
||
"****" in litellm_params["oci_fingerprint"]
|
||
), "oci_fingerprint should be masked"
|
||
assert "****" in litellm_params["oci_tenancy"], "oci_tenancy should be masked"
|
||
assert "****" in litellm_params["oci_key_file"], "oci_key_file should be masked"
|
||
|
||
# Verify that non-sensitive fields are NOT masked
|
||
assert (
|
||
litellm_params["model"] == "oci/xai.grok-4"
|
||
), "model field should not be masked"
|
||
assert (
|
||
litellm_params["oci_region"] == "us-phoenix-1"
|
||
), "oci_region should not be masked"
|
||
assert litellm_params["drop_params"] is True, "drop_params should not be masked"
|
||
|
||
# Verify the model field specifically is not masked (this was the original issue)
|
||
assert (
|
||
"****" not in litellm_params["model"]
|
||
), "model field should never be masked"
|
||
assert litellm_params["model"].startswith(
|
||
"oci/"
|
||
), "model should retain its full value"
|
||
|
||
# Verify that actual secret values are not present in the response
|
||
result_str = str(result)
|
||
assert (
|
||
"ocid1.api_key.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
|
||
not in result_str
|
||
)
|
||
assert "aa:bb:cc:dd:ee:ff:11:22:33:44:55:66:77:88:99:00" not in result_str
|
||
assert (
|
||
"ocid1.tenancy.oc1..aaaaaaaa7kbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbkbk"
|
||
not in result_str
|
||
)
|
||
assert "/path/to/oci_api_key.pem" not in result_str
|
||
|
||
|
||
def test_add_callback_from_db_to_in_memory_litellm_callbacks():
|
||
"""
|
||
Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks
|
||
for success, failure, and combined event types.
|
||
"""
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock the callback manager
|
||
mock_callback_manager = MagicMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.litellm") as mock_litellm:
|
||
# Set up mock litellm attributes
|
||
mock_litellm._known_custom_logger_compatible_callbacks = []
|
||
mock_litellm.logging_callback_manager = mock_callback_manager
|
||
|
||
# Test Case 1: Add success callback
|
||
mock_success_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="prometheus",
|
||
event_types=["success"],
|
||
existing_callbacks=mock_success_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_success_callback.assert_called_once_with(
|
||
"prometheus"
|
||
)
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 2: Add failure callback
|
||
mock_failure_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="langfuse",
|
||
event_types=["failure"],
|
||
existing_callbacks=mock_failure_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_failure_callback.assert_called_once_with(
|
||
"langfuse"
|
||
)
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 3: Add callback for both success and failure
|
||
mock_callbacks = []
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="s3",
|
||
event_types=["success", "failure"],
|
||
existing_callbacks=mock_callbacks,
|
||
)
|
||
mock_callback_manager.add_litellm_callback.assert_called_once_with("s3")
|
||
mock_callback_manager.reset_mock()
|
||
|
||
# Test Case 4: Don't add callback if it already exists
|
||
existing_callbacks_with_item = ["prometheus"]
|
||
proxy_config._add_callback_from_db_to_in_memory_litellm_callbacks(
|
||
callback="prometheus",
|
||
event_types=["success"],
|
||
existing_callbacks=existing_callbacks_with_item,
|
||
)
|
||
mock_callback_manager.add_litellm_success_callback.assert_not_called()
|
||
|
||
|
||
def test_should_load_db_object_with_supported_db_objects():
|
||
"""
|
||
Test _should_load_db_object method with supported_db_objects configuration.
|
||
|
||
Verifies that when supported_db_objects is set, only specified object types
|
||
are loaded from the database.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: supported_db_objects not set - all objects should be loaded
|
||
with patch("litellm.proxy.proxy_server.general_settings", {}):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
|
||
# Test Case 2: supported_db_objects set to only load MCP
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": ["mcp"]},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is False
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is False
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is False
|
||
|
||
# Test Case 3: supported_db_objects set to load multiple types
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": ["mcp", "guardrails", "vector_stores"]},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is False
|
||
|
||
# Test Case 4: supported_db_objects is not a list (should default to loading all)
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": "invalid_type"},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
|
||
# Test Case 5: supported_db_objects is an empty list (nothing should be loaded)
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{"supported_db_objects": []},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is False
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is False
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is False
|
||
|
||
# Test Case 6: Test all available object types
|
||
with patch(
|
||
"litellm.proxy.proxy_server.general_settings",
|
||
{
|
||
"supported_db_objects": [
|
||
"models",
|
||
"mcp",
|
||
"guardrails",
|
||
"vector_stores",
|
||
"pass_through_endpoints",
|
||
"prompts",
|
||
"model_cost_map",
|
||
]
|
||
},
|
||
):
|
||
assert proxy_config._should_load_db_object(object_type="models") is True
|
||
assert proxy_config._should_load_db_object(object_type="mcp") is True
|
||
assert proxy_config._should_load_db_object(object_type="guardrails") is True
|
||
assert proxy_config._should_load_db_object(object_type="vector_stores") is True
|
||
assert (
|
||
proxy_config._should_load_db_object(object_type="pass_through_endpoints")
|
||
is True
|
||
)
|
||
assert proxy_config._should_load_db_object(object_type="prompts") is True
|
||
assert proxy_config._should_load_db_object(object_type="model_cost_map") is True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tag_cache_update_called():
|
||
"""
|
||
Test that update_cache updates tag cache when tags are provided.
|
||
"""
|
||
from litellm.caching.caching import DualCache
|
||
from litellm.proxy.proxy_server import user_api_key_cache
|
||
|
||
cache = DualCache()
|
||
|
||
setattr(
|
||
litellm.proxy.proxy_server,
|
||
"user_api_key_cache",
|
||
cache,
|
||
)
|
||
|
||
mock_tag_obj = {
|
||
"tag_name": "test-tag",
|
||
"spend": 10.0,
|
||
}
|
||
|
||
with patch.object(
|
||
cache, "async_get_cache", new=AsyncMock(return_value=mock_tag_obj)
|
||
) as mock_get_cache:
|
||
with patch.object(
|
||
cache, "async_set_cache_pipeline", new=AsyncMock()
|
||
) as mock_set_cache:
|
||
await litellm.proxy.proxy_server.update_cache(
|
||
token=None,
|
||
user_id=None,
|
||
end_user_id=None,
|
||
team_id=None,
|
||
response_cost=5.0,
|
||
parent_otel_span=None,
|
||
tags=["test-tag"],
|
||
)
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
mock_get_cache.assert_awaited_once_with(key="tag:test-tag")
|
||
mock_set_cache.assert_awaited_once()
|
||
|
||
call_args = mock_set_cache.call_args
|
||
cache_list = call_args.kwargs["cache_list"]
|
||
|
||
assert len(cache_list) == 1
|
||
cache_key, cache_value = cache_list[0]
|
||
assert cache_key == "tag:test-tag"
|
||
assert cache_value["spend"] == 15.0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tag_cache_update_multiple_tags():
|
||
"""
|
||
Test that multiple tags are updated in cache.
|
||
"""
|
||
from litellm.caching.caching import DualCache
|
||
from litellm.proxy.proxy_server import user_api_key_cache
|
||
|
||
cache = DualCache()
|
||
|
||
setattr(
|
||
litellm.proxy.proxy_server,
|
||
"user_api_key_cache",
|
||
cache,
|
||
)
|
||
|
||
mock_tag1_obj = {"tag_name": "tag1", "spend": 10.0}
|
||
mock_tag2_obj = {"tag_name": "tag2", "spend": 20.0}
|
||
|
||
async def mock_get_cache_side_effect(key):
|
||
if key == "tag:tag1":
|
||
return mock_tag1_obj
|
||
elif key == "tag:tag2":
|
||
return mock_tag2_obj
|
||
return None
|
||
|
||
with patch.object(
|
||
cache, "async_get_cache", new=AsyncMock(side_effect=mock_get_cache_side_effect)
|
||
) as mock_get_cache:
|
||
with patch.object(
|
||
cache, "async_set_cache_pipeline", new=AsyncMock()
|
||
) as mock_set_cache:
|
||
await litellm.proxy.proxy_server.update_cache(
|
||
token=None,
|
||
user_id=None,
|
||
end_user_id=None,
|
||
team_id=None,
|
||
response_cost=5.0,
|
||
parent_otel_span=None,
|
||
tags=["tag1", "tag2"],
|
||
)
|
||
|
||
await asyncio.sleep(0.1)
|
||
|
||
assert mock_get_cache.call_count == 2
|
||
mock_set_cache.assert_awaited_once()
|
||
|
||
call_args = mock_set_cache.call_args
|
||
cache_list = call_args.kwargs["cache_list"]
|
||
|
||
assert len(cache_list) == 2
|
||
|
||
tag_updates = {
|
||
cache_key: cache_value for cache_key, cache_value in cache_list
|
||
}
|
||
assert "tag:tag1" in tag_updates
|
||
assert "tag:tag2" in tag_updates
|
||
assert tag_updates["tag:tag1"]["spend"] == 15.0
|
||
assert tag_updates["tag:tag2"]["spend"] == 25.0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db():
|
||
"""
|
||
Test that _init_sso_settings_in_db properly loads SSO settings from database,
|
||
uppercases keys, and calls _decrypt_and_set_db_env_variables.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test Case 1: SSO settings exist in database
|
||
mock_sso_config = MagicMock()
|
||
mock_sso_config.sso_settings = {
|
||
"google_client_id": "test-client-id",
|
||
"google_client_secret": "test-client-secret",
|
||
"microsoft_client_id": "ms-client-id",
|
||
"microsoft_client_secret": "ms-client-secret",
|
||
}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
return_value=mock_sso_config
|
||
)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called with correct parameters
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was called with uppercased keys
|
||
mock_decrypt_and_set.assert_called_once()
|
||
call_args = mock_decrypt_and_set.call_args
|
||
uppercased_settings = call_args.kwargs["environment_variables"]
|
||
|
||
# Verify all keys are uppercased
|
||
assert "GOOGLE_CLIENT_ID" in uppercased_settings
|
||
assert "GOOGLE_CLIENT_SECRET" in uppercased_settings
|
||
assert "MICROSOFT_CLIENT_ID" in uppercased_settings
|
||
assert "MICROSOFT_CLIENT_SECRET" in uppercased_settings
|
||
|
||
# Verify values are preserved
|
||
assert uppercased_settings["GOOGLE_CLIENT_ID"] == "test-client-id"
|
||
assert uppercased_settings["GOOGLE_CLIENT_SECRET"] == "test-client-secret"
|
||
assert uppercased_settings["MICROSOFT_CLIENT_ID"] == "ms-client-id"
|
||
assert uppercased_settings["MICROSOFT_CLIENT_SECRET"] == "ms-client-secret"
|
||
|
||
# Verify original lowercase keys are not present
|
||
assert "google_client_id" not in uppercased_settings
|
||
assert "microsoft_client_id" not in uppercased_settings
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_no_settings():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles the case when no SSO settings exist in database.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client to return None (no SSO settings)
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(return_value=None)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was NOT called when no settings exist
|
||
mock_decrypt_and_set.assert_not_called()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_error_handling():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles database errors gracefully.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock prisma client to raise an exception
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
side_effect=Exception("Database connection error")
|
||
)
|
||
|
||
# The method should not raise an exception, it should log it instead
|
||
try:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
# If we get here, the exception was handled properly
|
||
assert True
|
||
except Exception as e:
|
||
# The exception should be caught and logged, not propagated
|
||
pytest.fail(
|
||
f"Exception should have been caught and logged, but was raised: {e}"
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_empty_settings():
|
||
"""
|
||
Test that _init_sso_settings_in_db handles empty SSO settings dictionary.
|
||
"""
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock SSO config with empty settings dictionary
|
||
mock_sso_config = MagicMock()
|
||
mock_sso_config.sso_settings = {}
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
return_value=mock_sso_config
|
||
)
|
||
|
||
# Mock _decrypt_and_set_db_env_variables
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt_and_set:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
# Verify find_unique was called
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique.assert_awaited_once_with(
|
||
where={"id": "sso_config"}
|
||
)
|
||
|
||
# Verify _decrypt_and_set_db_env_variables was called with empty dict
|
||
mock_decrypt_and_set.assert_called_once()
|
||
call_args = mock_decrypt_and_set.call_args
|
||
uppercased_settings = call_args.kwargs["environment_variables"]
|
||
|
||
# Verify empty dictionary
|
||
assert uppercased_settings == {}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_retries_on_transport_error():
|
||
"""`_init_sso_settings_in_db` self-heals across one ClientNotConnectedError
|
||
via call_with_db_reconnect_retry — mirrors the auth-path behavior so
|
||
startup/reload bursts don't spam the log."""
|
||
import prisma
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_sso_config = MagicMock()
|
||
mock_sso_config.sso_settings = {"GOOGLE_CLIENT_ID": "xxx"}
|
||
|
||
invocations: list = []
|
||
|
||
async def _flaky_find_unique(**kwargs):
|
||
invocations.append(None)
|
||
if len(invocations) == 1:
|
||
raise prisma.errors.ClientNotConnectedError()
|
||
return mock_sso_config
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
side_effect=_flaky_find_unique
|
||
)
|
||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||
|
||
with patch.object(
|
||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||
) as mock_decrypt:
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
assert len(invocations) == 2
|
||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||
assert reconnect_kwargs["reason"] == "init_sso_settings_in_db_lookup_failure"
|
||
mock_decrypt.assert_called_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_sso_settings_in_db_propagates_when_reconnect_fails():
|
||
"""When reconnect returns False (cooldown / lock contention), the original
|
||
ClientNotConnectedError is caught by the function's `except Exception` and
|
||
logged — no retry storm, no crash."""
|
||
import prisma
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||
side_effect=prisma.errors.ClientNotConnectedError()
|
||
)
|
||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=False)
|
||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||
|
||
# Should NOT raise — the function's own try/except swallows the propagated error.
|
||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||
|
||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_hashicorp_vault_config_override_retries_on_transport_error():
|
||
"""`_init_hashicorp_vault_config_override` self-heals across one
|
||
ClientNotConnectedError via call_with_db_reconnect_retry."""
|
||
import prisma
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
proxy_config._last_hashicorp_vault_config = None
|
||
|
||
invocations: list = []
|
||
|
||
async def _flaky_find_unique(**kwargs):
|
||
invocations.append(None)
|
||
if len(invocations) == 1:
|
||
raise prisma.errors.ClientNotConnectedError()
|
||
return None # No config in DB → function returns early after retry.
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_configoverrides.find_unique = AsyncMock(
|
||
side_effect=_flaky_find_unique
|
||
)
|
||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||
|
||
await proxy_config._init_hashicorp_vault_config_override(
|
||
prisma_client=mock_prisma_client
|
||
)
|
||
|
||
assert len(invocations) == 2
|
||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||
assert (
|
||
reconnect_kwargs["reason"]
|
||
== "init_hashicorp_vault_config_override_lookup_failure"
|
||
)
|
||
|
||
|
||
def test_update_config_fields_uppercases_env_vars(monkeypatch):
|
||
"""
|
||
Ensure environment variables pulled from DB are uppercased when applied so
|
||
integrations like Datadog that expect uppercase env keys can read them.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
for key in ["DD_API_KEY", "DD_SITE", "dd_api_key", "dd_site"]:
|
||
monkeypatch.delenv(key, raising=False)
|
||
|
||
proxy_config = ProxyConfig()
|
||
updated_config = proxy_config._update_config_fields(
|
||
current_config={},
|
||
param_name="environment_variables",
|
||
db_param_value={"dd_api_key": "test-api-key", "dd_site": "us5.datadoghq.com"},
|
||
)
|
||
|
||
env_vars = updated_config.get("environment_variables", {})
|
||
assert env_vars["DD_API_KEY"] == "test-api-key"
|
||
assert env_vars["DD_SITE"] == "us5.datadoghq.com"
|
||
assert os.environ.get("DD_API_KEY") == "test-api-key"
|
||
assert os.environ.get("DD_SITE") == "us5.datadoghq.com"
|
||
|
||
|
||
def test_encrypt_env_variables_for_db_is_idempotent(monkeypatch):
|
||
"""
|
||
Regression: /config/update and save_config must not stack a second
|
||
encryption layer when a caller re-submits a value that is already
|
||
ciphertext (the Admin UI reads config back from /get/config/callbacks —
|
||
which returns the stored, still-encrypted value — and re-POSTs it on the
|
||
next save). _encrypt_env_variables_for_db must yield a value that decrypts
|
||
to the original plaintext in exactly ONE layer, no matter how many times
|
||
its own output is fed back in. It must also not mutate os.environ (write
|
||
path — loading into the process env is the read path's job).
|
||
"""
|
||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||
decrypt_value_helper,
|
||
)
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
monkeypatch.setenv("LITELLM_SALT_KEY", "sk-test-salt-key")
|
||
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
|
||
|
||
proxy_config = ProxyConfig()
|
||
plaintext = "pk-langfuse-secret-value"
|
||
|
||
# First write: plaintext in -> single-encrypted out.
|
||
enc1 = proxy_config._encrypt_env_variables_for_db(
|
||
{"LANGFUSE_PUBLIC_KEY": plaintext}
|
||
)
|
||
assert enc1["LANGFUSE_PUBLIC_KEY"] != plaintext
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=enc1["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# UI round-trip: feed the ciphertext back in. Must NOT double-encrypt.
|
||
enc2 = proxy_config._encrypt_env_variables_for_db(enc1)
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=enc2["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# And again, ×3 total ciphertext re-feeds — still exactly one layer,
|
||
# never stacked, no matter how many times the UI re-saves.
|
||
enc3 = proxy_config._encrypt_env_variables_for_db(enc2)
|
||
enc4 = proxy_config._encrypt_env_variables_for_db(enc3)
|
||
for stacked in (enc3, enc4):
|
||
assert (
|
||
decrypt_value_helper(
|
||
value=stacked["LANGFUSE_PUBLIC_KEY"], key="LANGFUSE_PUBLIC_KEY"
|
||
)
|
||
== plaintext
|
||
)
|
||
|
||
# Write path must not leak the value into the process environment.
|
||
assert os.environ.get("LANGFUSE_PUBLIC_KEY") is None
|
||
|
||
|
||
def test_get_prompt_spec_for_db_prompt_with_versions():
|
||
"""
|
||
Test that _get_prompt_spec_for_db_prompt correctly converts database prompts
|
||
to PromptSpec with versioned naming convention.
|
||
"""
|
||
from unittest.mock import MagicMock
|
||
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Mock database prompt version 1
|
||
mock_prompt_v1 = MagicMock()
|
||
mock_prompt_v1.model_dump.return_value = {
|
||
"id": "uuid-1",
|
||
"prompt_id": "chat_prompt",
|
||
"version": 1,
|
||
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "v1 content"}]}',
|
||
"prompt_info": '{"prompt_type": "db"}',
|
||
"created_at": "2024-01-01T00:00:00",
|
||
"updated_at": "2024-01-01T00:00:00",
|
||
}
|
||
|
||
# Mock database prompt version 2
|
||
mock_prompt_v2 = MagicMock()
|
||
mock_prompt_v2.model_dump.return_value = {
|
||
"id": "uuid-2",
|
||
"prompt_id": "chat_prompt",
|
||
"version": 2,
|
||
"litellm_params": '{"prompt_id": "chat_prompt", "prompt_integration": "dotprompt", "model": "gpt-4", "messages": [{"role": "user", "content": "v2 content"}]}',
|
||
"prompt_info": '{"prompt_type": "db"}',
|
||
"created_at": "2024-01-02T00:00:00",
|
||
"updated_at": "2024-01-02T00:00:00",
|
||
}
|
||
|
||
# Test version 1
|
||
prompt_spec_v1 = proxy_config._get_prompt_spec_for_db_prompt(
|
||
db_prompt=mock_prompt_v1
|
||
)
|
||
assert prompt_spec_v1.prompt_id == "chat_prompt.v1"
|
||
|
||
# Test version 2
|
||
prompt_spec_v2 = proxy_config._get_prompt_spec_for_db_prompt(
|
||
db_prompt=mock_prompt_v2
|
||
)
|
||
assert prompt_spec_v2.prompt_id == "chat_prompt.v2"
|
||
|
||
|
||
def test_root_redirect_when_docs_url_not_root_and_redirect_url_set(monkeypatch):
|
||
from fastapi.responses import RedirectResponse
|
||
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
from litellm.proxy.utils import _get_docs_url
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
# Ensure docs are mounted on a non-root path to trigger redirect logic
|
||
monkeypatch.setenv("DOCS_URL", "/docs")
|
||
|
||
test_redirect_url = "/ui"
|
||
monkeypatch.setenv("ROOT_REDIRECT_URL", test_redirect_url)
|
||
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
docs_url = _get_docs_url()
|
||
root_redirect_url = os.getenv("ROOT_REDIRECT_URL")
|
||
|
||
# Remove any existing "/" route that might interfere
|
||
routes_to_remove = []
|
||
for route in app.routes:
|
||
if hasattr(route, "path") and route.path == "/":
|
||
if hasattr(route, "methods") and "GET" in route.methods:
|
||
routes_to_remove.append(route)
|
||
elif not hasattr(route, "methods"): # Catch-all routes
|
||
routes_to_remove.append(route)
|
||
|
||
for route in routes_to_remove:
|
||
app.routes.remove(route)
|
||
|
||
# Add the redirect route if conditions are met (matching the actual implementation)
|
||
if docs_url != "/" and root_redirect_url:
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
async def root_redirect():
|
||
return RedirectResponse(url=root_redirect_url)
|
||
|
||
client = TestClient(app)
|
||
response = client.get("/", follow_redirects=False)
|
||
assert response.status_code == 307
|
||
assert response.headers["location"] == test_redirect_url
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_non_root_uses_var_lib_assets_dir(monkeypatch):
|
||
"""
|
||
Test that get_image uses /var/lib/litellm/assets when LITELLM_NON_ROOT is true.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Set LITELLM_NON_ROOT to true
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Mock os.path operations - exists=False for assets_dir so makedirs gets called
|
||
def exists_side_effect(path):
|
||
return False if path == "/var/lib/litellm/assets" else True
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch(
|
||
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
|
||
),
|
||
patch("litellm.proxy.proxy_server.os.access", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv to return empty string for UI_LOGO_PATH
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "true"
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was called with /var/lib/litellm/assets
|
||
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_non_root_fallback_to_default_logo(monkeypatch):
|
||
"""
|
||
Test that get_image falls back to default_site_logo when logo doesn't exist
|
||
in /var/lib/litellm/assets for non-root case.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Set LITELLM_NON_ROOT to true
|
||
monkeypatch.setenv("LITELLM_NON_ROOT", "true")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Track path.exists calls to verify it checks /var/lib/litellm/assets/logo.jpg
|
||
exists_calls = []
|
||
|
||
def exists_side_effect(path):
|
||
exists_calls.append(path)
|
||
# Return False for /var/lib/litellm/assets* so: makedirs is called, logo fallback
|
||
# triggers, and we don't return early with cached file
|
||
if "/var/lib/litellm/assets" in path:
|
||
return False
|
||
return True
|
||
|
||
# Mock os.path operations
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch(
|
||
"litellm.proxy.proxy_server.os.path.exists", side_effect=exists_side_effect
|
||
),
|
||
patch("litellm.proxy.proxy_server.os.access", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "true"
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was called with /var/lib/litellm/assets
|
||
mock_makedirs.assert_called_once_with("/var/lib/litellm/assets", exist_ok=True)
|
||
|
||
# Verify that exists was called to check /var/lib/litellm/assets/logo.jpg
|
||
assets_logo_path = "/var/lib/litellm/assets/logo.jpg"
|
||
assert any(
|
||
assets_logo_path in str(call) for call in exists_calls
|
||
), f"Should check if {assets_logo_path} exists"
|
||
|
||
# Verify FileResponse was called (with fallback logo)
|
||
assert mock_file_response.called, "FileResponse should be called"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_root_case_uses_current_dir(monkeypatch):
|
||
"""
|
||
Test that get_image uses current_dir when LITELLM_NON_ROOT is not true.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
# Don't set LITELLM_NON_ROOT (or set it to false)
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
|
||
# Mock os.path operations
|
||
with (
|
||
patch("litellm.proxy.proxy_server.os.makedirs") as mock_makedirs,
|
||
patch("litellm.proxy.proxy_server.os.path.exists", return_value=True),
|
||
patch("litellm.proxy.proxy_server.os.getenv") as mock_getenv,
|
||
patch("litellm.proxy.proxy_server.FileResponse") as mock_file_response,
|
||
):
|
||
# Setup mock_getenv
|
||
def getenv_side_effect(key, default=""):
|
||
if key == "UI_LOGO_PATH":
|
||
return ""
|
||
elif key == "LITELLM_NON_ROOT":
|
||
return "" # Not set or empty
|
||
return default
|
||
|
||
mock_getenv.side_effect = getenv_side_effect
|
||
|
||
# Call the function
|
||
await get_image()
|
||
|
||
# Verify makedirs was NOT called with /var/lib/litellm/assets (should not create it for root case)
|
||
var_lib_assets_calls = [
|
||
call
|
||
for call in mock_makedirs.call_args_list
|
||
if "/var/lib/litellm/assets" in str(call)
|
||
]
|
||
assert (
|
||
len(var_lib_assets_calls) == 0
|
||
), "Should not create /var/lib/litellm/assets for root case"
|
||
|
||
# Verify FileResponse was called
|
||
assert mock_file_response.called, "FileResponse should be called"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_local_logo_bypasses_cache(monkeypatch, tmp_path):
|
||
"""
|
||
Test that when UI_LOGO_PATH is set to a local file, get_image serves it
|
||
directly and does not return a stale cached_logo.jpg.
|
||
|
||
Regression test: previously the cache check ran before reading UI_LOGO_PATH,
|
||
so a pre-existing cached_logo.jpg (e.g. from the base Docker image) would
|
||
always be returned, ignoring the user's custom logo.
|
||
"""
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo = tmp_path / "custom_logo.jpg"
|
||
custom_logo.write_bytes(b"\xff\xd8\xff custom logo")
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.delenv("LITELLM_ASSETS_PATH", raising=False)
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
assert calls_to_file_response[0] == str(custom_logo.resolve()), (
|
||
f"Expected custom logo path, got {calls_to_file_response[0]}. "
|
||
"A stale cached_logo.jpg may have been returned instead."
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_default_logo_ignores_stale_cache(monkeypatch, tmp_path):
|
||
"""
|
||
Test that when UI_LOGO_PATH is NOT set, stale pre-fix cached_logo.jpg
|
||
files are ignored and the default logo is served.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
cache_path = tmp_path / "cached_logo.jpg"
|
||
cache_path.write_bytes(b"\xff\xd8\xff cached logo")
|
||
monkeypatch.delenv("UI_LOGO_PATH", raising=False)
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(cache_path.resolve())
|
||
assert served_path.endswith("logo.jpg")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_logo_missing_falls_through_to_default(
|
||
monkeypatch, tmp_path
|
||
):
|
||
"""
|
||
Test that when UI_LOGO_PATH points to a non-existent local file,
|
||
get_image falls through to the default logo instead of failing.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(
|
||
custom_logo_path
|
||
), "Should not attempt to serve a non-existent custom logo"
|
||
assert served_path.endswith("logo.jpg")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_image_custom_logo_missing_no_cache_serves_default(
|
||
monkeypatch, tmp_path
|
||
):
|
||
"""
|
||
Test that when UI_LOGO_PATH points to a non-existent file AND there is no
|
||
cached_logo.jpg, get_image serves the default logo instead of the non-existent
|
||
custom path.
|
||
"""
|
||
from unittest.mock import patch
|
||
|
||
from litellm.proxy.proxy_server import get_image
|
||
|
||
custom_logo_path = tmp_path / "nonexistent_logo.jpg"
|
||
monkeypatch.setenv("UI_LOGO_PATH", str(custom_logo_path))
|
||
monkeypatch.delenv("LITELLM_NON_ROOT", raising=False)
|
||
monkeypatch.setenv("LITELLM_ASSETS_PATH", str(tmp_path))
|
||
|
||
calls_to_file_response = []
|
||
|
||
def fake_file_response(path, **kwargs):
|
||
calls_to_file_response.append(path)
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch(
|
||
"litellm.proxy.proxy_server.FileResponse", side_effect=fake_file_response
|
||
),
|
||
):
|
||
await get_image()
|
||
|
||
assert (
|
||
len(calls_to_file_response) == 1
|
||
), "FileResponse should be called exactly once"
|
||
served_path = calls_to_file_response[0]
|
||
assert served_path != str(
|
||
custom_logo_path
|
||
), "Should not attempt to serve a non-existent custom logo"
|
||
assert served_path.endswith(
|
||
"logo.jpg"
|
||
), f"Expected fallback to default logo.jpg, got {served_path}"
|
||
|
||
|
||
def test_get_config_normalizes_string_callbacks(monkeypatch):
|
||
"""
|
||
Test that /get/config/callbacks normalizes string callbacks to lists.
|
||
"""
|
||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||
|
||
config_data = {
|
||
"litellm_settings": {
|
||
"success_callback": "langfuse",
|
||
"failure_callback": None,
|
||
"callbacks": ["prometheus", "datadog"],
|
||
},
|
||
"general_settings": {},
|
||
"environment_variables": {},
|
||
}
|
||
|
||
mock_router = MagicMock()
|
||
mock_router.get_settings.return_value = {}
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||
monkeypatch.setattr(proxy_config, "get_config", AsyncMock(return_value=config_data))
|
||
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: MagicMock()
|
||
|
||
client = TestClient(app)
|
||
try:
|
||
response = client.get("/get/config/callbacks")
|
||
finally:
|
||
app.dependency_overrides = original_overrides
|
||
|
||
assert response.status_code == 200
|
||
callbacks = response.json()["callbacks"]
|
||
|
||
success_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "success"]
|
||
failure_callbacks = [cb["name"] for cb in callbacks if cb.get("type") == "failure"]
|
||
success_and_failure_callbacks = [
|
||
cb["name"] for cb in callbacks if cb.get("type") == "success_and_failure"
|
||
]
|
||
|
||
assert "langfuse" in success_callbacks
|
||
assert len(failure_callbacks) == 0
|
||
assert "prometheus" in success_and_failure_callbacks
|
||
assert "datadog" in success_and_failure_callbacks
|
||
|
||
|
||
def test_deep_merge_dicts_skips_none_and_empty_lists(monkeypatch):
|
||
"""
|
||
Test that _update_config_fields deep merge skips None values and empty lists.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
current_config = {
|
||
"general_settings": {
|
||
"max_parallel_requests": 10,
|
||
"allowed_models": ["gpt-3.5-turbo", "gpt-4"],
|
||
"nested": {
|
||
"key1": "value1",
|
||
"key2": "value2",
|
||
},
|
||
}
|
||
}
|
||
|
||
db_param_value = {
|
||
"max_parallel_requests": None,
|
||
"allowed_models": [],
|
||
"new_key": "new_value",
|
||
"nested": {
|
||
"key1": "updated_value1",
|
||
"key3": "value3",
|
||
},
|
||
}
|
||
|
||
result = proxy_config._update_config_fields(
|
||
current_config, "general_settings", db_param_value
|
||
)
|
||
|
||
assert result["general_settings"]["max_parallel_requests"] == 10
|
||
assert result["general_settings"]["allowed_models"] == ["gpt-3.5-turbo", "gpt-4"]
|
||
assert result["general_settings"]["new_key"] == "new_value"
|
||
assert result["general_settings"]["nested"]["key1"] == "updated_value1"
|
||
assert result["general_settings"]["nested"]["key2"] == "value2"
|
||
assert result["general_settings"]["nested"]["key3"] == "value3"
|
||
|
||
|
||
class TestInvitationEndpoints:
|
||
"""Tests for /invitation/new and /invitation/delete endpoints."""
|
||
|
||
@pytest.fixture
|
||
def client_with_auth(self):
|
||
"""Create a test client with admin authentication."""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||
|
||
cleanup_router_config_variables()
|
||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||
asyncio.run(initialize(config=config_fp, debug=True))
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_id = "admin-user-id"
|
||
mock_auth.user_role = LitellmUserRoles.PROXY_ADMIN
|
||
mock_auth.api_key = "sk-test"
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
return TestClient(app)
|
||
|
||
@pytest.mark.parametrize(
|
||
"endpoint,payload,mock_return",
|
||
[
|
||
(
|
||
"/invitation/new",
|
||
{"user_id": "target-user-123"},
|
||
{
|
||
"id": "inv-123",
|
||
"user_id": "target-user-123",
|
||
"is_accepted": False,
|
||
"accepted_at": None,
|
||
"expires_at": "2025-02-18T00:00:00",
|
||
"created_at": "2025-02-11T00:00:00",
|
||
"created_by": "admin-user-id",
|
||
"updated_at": "2025-02-11T00:00:00",
|
||
"updated_by": "admin-user-id",
|
||
},
|
||
),
|
||
(
|
||
"/invitation/delete",
|
||
{"invitation_id": "inv-456"},
|
||
{
|
||
"id": "inv-456",
|
||
"user_id": "target-user-123",
|
||
"is_accepted": False,
|
||
"accepted_at": None,
|
||
"expires_at": "2025-02-18T00:00:00",
|
||
"created_at": "2025-02-11T00:00:00",
|
||
"created_by": "admin-user-id",
|
||
"updated_at": "2025-02-11T00:00:00",
|
||
"updated_by": "admin-user-id",
|
||
},
|
||
),
|
||
],
|
||
)
|
||
def test_invitation_endpoints_proxy_admin_success(
|
||
self, client_with_auth, endpoint, payload, mock_return
|
||
):
|
||
"""Proxy admin can successfully create and delete invitations."""
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_invitationlink = MagicMock()
|
||
if endpoint == "/invitation/new":
|
||
mock_create = AsyncMock(return_value=mock_return)
|
||
with patch(
|
||
"litellm.proxy.management_helpers.user_invitation.create_invitation_for_user",
|
||
mock_create,
|
||
):
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
else:
|
||
mock_prisma.db.litellm_invitationlink.find_unique = AsyncMock(
|
||
return_value={**mock_return, "created_by": "admin-user-id"}
|
||
)
|
||
mock_prisma.db.litellm_invitationlink.delete = AsyncMock(
|
||
return_value=mock_return
|
||
)
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == mock_return["id"]
|
||
assert data["user_id"] == mock_return["user_id"]
|
||
|
||
@pytest.mark.parametrize(
|
||
"endpoint,payload",
|
||
[
|
||
("/invitation/new", {"user_id": "target-user-123"}),
|
||
("/invitation/delete", {"invitation_id": "inv-456"}),
|
||
],
|
||
)
|
||
def test_invitation_endpoints_non_admin_denied(
|
||
self, client_with_auth, endpoint, payload
|
||
):
|
||
"""Non-admin users cannot access invitation endpoints."""
|
||
from litellm.proxy._types import LitellmUserRoles
|
||
|
||
mock_auth = MagicMock()
|
||
mock_auth.user_id = "regular-user"
|
||
mock_auth.user_role = LitellmUserRoles.INTERNAL_USER
|
||
mock_auth.api_key = "sk-regular"
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
|
||
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
|
||
mock_prisma.db.litellm_invitationlink = MagicMock()
|
||
# Avoid triggering async DB calls in _user_has_admin_privileges
|
||
with patch(
|
||
"litellm.proxy.proxy_server._user_has_admin_privileges",
|
||
new_callable=AsyncMock,
|
||
return_value=False,
|
||
):
|
||
response = client_with_auth.post(endpoint, json=payload)
|
||
|
||
assert response.status_code == 400
|
||
body = response.json()
|
||
# ProxyException handler returns {"error": {...}}, HTTPException returns {"detail": {...}}
|
||
error_content = body.get("error", body.get("detail", body))
|
||
assert "not allowed" in str(error_content).lower()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_early_exit():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() in the finally block
|
||
when the generator is abandoned mid-stream (client disconnect).
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
{"choices": [{"delta": {"content": " more"}}]},
|
||
]
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
# Create a mock response with aclose
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
# Consume only the first chunk then abandon the generator (simulates client disconnect)
|
||
gen = async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
)
|
||
first_chunk = await gen.__anext__()
|
||
assert first_chunk.startswith("data: ")
|
||
|
||
# Close the generator early (simulates what ASGI does on client disconnect)
|
||
await gen.aclose()
|
||
|
||
# Verify aclose was called on the response to release the HTTP connection
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_uses_direct_stream_fast_path_without_callbacks():
|
||
"""
|
||
When there are no streaming callbacks, async_data_generator should avoid
|
||
per-chunk hook machinery and iterate the provider stream directly.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
{"choices": [{"delta": {"content": " world"}}]},
|
||
]
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(
|
||
ProxyLogging, "_fire_deferred_stream_logging"
|
||
) as mock_deferred_logging:
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert len([chunk for chunk in yielded_text if chunk.startswith("data: {")]) == 2
|
||
assert yielded_text[-1] == "data: [DONE]\n\n"
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook.assert_not_called()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook.assert_not_awaited()
|
||
mock_deferred_logging.assert_called_once_with(mock_request_data)
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_passes_through_google_native_sse_bytes():
|
||
"""
|
||
Google-native streamGenerateContent yields raw SSE bytes; they must not be
|
||
re-wrapped as data: b'data: {...}'.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
gemini_event = b'data: {"candidates": [{"content": "hi"}]}\n\n'
|
||
gemini_event_without_terminator = b'data: {"candidates": [{"content": "there"}]}'
|
||
raw_payload = b'{"partial": true}'
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield gemini_event
|
||
yield gemini_event_without_terminator
|
||
yield raw_payload
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text[0] == gemini_event.decode("utf-8")
|
||
assert yielded_text[1] == gemini_event_without_terminator.decode("utf-8") + "\n\n"
|
||
assert yielded_text[2] == f'data: {raw_payload.decode("utf-8")}\n\n'
|
||
assert "b'data:" not in "".join(yielded_text)
|
||
assert yielded_text[-1] == "data: [DONE]\n\n"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_google_genai_stream_omits_openai_done():
|
||
"""
|
||
google-genai SDK streamGenerateContent?alt=sse must not receive data: [DONE].
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"_litellm_skip_openai_stream_done": True,
|
||
}
|
||
gemini_event = (
|
||
b'data: {"candidates": [{"content": {"parts": [{"text": "Hi"}]}}]}\n\n'
|
||
)
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield gemini_event
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text == [gemini_event.decode("utf-8")]
|
||
assert "[DONE]" not in "".join(yielded_text)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_google_genai_stream_forwards_error_without_done():
|
||
"""Stream errors must still reach the client when OpenAI [DONE] is skipped."""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
error_sse = 'data: {"error": {"message": "stream failed"}}\n\n'
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gemini-2.0-flash",
|
||
"_litellm_skip_openai_stream_done": True,
|
||
}
|
||
|
||
class MockStream:
|
||
def __aiter__(self):
|
||
return self._stream()
|
||
|
||
async def _stream(self):
|
||
yield error_sse
|
||
|
||
async def aclose(self):
|
||
pass
|
||
|
||
mock_response = MockStream()
|
||
mock_response.aclose = AsyncMock()
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging_obj.has_streaming_callbacks.return_value = False
|
||
mock_proxy_logging_obj.needs_iterator_wrap.return_value = False
|
||
mock_proxy_logging_obj.needs_per_chunk_streaming_hook.return_value = False
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = MagicMock()
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock()
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
with patch.object(ProxyLogging, "_fire_deferred_stream_logging"):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
yielded_text = [
|
||
chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk
|
||
for chunk in yielded_data
|
||
]
|
||
assert yielded_text == [error_sse]
|
||
assert "[DONE]" not in "".join(yielded_text)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_normal_completion():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() even on normal completion.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_chunks = [
|
||
{"choices": [{"delta": {"content": "Hello"}}]},
|
||
]
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator(*args, **kwargs):
|
||
for chunk in mock_chunks:
|
||
yield chunk
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
# Should have completed normally with [DONE]
|
||
assert any("[DONE]" in d for d in yielded_data)
|
||
# aclose should still be called via finally block
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_data_generator_cleanup_on_midstream_error():
|
||
"""
|
||
Test that async_data_generator calls response.aclose() via finally block
|
||
even when an exception occurs mid-stream.
|
||
"""
|
||
from litellm.proxy._types import UserAPIKeyAuth
|
||
from litellm.proxy.proxy_server import async_data_generator
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||
mock_request_data = {
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [{"role": "user", "content": "test"}],
|
||
}
|
||
|
||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||
|
||
async def mock_streaming_iterator_with_error(*args, **kwargs):
|
||
yield {"choices": [{"delta": {"content": "Hello"}}]}
|
||
raise RuntimeError("upstream connection reset")
|
||
|
||
mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
|
||
mock_streaming_iterator_with_error
|
||
)
|
||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||
side_effect=lambda **kwargs: kwargs.get("response")
|
||
)
|
||
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()
|
||
|
||
mock_response = MagicMock()
|
||
mock_response.aclose = AsyncMock()
|
||
|
||
with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
|
||
yielded_data = []
|
||
async for data in async_data_generator(
|
||
mock_response, mock_user_api_key_dict, mock_request_data
|
||
):
|
||
yielded_data.append(data)
|
||
|
||
# Should have yielded data chunk and then an error chunk
|
||
assert len(yielded_data) >= 2
|
||
assert any("error" in d for d in yielded_data)
|
||
# aclose must still be called via finally block despite the error
|
||
mock_response.aclose.assert_awaited_once()
|
||
|
||
|
||
# ============================================================================
|
||
# store_model_in_db DB Config Override Tests
|
||
# ============================================================================
|
||
|
||
|
||
def test_store_model_in_db_in_config_general_settings():
|
||
"""
|
||
Verify store_model_in_db is a valid field in ConfigGeneralSettings
|
||
and validates correctly for True/False values.
|
||
"""
|
||
from litellm.proxy._types import ConfigGeneralSettings
|
||
|
||
assert "store_model_in_db" in ConfigGeneralSettings.model_fields
|
||
|
||
# Should validate with True
|
||
config = ConfigGeneralSettings(store_model_in_db=True)
|
||
assert config.store_model_in_db is True
|
||
|
||
# Should validate with False
|
||
config = ConfigGeneralSettings(store_model_in_db=False)
|
||
assert config.store_model_in_db is False
|
||
|
||
# Should validate with None (default)
|
||
config = ConfigGeneralSettings(store_model_in_db=None)
|
||
assert config.store_model_in_db is None
|
||
|
||
# Should validate with no value
|
||
config = ConfigGeneralSettings()
|
||
assert config.store_model_in_db is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_true():
|
||
"""
|
||
Verify _update_general_settings sets global store_model_in_db to True
|
||
when DB general_settings has store_model_in_db=True.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False) as mock_store,
|
||
patch("litellm.proxy.proxy_server.general_settings", {}) as mock_gs,
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": True}
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
assert ps.general_settings["store_model_in_db"] is True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_false():
|
||
"""
|
||
Verify _update_general_settings sets global store_model_in_db to False
|
||
when DB general_settings has store_model_in_db=False.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": False}
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
assert ps.general_settings["store_model_in_db"] is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_string_normalization():
|
||
"""
|
||
Verify _update_general_settings normalizes string values for store_model_in_db.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# Test "true" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "true"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# Test "True" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "True"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# Test "false" string
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": "false"}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_general_settings_store_model_in_db_none_keeps_current():
|
||
"""
|
||
Verify _update_general_settings does not change store_model_in_db
|
||
when DB value is None.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyConfig
|
||
|
||
proxy_config = ProxyConfig()
|
||
|
||
# When current is True and DB sends None, should stay True
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": None}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
|
||
# When current is False and DB sends None, should stay False
|
||
with (
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||
):
|
||
await proxy_config._update_general_settings(
|
||
db_general_settings={"store_model_in_db": None}
|
||
)
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_override_when_config_false():
|
||
"""
|
||
Verify the early DB check in initialize_scheduled_background_jobs
|
||
overrides store_model_in_db=False when DB has True.
|
||
"""
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
|
||
# Mock DB returning store_model_in_db=True in general_settings
|
||
mock_db_record = MagicMock()
|
||
mock_db_record.param_value = {"store_model_in_db": True}
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
return_value=mock_db_record
|
||
)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
# store_model_in_db should now be True (overridden by DB)
|
||
assert ps.store_model_in_db is True
|
||
|
||
# add_deployment and get_credentials should have been called
|
||
# since store_model_in_db is now True
|
||
assert mock_proxy_config.add_deployment.call_count == 1
|
||
assert mock_proxy_config.get_credentials.call_count == 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_check_skipped_when_already_true(monkeypatch):
|
||
"""
|
||
Verify the early DB check is skipped when store_model_in_db is already True.
|
||
The DB query for the early check should not be called.
|
||
"""
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(return_value=None)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", True),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=True),
|
||
):
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
# The early DB check uses find_first with param_name="general_settings".
|
||
# When store_model_in_db is already True, the early check should be skipped.
|
||
# However, add_deployment may also call find_first.
|
||
# We just verify that store_model_in_db stays True and jobs are scheduled.
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
assert ps.store_model_in_db is True
|
||
assert mock_proxy_config.add_deployment.call_count == 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_model_in_db_db_failure_graceful(monkeypatch):
|
||
"""
|
||
Verify the early DB check handles DB failures gracefully
|
||
without crashing and keeps store_model_in_db as False.
|
||
"""
|
||
monkeypatch.delenv("STORE_MODEL_IN_DB", raising=False)
|
||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||
from litellm.proxy.utils import ProxyLogging
|
||
|
||
mock_prisma_client = MagicMock()
|
||
# Simulate DB failure
|
||
mock_prisma_client.db.litellm_config.find_first = AsyncMock(
|
||
side_effect=Exception("DB connection error")
|
||
)
|
||
|
||
mock_proxy_logging = MagicMock(spec=ProxyLogging)
|
||
mock_proxy_logging.slack_alerting_instance = MagicMock()
|
||
mock_proxy_config = AsyncMock()
|
||
|
||
with (
|
||
patch("litellm.proxy.proxy_server.proxy_config", mock_proxy_config),
|
||
patch("litellm.proxy.proxy_server.store_model_in_db", False),
|
||
patch("litellm.proxy.proxy_server.get_secret_bool", return_value=False),
|
||
):
|
||
# Should not raise an exception
|
||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||
general_settings={},
|
||
prisma_client=mock_prisma_client,
|
||
proxy_budget_rescheduler_min_time=1,
|
||
proxy_budget_rescheduler_max_time=2,
|
||
proxy_batch_write_at=5,
|
||
proxy_logging_obj=mock_proxy_logging,
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
# store_model_in_db should remain False
|
||
assert ps.store_model_in_db is False
|
||
|
||
# add_deployment should NOT have been called since store_model_in_db is False
|
||
mock_proxy_config.add_deployment.assert_not_called()
|
||
|
||
|
||
# =====================================================================
|
||
# Spend counter tests (v2 — Redis-backed spend counters)
|
||
# =====================================================================
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_reads_redis_first():
|
||
"""get_current_spend should prefer Redis over in-memory."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache()
|
||
|
||
# In-memory has stale value
|
||
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.30)
|
||
|
||
# Mock Redis with cross-pod authoritative value
|
||
mock_redis = AsyncMock()
|
||
mock_redis.async_get_cache = AsyncMock(return_value=0.90)
|
||
counter_cache.redis_cache = mock_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
result = await get_current_spend(
|
||
counter_key="spend:key:test",
|
||
fallback_spend=0.0,
|
||
)
|
||
# Should return Redis value (0.90), not in-memory (0.30)
|
||
assert result == 0.90
|
||
mock_redis.async_get_cache.assert_called_once_with(key="spend:key:test")
|
||
finally:
|
||
ps.spend_counter_cache = original
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_fallback_to_in_memory():
|
||
"""When Redis is not configured, get_current_spend uses in-memory."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache() # no redis_cache
|
||
counter_cache.in_memory_cache.set_cache(key="spend:key:test", value=0.50)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
result = await get_current_spend(
|
||
counter_key="spend:key:test",
|
||
fallback_spend=0.0,
|
||
)
|
||
assert result == 0.50
|
||
finally:
|
||
ps.spend_counter_cache = original
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_initializes_and_increments():
|
||
"""Counter should initialize from cached object spend, then increment.
|
||
|
||
Uses a pre-hashed token to match production: metadata["user_api_key"]
|
||
is always hashed by the auth flow before reaching the cost callback.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy._types import LiteLLM_VerificationTokenView, hash_token
|
||
|
||
key_cache = DualCache()
|
||
counter_cache = DualCache()
|
||
|
||
# In production, the auth flow hashes the raw key before it reaches
|
||
# the cost callback. Simulate that by passing the hashed token.
|
||
hashed_token = hash_token("sk-test-token-for-counter")
|
||
|
||
# Simulate a cached key object with existing spend from DB
|
||
cached_key = LiteLLM_VerificationTokenView(
|
||
token=hashed_token,
|
||
spend=5.0,
|
||
max_budget=10.0,
|
||
)
|
||
key_cache.in_memory_cache.set_cache(key=hashed_token, value=cached_key)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original_key_cache = ps.user_api_key_cache
|
||
original_counter_cache = ps.spend_counter_cache
|
||
ps.user_api_key_cache = key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
# Pass pre-hashed token (as the cost callback would in production)
|
||
await increment_spend_counters(
|
||
token=hashed_token,
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.50,
|
||
)
|
||
|
||
# Counter should be: base(5.0) + increment(0.50) = 5.50
|
||
counter = counter_cache.in_memory_cache.get_cache(
|
||
key=f"spend:key:{hashed_token}"
|
||
)
|
||
assert counter == 5.50
|
||
|
||
# Second increment — counter already exists, just increment
|
||
await increment_spend_counters(
|
||
token=hashed_token,
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
)
|
||
|
||
counter = counter_cache.in_memory_cache.get_cache(
|
||
key=f"spend:key:{hashed_token}"
|
||
)
|
||
assert counter == 5.75
|
||
finally:
|
||
ps.user_api_key_cache = original_key_cache
|
||
ps.spend_counter_cache = original_counter_cache
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_team_and_member():
|
||
"""Counter should track team and team member spend separately."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy._types import LiteLLM_TeamTable
|
||
|
||
key_cache = DualCache()
|
||
counter_cache = DualCache()
|
||
|
||
# Cached team object
|
||
team_obj = LiteLLM_TeamTable(team_id="team-1", spend=2.0)
|
||
key_cache.in_memory_cache.set_cache(key="team_id:team-1", value=team_obj)
|
||
|
||
# Cached team membership
|
||
key_cache.in_memory_cache.set_cache(
|
||
key="team_membership:user-1:team-1",
|
||
value={"user_id": "user-1", "team_id": "team-1", "spend": 1.0},
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
original_key_cache = ps.user_api_key_cache
|
||
original_counter_cache = ps.spend_counter_cache
|
||
ps.user_api_key_cache = key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
|
||
try:
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
await increment_spend_counters(
|
||
token=None,
|
||
team_id="team-1",
|
||
user_id="user-1",
|
||
response_cost=0.30,
|
||
)
|
||
|
||
team_counter = counter_cache.in_memory_cache.get_cache(key="spend:team:team-1")
|
||
assert team_counter == 2.30
|
||
|
||
member_counter = counter_cache.in_memory_cache.get_cache(
|
||
key="spend:team_member:user-1:team-1"
|
||
)
|
||
assert member_counter == 1.30
|
||
finally:
|
||
ps.user_api_key_cache = original_key_cache
|
||
ps.spend_counter_cache = original_counter_cache
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_and_increment_spend_counter_reseeds_from_db_on_counter_miss():
|
||
"""When the Redis counter is missing, the reseed path reads the
|
||
authoritative spend from the DB (not a stale cache), so the next
|
||
increment continues from the correct base value."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
|
||
counter_cache = DualCache()
|
||
recorded_increments: list = []
|
||
|
||
async def record_increment(key, value, ttl=None, **kwargs):
|
||
recorded_increments.append({"key": key, "value": value, "ttl": ttl})
|
||
return value
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_increment = AsyncMock(side_effect=record_increment)
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None) # counter missing
|
||
fake_redis.async_set_cache = AsyncMock(return_value=True) # SET NX wins
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# Prisma returns spend=42.0 (authoritative) while the stale cached
|
||
# value (would be read only if prisma is None) is 10.0. The counter
|
||
# must seed from 42, not 10.
|
||
db_row = MagicMock()
|
||
db_row.spend = 42.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
stale_cache = DualCache()
|
||
stale_team = MagicMock()
|
||
stale_team.spend = 10.0
|
||
stale_cache.in_memory_cache.set_cache(key="team_id:team-9", value=stale_team)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
|
||
|
||
orig_user, orig_counter, orig_prisma = (
|
||
ps.user_api_key_cache,
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
)
|
||
ps.user_api_key_cache = stale_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_spend_counter(
|
||
counter_key="spend:team:team-9",
|
||
source_cache_key="team_id:team-9",
|
||
increment=1.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||
where={"team_id": "team-9"}
|
||
)
|
||
# Seed uses SET NX with db_spend (42) — cross-pod safe, no INCR of 42.
|
||
# Only the per-request delta (1.5) goes through INCRBYFLOAT.
|
||
fake_redis.async_set_cache.assert_awaited_once_with(
|
||
key="spend:team:team-9", value=42.0, nx=True
|
||
)
|
||
writes = [(c["key"], c["value"]) for c in recorded_increments]
|
||
assert writes == [("spend:team:team-9", 1.5)]
|
||
finally:
|
||
ps.user_api_key_cache = orig_user
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_primary_spend_counter_redis_concurrent_seed_does_not_double_seed():
|
||
"""Two pods both observing a missing Redis counter must not both
|
||
INCRBYFLOAT the full DB spend. SpendCounterReseed.coalesced uses SET NX
|
||
so the loser reads the winner's value; final Redis = db_spend, not
|
||
2 * db_spend.
|
||
|
||
The per-counter asyncio.Lock is per-process, so it does NOT coordinate
|
||
across pods. We simulate two pods by patching _get_lock to return a
|
||
fresh lock per call (each "pod" has its own lock registry in real life).
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
counter_key = "spend:team:team-concurrent-seed"
|
||
redis_store: dict = {}
|
||
db_read_count = 0
|
||
set_results: list = []
|
||
get_after_set_count = 0
|
||
set_completed_count = 0
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
# Yield BEFORE the membership check so two concurrent callers
|
||
# interleave the way real atomic Redis SET NX does: the first
|
||
# to resume runs check + write atomically and wins; the second
|
||
# resumes after the key exists and loses. Yielding *after* the
|
||
# check would let both callers pass the empty-store check before
|
||
# either writes, so neither would ever lose.
|
||
await asyncio.sleep(0)
|
||
if nx and key in redis_store:
|
||
set_results.append(False)
|
||
return False
|
||
redis_store[key] = float(value)
|
||
set_results.append(True)
|
||
nonlocal set_completed_count
|
||
set_completed_count += 1
|
||
return True
|
||
|
||
async def redis_get_cache(key):
|
||
# Track reads that happen after at least one SET NX has completed
|
||
# — those are the loser-path fallback reads we want to verify.
|
||
if set_completed_count > 0:
|
||
nonlocal get_after_set_count
|
||
get_after_set_count += 1
|
||
return redis_store.get(key)
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
|
||
async def slow_find_unique(**_):
|
||
nonlocal db_read_count
|
||
db_read_count += 1
|
||
# Both pods read DB before either's SET NX lands.
|
||
await asyncio.sleep(0)
|
||
row = MagicMock()
|
||
row.spend = 506.0
|
||
return row
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
pod_a = DualCache()
|
||
pod_a.redis_cache = fake_redis
|
||
pod_b = DualCache()
|
||
pod_b.redis_cache = fake_redis
|
||
|
||
# Each "pod" has its own per-process lock registry. Patch _get_lock to
|
||
# always return a fresh lock so the two coalesced calls do not serialize
|
||
# via one in-process lock (which is what would happen across pods).
|
||
async def fresh_lock(_counter_key):
|
||
return asyncio.Lock()
|
||
|
||
with patch.object(SpendCounterReseed, "_get_lock", side_effect=fresh_lock):
|
||
results = await asyncio.gather(
|
||
SpendCounterReseed.coalesced(
|
||
prisma_client=fake_prisma,
|
||
spend_counter_cache=pod_a,
|
||
counter_key=counter_key,
|
||
),
|
||
SpendCounterReseed.coalesced(
|
||
prisma_client=fake_prisma,
|
||
spend_counter_cache=pod_b,
|
||
counter_key=counter_key,
|
||
),
|
||
)
|
||
|
||
assert all(r == 506.0 for r in results), results
|
||
assert redis_store[counter_key] == pytest.approx(506.0), redis_store
|
||
# Both pods read the DB and both attempted SET NX; exactly one wrote
|
||
# (winner) and one was rejected (loser).
|
||
assert db_read_count == 2
|
||
assert fake_redis.async_set_cache.await_count == 2
|
||
nx_writes = [
|
||
call
|
||
for call in fake_redis.async_set_cache.await_args_list
|
||
if call.kwargs.get("nx") is True
|
||
]
|
||
assert len(nx_writes) == 2
|
||
assert sorted(set_results) == [
|
||
False,
|
||
True,
|
||
], f"expected exactly one SET NX winner and one loser, got {set_results}"
|
||
# Loser path executed: after the winner's SET NX returned True, the
|
||
# losing coalesced() call falls back to async_get_cache to read the
|
||
# winner's value rather than re-seeding.
|
||
assert (
|
||
get_after_set_count >= 1
|
||
), "loser branch (else: read back winner's value) was never exercised"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_spend_from_db_user_and_org_prefixes():
|
||
"""User and org counters reseed from their own DB tables.
|
||
|
||
End-user and tag counters use the already fetched auth objects passed as
|
||
fallback_spend, so this reseed helper must not add extra per-request DB
|
||
reads for them.
|
||
"""
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
user_row = MagicMock()
|
||
user_row.spend = 17.0
|
||
org_row = MagicMock()
|
||
org_row.spend = 305.0
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_usertable.find_unique = AsyncMock(return_value=user_row)
|
||
fake_prisma.db.litellm_endusertable.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_tagtable.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_organizationtable.find_unique = AsyncMock(
|
||
return_value=org_row
|
||
)
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:user:alice") == 17.0
|
||
fake_prisma.db.litellm_usertable.find_unique.assert_awaited_once_with(
|
||
where={"user_id": "alice"}
|
||
)
|
||
|
||
assert (
|
||
await SpendCounterReseed.from_db(
|
||
fake_prisma,
|
||
"spend:end_user:customer-1",
|
||
)
|
||
is None
|
||
)
|
||
fake_prisma.db.litellm_endusertable.find_unique.assert_not_awaited()
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:tag:paid-tag") is None
|
||
fake_prisma.db.litellm_tagtable.find_unique.assert_not_awaited()
|
||
|
||
assert await SpendCounterReseed.from_db(fake_prisma, "spend:org:acme") == 305.0
|
||
fake_prisma.db.litellm_organizationtable.find_unique.assert_awaited_once_with(
|
||
where={"organization_id": "acme"}
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_spend_from_db_skips_window_variant_keys():
|
||
"""Window counters (spend:*:window:{duration}) share prefixes with
|
||
primary counters but don't correspond to a DB row. The guard must
|
||
short-circuit without querying the DB."""
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_verificationtoken.find_unique = AsyncMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock()
|
||
|
||
assert (
|
||
await SpendCounterReseed.from_db(fake_prisma, "spend:key:sk-abc:window:1h")
|
||
is None
|
||
)
|
||
assert (
|
||
await SpendCounterReseed.from_db(fake_prisma, "spend:team:team-1:window:1d")
|
||
is None
|
||
)
|
||
fake_prisma.db.litellm_verificationtoken.find_unique.assert_not_awaited()
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_not_awaited()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_reseeds_from_spend_logs_on_counter_miss():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[{"api_key": "key-window", "_sum": {"spend": 2.25}}]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key="spend:key:key-window:window:1h",
|
||
entity_type="Key",
|
||
entity_id="key-window",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
|
||
by=["api_key"],
|
||
where={"api_key": "key-window", "startTime": {"gte": window_start}},
|
||
sum={"spend": True},
|
||
)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-window:window:1h"
|
||
) == pytest.approx(2.75)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_init_spend_counter_redis_clean_miss_skips_stale_in_memory():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team:team-stale-local"
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=10.0)
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_row = MagicMock()
|
||
db_row.spend = 42.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teamtable.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma, orig_user = (
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
ps.user_api_key_cache,
|
||
)
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
await _init_and_increment_spend_counter(
|
||
counter_key=counter_key,
|
||
source_cache_key="team_id:team-stale-local",
|
||
increment=1.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_teamtable.find_unique.assert_awaited_once_with(
|
||
where={"team_id": "team-stale-local"}
|
||
)
|
||
# Seed via SET NX (42) + delta via INCRBYFLOAT (1.5) = 43.5.
|
||
assert redis_store[counter_key] == pytest.approx(43.5)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(43.5)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_redis_clean_miss_skips_stale_in_memory():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-stale-local:window:1h"
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=100.0)
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, **_):
|
||
if key in redis_store:
|
||
return False
|
||
redis_store[key] = value
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[{"api_key": "key-window-stale-local", "_sum": {"spend": 2.25}}]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-stale-local",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_prisma.db.litellm_spendlogs.group_by.assert_awaited_once_with(
|
||
by=["api_key"],
|
||
where={
|
||
"api_key": "key-window-stale-local",
|
||
"startTime": {"gte": window_start},
|
||
},
|
||
sum={"spend": True},
|
||
)
|
||
assert redis_store[counter_key] == pytest.approx(2.75)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(2.75)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_redis_concurrent_seed_does_not_double_seed():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-concurrent-seed:window:1h"
|
||
window_start = datetime.now(timezone.utc) - timedelta(hours=1)
|
||
redis_store = {counter_key: 2.75}
|
||
redis_reads = 0
|
||
|
||
async def redis_get_cache(key):
|
||
nonlocal redis_reads
|
||
redis_reads += 1
|
||
if redis_reads <= 2:
|
||
return None
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get_cache)
|
||
fake_redis.async_set_cache = AsyncMock(return_value=False)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_spendlogs.group_by = AsyncMock(
|
||
return_value=[
|
||
{"api_key": "key-window-concurrent-seed", "_sum": {"spend": 2.25}}
|
||
]
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-concurrent-seed",
|
||
window_start=window_start,
|
||
increment=0.5,
|
||
)
|
||
|
||
fake_redis.async_set_cache.assert_awaited_once_with(
|
||
key=counter_key,
|
||
value=2.25,
|
||
nx=True,
|
||
)
|
||
assert redis_store[counter_key] == pytest.approx(3.25)
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key=counter_key
|
||
) == pytest.approx(3.25)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_skips_invalid_window_start():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _init_and_increment_window_spend_counter
|
||
|
||
counter_cache = DualCache()
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
await _init_and_increment_window_spend_counter(
|
||
counter_key="spend:key:key-invalid-window:window:not-a-duration",
|
||
entity_type="Key",
|
||
entity_id="key-invalid-window",
|
||
window_start=None,
|
||
increment=0.5,
|
||
)
|
||
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-invalid-window:window:not-a-duration"
|
||
)
|
||
is None
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_spend_counter_does_not_seed_zero_when_db_unavailable():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _ensure_window_spend_counter_initialized
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:key:key-window-db-unavailable:window:1h"
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = None
|
||
try:
|
||
initialized = await _ensure_window_spend_counter_initialized(
|
||
counter_key=counter_key,
|
||
entity_type="Key",
|
||
entity_id="key-window-db-unavailable",
|
||
window_start=datetime.now(timezone.utc) - timedelta(hours=1),
|
||
)
|
||
|
||
assert initialized is False
|
||
assert counter_cache.in_memory_cache.get_cache(key=counter_key) is None
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_finalizes_after_unreserved_increments():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(
|
||
key="spend:key:key-finalize-after-increments",
|
||
value=0.5,
|
||
)
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-finalize-after-increments",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-finalize-after-increments",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
incremented_counters = []
|
||
|
||
async def assert_reservation_not_finalized_yet(**kwargs):
|
||
assert budget_reservation["finalized"] is False
|
||
incremented_counters.append(kwargs["counter_key"])
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_user = ps.spend_counter_cache, ps.user_api_key_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server._init_and_increment_spend_counter",
|
||
new=AsyncMock(side_effect=assert_reservation_not_finalized_yet),
|
||
):
|
||
await increment_spend_counters(
|
||
token="key-finalize-after-increments",
|
||
team_id="team-finalize-after-increments",
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
assert incremented_counters == ["spend:team:team-finalize-after-increments"]
|
||
assert budget_reservation["finalized"] is True
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-finalize-after-increments"
|
||
) == pytest.approx(0.25)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_finalizes_none_cost_reservation():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(
|
||
key="spend:key:key-finalize-none-cost",
|
||
value=0.5,
|
||
)
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-finalize-none-cost",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-finalize-none-cost",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
await increment_spend_counters(
|
||
token="key-finalize-none-cost",
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=None,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
assert budget_reservation["finalized"] is True
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-finalize-none-cost"
|
||
) == pytest.approx(0.0)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counters_falls_back_to_direct_increment_on_bad_reserved_counter():
|
||
"""When the reservation reconcile fails, the reserved counters are
|
||
invalidated and the actual response cost must still be written via the
|
||
direct increment fallback. Leaving the counter at ``None`` lets the next
|
||
request reseed a stale value from the DB and silently stops budget gating,
|
||
which is the bug this fix addresses."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import increment_spend_counters
|
||
|
||
counter_cache = DualCache()
|
||
budget_reservation = {
|
||
"reserved_cost": 0.5,
|
||
"entries": [
|
||
{
|
||
"counter_key": "spend:key:key-bad-reserved-counter",
|
||
"entity_type": "Key",
|
||
"entity_id": "key-bad-reserved-counter",
|
||
"reserved_cost": 0.5,
|
||
"applied_adjustment": 0.0,
|
||
}
|
||
],
|
||
"finalized": False,
|
||
}
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
with patch(
|
||
"litellm.proxy.proxy_server.verbose_proxy_logger.warning"
|
||
) as mock_warning:
|
||
await increment_spend_counters(
|
||
token="key-bad-reserved-counter",
|
||
team_id=None,
|
||
user_id=None,
|
||
response_cost=0.25,
|
||
budget_reservation=budget_reservation,
|
||
)
|
||
|
||
mock_warning.assert_called_once()
|
||
assert budget_reservation["finalized"] is True
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(
|
||
key="spend:key:key-bad-reserved-counter"
|
||
)
|
||
== 0.25
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_increment_spend_counter_invalidates_stale_cache_on_redis_failure():
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import _increment_spend_counter_cache
|
||
|
||
counter_cache = DualCache()
|
||
counter_cache.in_memory_cache.set_cache(key="spend:team:redis-fail", value=4.0)
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_increment = AsyncMock(side_effect=RuntimeError("redis down"))
|
||
fake_redis.async_delete_cache = AsyncMock()
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter = ps.spend_counter_cache
|
||
ps.spend_counter_cache = counter_cache
|
||
try:
|
||
with pytest.raises(RuntimeError):
|
||
await _increment_spend_counter_cache(
|
||
counter_key="spend:team:redis-fail",
|
||
increment=0.5,
|
||
)
|
||
|
||
assert (
|
||
counter_cache.in_memory_cache.get_cache(key="spend:team:redis-fail") is None
|
||
)
|
||
fake_redis.async_delete_cache.assert_awaited_once_with(
|
||
key="spend:team:redis-fail"
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_reseeds_from_db_when_counter_missing():
|
||
"""
|
||
When both the Redis and in-memory counters are missing, the enforcement
|
||
read path must reseed from the authoritative DB, not fall back to the
|
||
caller-supplied stale value. Otherwise, every Redis TTL expiry lets a
|
||
request through against a stale in-process `team_membership.spend`.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
recorded_seeds: list = []
|
||
|
||
async def record_set_cache(key, value, nx=False, **kwargs):
|
||
recorded_seeds.append({"key": key, "value": value, "nx": nx})
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=record_set_cache)
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# DB has authoritative spend=362.0; caller hands us stale fallback=30.0
|
||
# (the in-process team_membership.spend that hasn't caught up to DB).
|
||
db_row = MagicMock()
|
||
db_row.spend = 362.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-1",
|
||
fallback_spend=30.0,
|
||
)
|
||
assert spend == 362.0, (
|
||
f"expected DB reseed to return 362.0, got {spend} "
|
||
f"(fallback would have returned 30.0 and caused bypass)"
|
||
)
|
||
# Counter warmed via SET NX so subsequent reads are fast.
|
||
assert ("spend:team_member:user-1:team-1", 362.0, True) in [
|
||
(s["key"], s["value"], s["nx"]) for s in recorded_seeds
|
||
]
|
||
assert counter_cache.in_memory_cache.get_cache(
|
||
key="spend:team_member:user-1:team-1"
|
||
) == pytest.approx(362.0)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_uses_fallback_when_db_unavailable():
|
||
"""
|
||
If prisma is unavailable and both counters are missing, the read path
|
||
must degrade to the caller-supplied fallback rather than raising.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = None # simulate prisma unavailable
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-1",
|
||
fallback_spend=15.5,
|
||
)
|
||
assert spend == 15.5
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_coalesces_concurrent_reseeds():
|
||
"""
|
||
When several concurrent calls hit a cold counter on the same pod,
|
||
only one DB query should fire. The rest should wait for the lock,
|
||
re-check the warmed counter, and return without hitting the DB.
|
||
"""
|
||
import asyncio as _asyncio
|
||
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-coalesce"
|
||
|
||
# Track DB query calls and inject a small delay so the concurrent
|
||
# callers actually overlap in the lock-acquire window.
|
||
db_call_count = 0
|
||
|
||
async def slow_find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
await _asyncio.sleep(0.05)
|
||
row = MagicMock()
|
||
row.spend = 100.0
|
||
return row
|
||
|
||
fake_redis = AsyncMock()
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
results = await _asyncio.gather(
|
||
*[
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
for _ in range(5)
|
||
]
|
||
)
|
||
assert results == [100.0] * 5, f"all callers should see DB value, got {results}"
|
||
assert (
|
||
db_call_count == 1
|
||
), f"expected exactly 1 DB query for 5 concurrent reseeds, got {db_call_count}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_uses_db_zero_over_stale_fallback():
|
||
"""
|
||
When DB returns spend=0 (e.g. just after a budget period reset), the
|
||
authoritative DB value must win over a stale non-zero fallback. The
|
||
fallback in production is the in-process team_membership.spend, which
|
||
can still hold the pre-reset value across pods.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_row = MagicMock()
|
||
db_row.spend = 0.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(
|
||
counter_key="spend:team_member:user-1:team-after-reset",
|
||
fallback_spend=42.0,
|
||
)
|
||
assert (
|
||
spend == 0.0
|
||
), f"DB authoritative 0 must override stale fallback 42, got {spend}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrent_read_and_write_paths_share_one_db_query():
|
||
"""
|
||
The read path (`get_current_spend`) and the write path
|
||
(`_init_and_increment_spend_counter`) both reseed cold counters from
|
||
the DB. They must share the per-counter lock so a concurrent pre-call
|
||
enforcement read and post-call increment for the same counter collapse
|
||
to one DB query, not two.
|
||
"""
|
||
import asyncio as _asyncio
|
||
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import (
|
||
_init_and_increment_spend_counter,
|
||
get_current_spend,
|
||
)
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-cross-path"
|
||
|
||
db_call_count = 0
|
||
|
||
async def slow_find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
await _asyncio.sleep(0.05)
|
||
row = MagicMock()
|
||
row.spend = 50.0
|
||
return row
|
||
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=slow_find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma, orig_user = (
|
||
ps.spend_counter_cache,
|
||
ps.prisma_client,
|
||
ps.user_api_key_cache,
|
||
)
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
ps.user_api_key_cache = DualCache()
|
||
try:
|
||
results = await _asyncio.gather(
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
|
||
_init_and_increment_spend_counter(
|
||
counter_key=counter_key,
|
||
source_cache_key="ignored",
|
||
increment=1.5,
|
||
),
|
||
get_current_spend(counter_key=counter_key, fallback_spend=0.0),
|
||
)
|
||
assert (
|
||
db_call_count == 1
|
||
), f"expected 1 DB query for concurrent read+write+read, got {db_call_count}"
|
||
# Read-path callers see the warmed counter; the write path's
|
||
# increment may or may not have landed by then, so accept either
|
||
# the seeded value or seeded+increment.
|
||
assert results[0] in (50.0, 51.5), f"got {results[0]}"
|
||
assert results[2] in (50.0, 51.5), f"got {results[2]}"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
ps.user_api_key_cache = orig_user
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_locks_dict_is_bounded():
|
||
"""
|
||
`SpendCounterReseed._locks` is an LRU bounded at
|
||
`SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE` to prevent unbounded growth in
|
||
long-lived deployments with high counter-key churn. Inserting more
|
||
than the cap evicts the oldest entries.
|
||
"""
|
||
import litellm.constants as constants
|
||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||
|
||
orig_locks = SpendCounterReseed._locks.copy()
|
||
SpendCounterReseed._locks.clear()
|
||
orig_max = constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
|
||
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
|
||
# The class reads the constant via module-level import, so patch the
|
||
# module-level name on the spend_counter_reseed module too.
|
||
import litellm.proxy.db.spend_counter_reseed as scr
|
||
|
||
orig_module_max = scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE
|
||
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = 5
|
||
try:
|
||
for i in range(7):
|
||
await SpendCounterReseed._get_lock(f"spend:key:test-key-{i}")
|
||
assert (
|
||
len(SpendCounterReseed._locks) == 5
|
||
), f"got {len(SpendCounterReseed._locks)}"
|
||
# Oldest two evicted
|
||
assert "spend:key:test-key-0" not in SpendCounterReseed._locks
|
||
assert "spend:key:test-key-1" not in SpendCounterReseed._locks
|
||
# Most recent retained
|
||
assert "spend:key:test-key-6" in SpendCounterReseed._locks
|
||
finally:
|
||
constants.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_max
|
||
scr.SPEND_COUNTER_RESEED_LOCKS_MAX_SIZE = orig_module_max
|
||
SpendCounterReseed._locks.clear()
|
||
SpendCounterReseed._locks.update(orig_locks)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reseed_warms_cache_even_on_zero_db_spend():
|
||
"""
|
||
When DB returns 0.0 (fresh entity / just after reset), the cache must
|
||
still be warmed so subsequent reads hit the cache instead of issuing
|
||
another DB query. Skipping the warm causes O(requests) DB load on
|
||
zero-spend entities.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-zero-warm"
|
||
redis_store: dict = {}
|
||
|
||
async def redis_get(key, **_):
|
||
return redis_store.get(key)
|
||
|
||
async def redis_increment(key, value, **_):
|
||
redis_store[key] = (redis_store.get(key) or 0.0) + value
|
||
return redis_store[key]
|
||
|
||
async def redis_set_cache(key, value, nx=False, **_):
|
||
if nx and key in redis_store:
|
||
return False
|
||
redis_store[key] = float(value)
|
||
return True
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=redis_get)
|
||
fake_redis.async_increment = AsyncMock(side_effect=redis_increment)
|
||
fake_redis.async_set_cache = AsyncMock(side_effect=redis_set_cache)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
db_call_count = 0
|
||
|
||
async def find_unique(**kwargs):
|
||
nonlocal db_call_count
|
||
db_call_count += 1
|
||
row = MagicMock()
|
||
row.spend = 0.0
|
||
return row
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
side_effect=find_unique
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
# First call: cold cache, hits DB, returns 0.
|
||
spend1 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
# Second call: cache should be warmed at 0, no second DB query.
|
||
spend2 = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend1 == 0.0 and spend2 == 0.0
|
||
assert (
|
||
db_call_count == 1
|
||
), f"second read should hit warmed cache, got {db_call_count} DB queries"
|
||
assert redis_store.get(counter_key) == 0.0, "cache must be warmed at 0"
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# /config/update — critical paths only.
|
||
#
|
||
# These exercise the four behaviors that broke or changed in the rewrite of
|
||
# update_config (litellm/proxy/proxy_server.py): targeted per-section writes,
|
||
# the removal of the store_model_in_db gate, env var encryption, and the
|
||
# success_callback / litellm_settings merge semantics. All other branches
|
||
# (auth, missing-DB, slack auto-enable, router_settings merge) are covered
|
||
# implicitly or by upstream tests.
|
||
# -----------------------------------------------------------------------------
|
||
|
||
|
||
class _FakeRow:
|
||
def __init__(self, param_name, param_value):
|
||
self.param_name = param_name
|
||
self.param_value = param_value
|
||
|
||
|
||
class _FakeLitellmConfig:
|
||
def __init__(self, initial_rows=None):
|
||
self.rows = dict(initial_rows or {})
|
||
self.upsert_calls: list = []
|
||
self.find_first = AsyncMock(side_effect=self._find_first)
|
||
self.upsert = AsyncMock(side_effect=self._upsert)
|
||
|
||
async def _find_first(self, where=None):
|
||
if where and "param_name" in where:
|
||
name = where["param_name"]
|
||
if name in self.rows:
|
||
return _FakeRow(name, self.rows[name])
|
||
return None
|
||
|
||
async def _upsert(self, where=None, data=None):
|
||
name = where["param_name"]
|
||
raw = data["update"]["param_value"]
|
||
value = json.loads(raw) if isinstance(raw, str) else raw
|
||
self.rows[name] = value
|
||
self.upsert_calls.append((name, value))
|
||
|
||
|
||
class _FakePrismaClient:
|
||
def __init__(self, initial_rows=None):
|
||
self.db = mock.MagicMock()
|
||
self.db.litellm_config = _FakeLitellmConfig(initial_rows=initial_rows)
|
||
self.jsonify_object = lambda obj: obj
|
||
|
||
|
||
@pytest.fixture
|
||
def _update_config_setup(monkeypatch):
|
||
"""Install fakes for the /config/update endpoint and return (client, prisma)."""
|
||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth as auth_dep
|
||
|
||
def _install(initial_rows=None, store_model_in_db=True):
|
||
prisma = _FakePrismaClient(initial_rows=initial_rows)
|
||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.store_model_in_db", store_model_in_db
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.encrypt_value_helper",
|
||
lambda value, **_: f"enc:{value}",
|
||
)
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.invalidate_config_param",
|
||
AsyncMock(return_value=None),
|
||
)
|
||
from litellm.proxy.proxy_server import proxy_config as real_proxy_config
|
||
|
||
monkeypatch.setattr(
|
||
real_proxy_config, "add_deployment", AsyncMock(return_value=None)
|
||
)
|
||
|
||
original_overrides = app.dependency_overrides.copy()
|
||
app.dependency_overrides[auth_dep] = lambda: UserAPIKeyAuth(
|
||
user_id="test_admin",
|
||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||
api_key="sk-1234",
|
||
)
|
||
client = TestClient(app)
|
||
|
||
def _restore():
|
||
app.dependency_overrides = original_overrides
|
||
|
||
return client, prisma, _restore
|
||
|
||
return _install
|
||
|
||
|
||
def test_update_config_writes_only_sent_section(_update_config_setup):
|
||
"""A request that only touches general_settings must not write any other
|
||
section row, and must leave previously-written rows byte-identical."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={
|
||
"litellm_settings": {"drop_params": True},
|
||
"environment_variables": {"FOO": "enc:bar"},
|
||
}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"general_settings": {"store_prompts_in_spend_logs": True}},
|
||
)
|
||
assert resp.status_code == 200
|
||
written = {name for name, _ in prisma.db.litellm_config.upsert_calls}
|
||
assert written == {"general_settings"}
|
||
assert prisma.db.litellm_config.rows["litellm_settings"] == {
|
||
"drop_params": True
|
||
}
|
||
assert prisma.db.litellm_config.rows["environment_variables"] == {
|
||
"FOO": "enc:bar"
|
||
}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_env_var_round_trip_not_double_encrypted(
|
||
_update_config_setup, monkeypatch
|
||
):
|
||
"""Endpoint-level regression for the /config/update double-encryption bug.
|
||
|
||
The Admin UI reads config back via /get/config/callbacks (which returns
|
||
the stored, still-encrypted value) and re-POSTs it on the next save. The
|
||
handler must NOT stack a second encryption layer on the re-submitted
|
||
ciphertext, and must leave untouched keys byte-identical.
|
||
|
||
Uses an invertible fake encrypt/decrypt pair ("enc:" prefix) so the
|
||
decrypt-then-encrypt chokepoint round-trips faithfully. On the pre-fix
|
||
code this stored "enc:enc:..."; the assertions below would fail there.
|
||
"""
|
||
|
||
def _fake_decrypt(
|
||
value, key=None, exception_type="error", return_original_value=False
|
||
):
|
||
if isinstance(value, str) and value.startswith("enc:"):
|
||
return value[len("enc:") :]
|
||
return value if return_original_value else None
|
||
|
||
monkeypatch.setattr(
|
||
"litellm.proxy.proxy_server.decrypt_value_helper", _fake_decrypt
|
||
)
|
||
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={"environment_variables": {"PREEXISTING_KEY": "enc:keepme"}}
|
||
)
|
||
try:
|
||
# First write: plaintext in -> single-encrypted at rest.
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"environment_variables": {"LANGFUSE_SECRET_KEY": "sk-secret"}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
|
||
|
||
# UI round-trip: re-POST the stored ciphertext (no field change).
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={
|
||
"environment_variables": {
|
||
"LANGFUSE_SECRET_KEY": stored["LANGFUSE_SECRET_KEY"]
|
||
}
|
||
},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
|
||
# The bug: this would be "enc:enc:sk-secret". The fix keeps it single.
|
||
assert stored["LANGFUSE_SECRET_KEY"] == "enc:sk-secret"
|
||
assert (
|
||
_fake_decrypt(stored["LANGFUSE_SECRET_KEY"], return_original_value=True)
|
||
== "sk-secret"
|
||
)
|
||
|
||
# Untouched key preserved byte-for-byte (only sent keys rewritten).
|
||
assert stored["PREEXISTING_KEY"] == "enc:keepme"
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_can_flip_store_model_in_db_when_currently_false(
|
||
_update_config_setup,
|
||
):
|
||
"""The endpoint used to refuse all writes when store_model_in_db was
|
||
False, blocking the very request that would flip it to True."""
|
||
client, prisma, restore = _update_config_setup(store_model_in_db=False)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update", json={"general_settings": {"store_model_in_db": True}}
|
||
)
|
||
assert resp.status_code == 200
|
||
assert (
|
||
prisma.db.litellm_config.rows["general_settings"]["store_model_in_db"]
|
||
is True
|
||
)
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_environment_variables_encrypted_before_write(
|
||
_update_config_setup,
|
||
):
|
||
"""env var values must be encrypted before they hit the DB row."""
|
||
client, prisma, restore = _update_config_setup()
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"environment_variables": {"OPENAI_API_KEY": "sk-secret"}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["environment_variables"]
|
||
assert stored == {"OPENAI_API_KEY": "enc:sk-secret"}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_litellm_settings_request_wins_for_non_callback_keys(
|
||
_update_config_setup,
|
||
):
|
||
"""Sending {"drop_params": False} when the row holds drop_params: True
|
||
must persist False (request wins). Untouched keys preserved."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={
|
||
"litellm_settings": {"drop_params": True, "set_verbose": True},
|
||
}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update", json={"litellm_settings": {"drop_params": False}}
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["litellm_settings"]
|
||
assert stored["drop_params"] is False
|
||
assert stored["set_verbose"] is True
|
||
finally:
|
||
restore()
|
||
|
||
|
||
def test_update_config_success_callback_normalizes_existing_mixed_case(
|
||
_update_config_setup,
|
||
):
|
||
"""Existing mixed-case callback names (written elsewhere) must be
|
||
normalized to lowercase before union, otherwise the union dedup misses
|
||
against the lowercase incoming entry and delete_callback (lowercase
|
||
lookup) cannot find the original."""
|
||
client, prisma, restore = _update_config_setup(
|
||
initial_rows={"litellm_settings": {"success_callback": ["Langfuse", "SQS"]}}
|
||
)
|
||
try:
|
||
resp = client.post(
|
||
"/config/update",
|
||
json={"litellm_settings": {"success_callback": ["langfuse"]}},
|
||
)
|
||
assert resp.status_code == 200
|
||
stored = prisma.db.litellm_config.rows["litellm_settings"]["success_callback"]
|
||
assert set(stored) == {"langfuse", "sqs"}
|
||
finally:
|
||
restore()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Lazy feature loading (LazyFeatureMiddleware) — verifies that optional
|
||
# routers are NOT imported at module load and ARE imported on first request
|
||
# to a matching path prefix. The same module isn't re-imported on subsequent
|
||
# requests.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLazyFeatureRegistry:
|
||
"""Sanity checks on the registry shape — guards against accidental edits."""
|
||
|
||
def test_registry_entries_have_required_fields(self):
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES, LazyFeature
|
||
|
||
assert len(LAZY_FEATURES) > 0
|
||
for feat in LAZY_FEATURES:
|
||
assert isinstance(feat, LazyFeature)
|
||
assert feat.name
|
||
assert feat.module_path
|
||
assert feat.path_prefixes
|
||
assert all(p.startswith("/") for p in feat.path_prefixes)
|
||
assert callable(feat.register_fn)
|
||
|
||
def test_registry_names_unique(self):
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES
|
||
|
||
names = [f.name for f in LAZY_FEATURES]
|
||
assert len(names) == len(set(names)), "duplicate feature names"
|
||
|
||
def test_matches_covers_prefix_and_suffix(self):
|
||
"""``matches`` is the single matcher shared by the middleware (request
|
||
paths) and the warm endpoint (registered route paths), so a route that
|
||
only matches via suffix — e.g. ``/v1/a2a/{id}/message/send`` against the
|
||
``/a2a`` prefix — must still be claimed by the feature."""
|
||
from litellm.proxy._lazy_features import LazyFeature
|
||
|
||
feat = LazyFeature(
|
||
name="a2a",
|
||
module_path="json",
|
||
path_prefixes=("/a2a",),
|
||
path_suffixes=("/message/send",),
|
||
)
|
||
assert feat.matches("/a2a/abc/message/send")
|
||
assert feat.matches("/v1/a2a/abc/message/send")
|
||
assert feat.matches("/a2a/abc/.well-known/agent-card.json")
|
||
assert not feat.matches("/v1/a2a/discover")
|
||
assert not feat.matches("/unrelated")
|
||
|
||
|
||
class TestLazyFeaturesNotImportedAtStartup:
|
||
"""
|
||
The whole point of the refactor: gated feature modules must NOT be
|
||
present in `sys.modules` immediately after `proxy_server` imports.
|
||
"""
|
||
|
||
def test_heavy_modules_absent_at_startup(self):
|
||
# Static scan of proxy_server.py source — catches any top-level
|
||
# `from <lazy_module> import` that would defeat lazy loading.
|
||
# Importing proxy_server in a subprocess and diffing sys.modules
|
||
# would also work, but takes 60-120 s and flakes on slow CI runners.
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from litellm.proxy._lazy_features import LAZY_FEATURES
|
||
|
||
proxy_server_src = (
|
||
Path(__file__).resolve().parents[3] / "litellm/proxy/proxy_server.py"
|
||
).read_text()
|
||
|
||
leaks = []
|
||
for feat in LAZY_FEATURES:
|
||
# Anchor at column 0 — indented imports inside function bodies
|
||
# are fine (deferred until the function runs).
|
||
pattern = (
|
||
rf"^(from\s+{re.escape(feat.module_path)}\s+import|"
|
||
rf"import\s+{re.escape(feat.module_path)})"
|
||
)
|
||
if re.search(pattern, proxy_server_src, re.MULTILINE):
|
||
leaks.append(feat.module_path)
|
||
|
||
assert not leaks, (
|
||
"proxy_server.py top-level imports a lazy feature module — these "
|
||
f"should be loaded via LazyFeatureMiddleware: {leaks}"
|
||
)
|
||
|
||
|
||
class TestLazyFeatureMiddleware:
|
||
"""Behavior of the middleware itself, exercised in isolation."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_first_request_triggers_load_subsequent_does_not(self):
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
loads = []
|
||
|
||
def fake_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name="dummy",
|
||
module_path="json", # any always-importable stdlib module
|
||
path_prefixes=("/dummy",),
|
||
register_fn=fake_register,
|
||
)
|
||
|
||
# Build a minimal ASGI receiver to satisfy the middleware contract
|
||
async def downstream(scope, receive, send):
|
||
# echo back; no-op handler
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
# First request matching the prefix triggers register
|
||
await mw(
|
||
{"type": "http", "path": "/dummy/x", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"]
|
||
|
||
# Second matching request must NOT re-register
|
||
sent.clear()
|
||
await mw(
|
||
{"type": "http", "path": "/dummy/y", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"], "register_fn called twice for the same feature"
|
||
|
||
# Non-matching path must not trigger anything
|
||
await mw(
|
||
{"type": "http", "path": "/unrelated", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert loads == ["json"]
|
||
|
||
@pytest.mark.asyncio
|
||
@pytest.mark.parametrize(
|
||
"server_root_path,request_path,should_load,case",
|
||
[
|
||
# SERVER_ROOT_PATH set: incoming path includes prefix → strip and match.
|
||
("/api/v1", "/api/v1/dummy/x", True, "root_path strip + match"),
|
||
# Trailing-slash env var must be normalized.
|
||
("/api/v1/", "/api/v1/dummy/x", True, "trailing-slash env normalization"),
|
||
# Reverse proxy already stripped the prefix → original path still matches.
|
||
("/api/v1", "/dummy/x", True, "pre-stripped path still loads"),
|
||
# No SERVER_ROOT_PATH set → unchanged behavior.
|
||
("", "/dummy/x", True, "no root path"),
|
||
# SERVER_ROOT_PATH=/ must be a no-op (not strip every leading slash).
|
||
("/", "/dummy/x", True, "root_path='/' is no-op"),
|
||
# Boundary check: /apiv2 must not match root /api.
|
||
("/api", "/apiv2/foo", False, "boundary check prevents false match"),
|
||
# Genuine non-match under root_path.
|
||
("/api/v1", "/api/v1/unrelated", False, "unrelated path under root"),
|
||
],
|
||
)
|
||
async def test_root_path_handling(
|
||
self, monkeypatch, server_root_path, request_path, should_load, case
|
||
):
|
||
"""
|
||
The middleware must strip SERVER_ROOT_PATH before prefix-matching so
|
||
lazy features load under deployments that set a server root path,
|
||
while handling boundary, trailing-slash, and reverse-proxy edge cases
|
||
correctly.
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
monkeypatch.setenv("SERVER_ROOT_PATH", server_root_path)
|
||
|
||
loads = []
|
||
|
||
def fake_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name=f"dummy_{case}",
|
||
module_path="json",
|
||
path_prefixes=("/dummy",),
|
||
register_fn=fake_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
async def send(message):
|
||
pass
|
||
|
||
await mw(
|
||
{
|
||
"type": "http",
|
||
"path": request_path,
|
||
"method": "GET",
|
||
"headers": [],
|
||
},
|
||
receive,
|
||
send,
|
||
)
|
||
if should_load:
|
||
assert loads == ["json"], f"{case}: expected feature to load"
|
||
else:
|
||
assert loads == [], f"{case}: feature must not load"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrent_first_requests_only_register_once(self):
|
||
"""
|
||
Two requests to the same prefix arriving in parallel must result in
|
||
exactly one `register_fn` invocation — the lock prevents the import +
|
||
register from racing with itself.
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
loads = []
|
||
|
||
def slow_register(app, module):
|
||
loads.append(getattr(module, "__name__", "?"))
|
||
|
||
feat = LazyFeature(
|
||
name="dummy_concurrent",
|
||
module_path="json",
|
||
path_prefixes=("/dummy_c",),
|
||
register_fn=slow_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
async def hit():
|
||
await mw(
|
||
{
|
||
"type": "http",
|
||
"path": "/dummy_c/x",
|
||
"method": "GET",
|
||
"headers": [],
|
||
},
|
||
receive,
|
||
send,
|
||
)
|
||
|
||
await asyncio.gather(hit(), hit(), hit(), hit(), hit())
|
||
assert loads == [
|
||
"json"
|
||
], f"expected one registration despite concurrent first hits, got {loads}"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failing_import_does_not_loop(self):
|
||
"""
|
||
If a feature's module can't be imported, the middleware should mark it
|
||
loaded anyway so subsequent requests don't repeatedly retry the failing
|
||
import (which would amplify the cost on every request).
|
||
"""
|
||
from fastapi import FastAPI
|
||
|
||
from litellm.proxy._lazy_features import (
|
||
LazyFeature,
|
||
LazyFeatureMiddleware,
|
||
)
|
||
|
||
attempts = []
|
||
|
||
def fail_register(app, module):
|
||
attempts.append("called")
|
||
raise RuntimeError("boom")
|
||
|
||
feat = LazyFeature(
|
||
name="failing",
|
||
module_path="json",
|
||
path_prefixes=("/fail",),
|
||
register_fn=fail_register,
|
||
)
|
||
|
||
async def downstream(scope, receive, send):
|
||
await send({"type": "http.response.start", "status": 200, "headers": []})
|
||
await send({"type": "http.response.body", "body": b""})
|
||
|
||
target_app = FastAPI()
|
||
mw = LazyFeatureMiddleware(downstream, fastapi_app=target_app, features=(feat,))
|
||
|
||
async def receive():
|
||
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
||
sent: list = []
|
||
|
||
async def send(message):
|
||
sent.append(message)
|
||
|
||
for _ in range(3):
|
||
await mw(
|
||
{"type": "http", "path": "/fail/x", "method": "GET", "headers": []},
|
||
receive,
|
||
send,
|
||
)
|
||
assert attempts == [
|
||
"called"
|
||
], f"failing register_fn should be invoked once, not on every request; got {attempts}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_redis_clean_miss_skips_stale_in_memory():
|
||
"""When Redis is reachable and cleanly returns None (TTL expired,
|
||
counter genuinely absent), the read must reseed from DB - NOT fall
|
||
through to per-pod in-memory which only contains this pod's writes.
|
||
|
||
Pre-fix in multi-pod deployments, in-memory contained a stale local
|
||
subset (e.g. $30) while DB had the true cross-pod total ($500). The
|
||
fall-through returned $30, enforcement passed, bypass.
|
||
"""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-1"
|
||
|
||
# Per-pod stale in-memory: only this pod's writes, not cross-pod truth.
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=30.0)
|
||
|
||
# Redis cleanly returns None (key expired or never written on this pod).
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(return_value=None)
|
||
fake_redis.async_increment = AsyncMock(return_value=500.0)
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
# DB has the authoritative cross-pod spend.
|
||
db_row = MagicMock()
|
||
db_row.spend = 500.0
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(return_value=db_row)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend == 500.0, (
|
||
f"expected DB-authoritative 500.0 on clean Redis miss, got {spend} "
|
||
f"(stale per-pod in-memory $30 would have caused multi-pod bypass)"
|
||
)
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_spend_redis_error_falls_back_to_in_memory():
|
||
"""When Redis raises, the read should still degrade to in-memory rather
|
||
than going straight to DB - in-memory is at least same-pod-fresh and
|
||
cheaper than a DB query during a Redis outage."""
|
||
from litellm.caching.dual_cache import DualCache
|
||
from litellm.proxy.proxy_server import get_current_spend
|
||
|
||
counter_cache = DualCache()
|
||
counter_key = "spend:team_member:user-1:team-1"
|
||
|
||
counter_cache.in_memory_cache.set_cache(key=counter_key, value=42.0)
|
||
|
||
fake_redis = AsyncMock()
|
||
fake_redis.async_get_cache = AsyncMock(side_effect=ConnectionError("redis down"))
|
||
counter_cache.redis_cache = fake_redis
|
||
|
||
fake_prisma = MagicMock()
|
||
fake_prisma.db.litellm_teammembership.find_unique = AsyncMock(
|
||
return_value=MagicMock(spend=999.0)
|
||
)
|
||
|
||
import litellm.proxy.proxy_server as ps
|
||
|
||
orig_counter, orig_prisma = ps.spend_counter_cache, ps.prisma_client
|
||
ps.spend_counter_cache = counter_cache
|
||
ps.prisma_client = fake_prisma
|
||
try:
|
||
spend = await get_current_spend(counter_key=counter_key, fallback_spend=0.0)
|
||
assert spend == 42.0, (
|
||
f"expected in-memory fallback 42.0 on Redis error, got {spend} "
|
||
f"(should not have hit DB when Redis errored)"
|
||
)
|
||
# DB query should NOT have fired - in-memory short-circuits.
|
||
fake_prisma.db.litellm_teammembership.find_unique.assert_not_awaited()
|
||
finally:
|
||
ps.spend_counter_cache = orig_counter
|
||
ps.prisma_client = orig_prisma
|
||
|
||
|
||
def test_realtime_websocket_route_aliases_registered():
|
||
"""Realtime sessions reach the proxy via three path aliases stacked on
|
||
`realtime_websocket_endpoint`. Dropping any of them silently 405s
|
||
WebSocket upgrades because the catch-all `/openai/{endpoint:path}`
|
||
HTTP passthrough only declares HTTP methods. The aliases must also be
|
||
in `LiteLLMRoutes.openai_routes` (so non-admin / team / key-scoped
|
||
auth allows them) and in `API_ROUTE_TO_CALL_TYPES` (so call-type-aware
|
||
logic such as guardrails can resolve the realtime call type)."""
|
||
from starlette.routing import WebSocketRoute
|
||
|
||
from litellm.proxy._types import LiteLLMRoutes
|
||
from litellm.proxy.proxy_server import app
|
||
from litellm.types.utils import API_ROUTE_TO_CALL_TYPES, CallTypes
|
||
|
||
websocket_paths = {
|
||
route.path for route in app.routes if isinstance(route, WebSocketRoute)
|
||
}
|
||
openai_routes = LiteLLMRoutes.openai_routes.value
|
||
|
||
for expected in ("/openai/v1/realtime", "/v1/realtime", "/realtime"):
|
||
assert expected in websocket_paths, (
|
||
f"{expected!r} missing from registered WebSocket routes; the "
|
||
f"realtime endpoint will 405 for clients hitting this path."
|
||
)
|
||
assert expected in openai_routes, (
|
||
f"{expected!r} missing from LiteLLMRoutes.openai_routes; "
|
||
f"non-admin / team / key-scoped users will get 403 on this path."
|
||
)
|
||
assert API_ROUTE_TO_CALL_TYPES.get(expected) == [CallTypes.arealtime], (
|
||
f"{expected!r} missing from API_ROUTE_TO_CALL_TYPES; call-type "
|
||
f"resolution will return None and break call-type-aware features."
|
||
)
|
||
|
||
|
||
class TestTransformRequestBannedParams:
|
||
"""
|
||
/utils/transform_request applies the same banned-param check as LLM endpoints.
|
||
|
||
Without this check, any authenticated user could supply aws_sts_endpoint,
|
||
api_base, etc. and have the server forward its credentials to an
|
||
attacker-controlled endpoint during SDK credential resolution.
|
||
"""
|
||
|
||
@pytest.fixture
|
||
def client(self):
|
||
mock_auth = UserAPIKeyAuth(
|
||
user_id="test-internal",
|
||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||
)
|
||
original = app.dependency_overrides.copy()
|
||
app.dependency_overrides[user_api_key_auth] = lambda: mock_auth
|
||
try:
|
||
yield TestClient(app)
|
||
finally:
|
||
app.dependency_overrides = original
|
||
|
||
@pytest.mark.parametrize(
|
||
"banned",
|
||
[
|
||
"aws_sts_endpoint",
|
||
"api_base",
|
||
"aws_web_identity_token",
|
||
"vertex_credentials",
|
||
],
|
||
)
|
||
def test_banned_params_rejected_for_all_users(self, client, banned):
|
||
"""Banned params must be blocked for any authenticated user."""
|
||
response = client.post(
|
||
"/utils/transform_request",
|
||
json={
|
||
"call_type": "completion",
|
||
"request_body": {
|
||
"model": "gpt-3.5-turbo",
|
||
banned: "https://attacker.example",
|
||
},
|
||
},
|
||
)
|
||
assert response.status_code == 400, (
|
||
f"Expected 400 for banned param '{banned}', "
|
||
f"got {response.status_code}: {response.json()}"
|
||
)
|
||
|
||
|
||
class TestSortModelsByDisplayName:
|
||
"""Regression: team BYOK rows persist an internal `model_name` like
|
||
`model_name_{team_id}_{uuid}` and expose the user-facing name via
|
||
`model_info.team_public_model_name`. Sorting must use the displayed
|
||
name so BYOK rows interleave with non-BYOK rows alphabetically —
|
||
otherwise they clump at the end on their opaque IDs even though the
|
||
UI shows them under a normal-looking name.
|
||
"""
|
||
|
||
def test_byok_models_sort_by_team_public_model_name(self):
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "claude-haiku-4-5", "model_info": {}},
|
||
{
|
||
# Opaque internal name; UI displays team_public_model_name.
|
||
"model_name": "model_name_team-1_abc123",
|
||
"model_info": {"team_public_model_name": "anthropic/claude"},
|
||
},
|
||
{"model_name": "gpt-4o", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="asc"
|
||
)
|
||
displayed_order = [
|
||
m["model_info"].get("team_public_model_name") or m["model_name"]
|
||
for m in sorted_models
|
||
]
|
||
assert displayed_order == [
|
||
"anthropic/claude",
|
||
"claude-haiku-4-5",
|
||
"gpt-4o",
|
||
]
|
||
|
||
def test_byok_models_sort_descending_by_display_name(self):
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "claude-haiku-4-5", "model_info": {}},
|
||
{
|
||
"model_name": "model_name_team-1_zzz",
|
||
"model_info": {"team_public_model_name": "zeta/model"},
|
||
},
|
||
{"model_name": "gpt-4o", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="desc"
|
||
)
|
||
displayed_order = [
|
||
m["model_info"].get("team_public_model_name") or m["model_name"]
|
||
for m in sorted_models
|
||
]
|
||
assert displayed_order == [
|
||
"zeta/model",
|
||
"gpt-4o",
|
||
"claude-haiku-4-5",
|
||
]
|
||
|
||
def test_empty_team_public_model_name_falls_back_to_model_name(self):
|
||
# Empty string for team_public_model_name (not None) must still
|
||
# fall back to model_name — otherwise BYOK rows with a blank
|
||
# display name would sort to the top.
|
||
from litellm.proxy.proxy_server import _sort_models
|
||
|
||
models = [
|
||
{"model_name": "alpha", "model_info": {"team_public_model_name": ""}},
|
||
{"model_name": "beta", "model_info": {}},
|
||
]
|
||
|
||
sorted_models = _sort_models(
|
||
all_models=models, sort_by="model_name", sort_order="asc"
|
||
)
|
||
assert [m["model_name"] for m in sorted_models] == ["alpha", "beta"]
|