fix: always retain config models

This commit is contained in:
Krrish Dholakia 2025-10-11 16:09:33 -07:00
parent 8eb58db989
commit 6bd722bba4
3 changed files with 52 additions and 32 deletions

View File

@ -1819,14 +1819,17 @@ async def initialize_pass_through_endpoints(
verbose_proxy_logger.debug("initializing pass through endpoints")
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes
from litellm.proxy.proxy_server import app, general_settings, premium_user
from litellm.proxy.proxy_server import (
app,
config_passthrough_endpoints,
premium_user,
)
## get combined pass-through endpoints from db + config
config_pass_through_endpoints = general_settings.get("pass_through_endpoints")
combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]]
if config_pass_through_endpoints is not None:
if config_passthrough_endpoints is not None:
combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore
pass_through_endpoints, config_pass_through_endpoints
pass_through_endpoints, config_passthrough_endpoints
)
else:
combined_pass_through_endpoints = pass_through_endpoints # type: ignore

View File

@ -259,9 +259,7 @@ from litellm.proxy.management_endpoints.customer_endpoints import (
from litellm.proxy.management_endpoints.internal_user_endpoints import (
router as internal_user_router,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import (
user_update,
)
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
from litellm.proxy.management_endpoints.key_management_endpoints import (
delete_verification_tokens,
duration_in_seconds,
@ -308,9 +306,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
)
from litellm.proxy.openai_files_endpoints.files_endpoints import (
set_files_config,
)
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
@ -580,7 +576,7 @@ async def _initialize_shared_aiohttp_session():
ttl_dns_cache=AIOHTTP_TTL_DNS_CACHE,
enable_cleanup_closed=True,
)
session = ClientSession(connector=connector)
verbose_proxy_logger.info(
f"SESSION REUSE: Created shared aiohttp session for connection pooling (ID: {id(session)})"
@ -723,7 +719,7 @@ async def proxy_startup_event(app: FastAPI):
verbose_proxy_logger.info("SESSION REUSE: Closed shared aiohttp session")
except Exception as e:
verbose_proxy_logger.error(f"Error closing shared aiohttp session: {e}")
await proxy_shutdown_event()
@ -995,13 +991,16 @@ experimental = False
llm_router: Optional[Router] = None
llm_model_list: Optional[list] = None
general_settings: dict = {}
config_passthrough_endpoints: Optional[List[Dict[str, Any]]] = None
callback_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key: Optional[str] = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
shared_aiohttp_session: Optional["ClientSession"] = None # Global shared session for connection reuse
shared_aiohttp_session: Optional["ClientSession"] = (
None # Global shared session for connection reuse
)
user_api_key_cache = DualCache(
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
)
@ -1385,31 +1384,33 @@ async def update_cache( # noqa: PLR0915
"""
if tags is None or response_cost is None:
return
try:
for tag_name in tags:
if not tag_name or not isinstance(tag_name, str):
continue
cache_key = f"tag:{tag_name}"
# Fetch the existing tag object from cache
existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key)
existing_tag_obj = await user_api_key_cache.async_get_cache(
key=cache_key
)
if existing_tag_obj is None:
# do nothing if tag not in api key cache
continue
verbose_proxy_logger.debug(
f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}"
)
if isinstance(existing_tag_obj, dict):
existing_spend = existing_tag_obj.get("spend", 0) or 0
else:
existing_spend = getattr(existing_tag_obj, "spend", 0) or 0
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
# Update the spend column for the given tag
if isinstance(existing_tag_obj, dict):
existing_tag_obj["spend"] = new_spend
@ -1482,6 +1483,7 @@ async def _run_background_health_check():
from litellm.proxy.health_check_utils.shared_health_check_manager import (
SharedHealthCheckManager,
)
shared_health_manager = SharedHealthCheckManager(
redis_cache=redis_usage_cache,
health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL,
@ -1502,16 +1504,21 @@ async def _run_background_health_check():
# Use shared health check if available, otherwise fall back to direct health check
# Convert health_check_details to bool for perform_shared_health_check (defaults to True if None)
details_bool = health_check_details if health_check_details is not None else True
details_bool = (
health_check_details if health_check_details is not None else True
)
if shared_health_manager is not None:
try:
healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check(
model_list=_llm_model_list, details=details_bool
healthy_endpoints, unhealthy_endpoints = (
await shared_health_manager.perform_shared_health_check(
model_list=_llm_model_list, details=details_bool
)
)
except Exception as e:
verbose_proxy_logger.error(
"Error in shared health check, falling back to direct health check: %s", str(e)
"Error in shared health check, falling back to direct health check: %s",
str(e),
)
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
model_list=_llm_model_list, details=health_check_details
@ -1879,7 +1886,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval, config_passthrough_endpoints
config: dict = await self.get_config(config_file_path=config_file_path)
@ -2234,9 +2241,13 @@ class ProxyConfig:
## pass through endpoints
if general_settings.get("pass_through_endpoints", None) is not None:
config_passthrough_endpoints = general_settings[
"pass_through_endpoints"
]
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
## ADMIN UI ACCESS ##
ui_access_mode = general_settings.get(
"ui_access_mode", "all"
@ -3055,7 +3066,9 @@ class ProxyConfig:
return current_config
# For dictionary values, update only non-none values
if isinstance(current_config[param_name], dict) and isinstance(db_param_value, dict):
if isinstance(current_config[param_name], dict) and isinstance(
db_param_value, dict
):
_deep_merge_dicts(current_config[param_name], db_param_value)
else:
# Non-dict or mismatched types: DB value replaces config (unchanged behavior)

View File

@ -9,9 +9,12 @@ from openai import AssistantEventHandler
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
def test_pass_through_file_operations():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
with tempfile.NamedTemporaryFile(
mode="w+", suffix=".txt", delete=False
) as temp_file:
temp_file.write("This is a test file for the OpenAI Assistants API.")
temp_file.flush()
@ -26,6 +29,7 @@ def test_pass_through_file_operations():
delete_file = client.files.delete(file.id)
print("file deleted", delete_file)
def test_openai_assistants_e2e_operations():
assistant = client.beta.assistants.create(
name="Math Tutor",
@ -98,13 +102,13 @@ def test_openai_assistants_e2e_operations_stream():
stream.until_done()
def test_azure_openai_assistants_e2e_operations_stream():
from openai import AzureOpenAI
client = AzureOpenAI(
base_url="http://0.0.0.0:4000/azure-config-passthrough/openai",
base_url="http://0.0.0.0:4000/azure-config-passthrough/openai",
api_key="sk-1234",
api_version="2025-01-01-preview"
api_version="2025-01-01-preview",
)
assistant = client.beta.assistants.create(
name="Math Tutor",
@ -134,4 +138,4 @@ def test_azure_openai_assistants_e2e_operations_stream():
instructions="Please address the user as Jane Doe. The user has a premium account.",
event_handler=EventHandler(),
) as stream:
stream.until_done()
stream.until_done()