fix: always retain config models
This commit is contained in:
parent
8eb58db989
commit
6bd722bba4
@ -1819,14 +1819,17 @@ async def initialize_pass_through_endpoints(
|
||||
|
||||
verbose_proxy_logger.debug("initializing pass through endpoints")
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes
|
||||
from litellm.proxy.proxy_server import app, general_settings, premium_user
|
||||
from litellm.proxy.proxy_server import (
|
||||
app,
|
||||
config_passthrough_endpoints,
|
||||
premium_user,
|
||||
)
|
||||
|
||||
## get combined pass-through endpoints from db + config
|
||||
config_pass_through_endpoints = general_settings.get("pass_through_endpoints")
|
||||
combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]]
|
||||
if config_pass_through_endpoints is not None:
|
||||
if config_passthrough_endpoints is not None:
|
||||
combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore
|
||||
pass_through_endpoints, config_pass_through_endpoints
|
||||
pass_through_endpoints, config_passthrough_endpoints
|
||||
)
|
||||
else:
|
||||
combined_pass_through_endpoints = pass_through_endpoints # type: ignore
|
||||
|
||||
@ -259,9 +259,7 @@ from litellm.proxy.management_endpoints.customer_endpoints import (
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
router as internal_user_router,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
user_update,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_verification_tokens,
|
||||
duration_in_seconds,
|
||||
@ -308,9 +306,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
router as openai_files_router,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
set_files_config,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
@ -580,7 +576,7 @@ async def _initialize_shared_aiohttp_session():
|
||||
ttl_dns_cache=AIOHTTP_TTL_DNS_CACHE,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
|
||||
session = ClientSession(connector=connector)
|
||||
verbose_proxy_logger.info(
|
||||
f"SESSION REUSE: Created shared aiohttp session for connection pooling (ID: {id(session)})"
|
||||
@ -723,7 +719,7 @@ async def proxy_startup_event(app: FastAPI):
|
||||
verbose_proxy_logger.info("SESSION REUSE: Closed shared aiohttp session")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error closing shared aiohttp session: {e}")
|
||||
|
||||
|
||||
await proxy_shutdown_event()
|
||||
|
||||
|
||||
@ -995,13 +991,16 @@ experimental = False
|
||||
llm_router: Optional[Router] = None
|
||||
llm_model_list: Optional[list] = None
|
||||
general_settings: dict = {}
|
||||
config_passthrough_endpoints: Optional[List[Dict[str, Any]]] = None
|
||||
callback_settings: dict = {}
|
||||
log_file = "api_log.json"
|
||||
worker_config = None
|
||||
master_key: Optional[str] = None
|
||||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
shared_aiohttp_session: Optional["ClientSession"] = None # Global shared session for connection reuse
|
||||
shared_aiohttp_session: Optional["ClientSession"] = (
|
||||
None # Global shared session for connection reuse
|
||||
)
|
||||
user_api_key_cache = DualCache(
|
||||
default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
|
||||
)
|
||||
@ -1385,31 +1384,33 @@ async def update_cache( # noqa: PLR0915
|
||||
"""
|
||||
if tags is None or response_cost is None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
for tag_name in tags:
|
||||
if not tag_name or not isinstance(tag_name, str):
|
||||
continue
|
||||
|
||||
|
||||
cache_key = f"tag:{tag_name}"
|
||||
# Fetch the existing tag object from cache
|
||||
existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key)
|
||||
existing_tag_obj = await user_api_key_cache.async_get_cache(
|
||||
key=cache_key
|
||||
)
|
||||
if existing_tag_obj is None:
|
||||
# do nothing if tag not in api key cache
|
||||
continue
|
||||
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}"
|
||||
)
|
||||
|
||||
|
||||
if isinstance(existing_tag_obj, dict):
|
||||
existing_spend = existing_tag_obj.get("spend", 0) or 0
|
||||
else:
|
||||
existing_spend = getattr(existing_tag_obj, "spend", 0) or 0
|
||||
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
|
||||
# Update the spend column for the given tag
|
||||
if isinstance(existing_tag_obj, dict):
|
||||
existing_tag_obj["spend"] = new_spend
|
||||
@ -1482,6 +1483,7 @@ async def _run_background_health_check():
|
||||
from litellm.proxy.health_check_utils.shared_health_check_manager import (
|
||||
SharedHealthCheckManager,
|
||||
)
|
||||
|
||||
shared_health_manager = SharedHealthCheckManager(
|
||||
redis_cache=redis_usage_cache,
|
||||
health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL,
|
||||
@ -1502,16 +1504,21 @@ async def _run_background_health_check():
|
||||
|
||||
# Use shared health check if available, otherwise fall back to direct health check
|
||||
# Convert health_check_details to bool for perform_shared_health_check (defaults to True if None)
|
||||
details_bool = health_check_details if health_check_details is not None else True
|
||||
|
||||
details_bool = (
|
||||
health_check_details if health_check_details is not None else True
|
||||
)
|
||||
|
||||
if shared_health_manager is not None:
|
||||
try:
|
||||
healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check(
|
||||
model_list=_llm_model_list, details=details_bool
|
||||
healthy_endpoints, unhealthy_endpoints = (
|
||||
await shared_health_manager.perform_shared_health_check(
|
||||
model_list=_llm_model_list, details=details_bool
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Error in shared health check, falling back to direct health check: %s", str(e)
|
||||
"Error in shared health check, falling back to direct health check: %s",
|
||||
str(e),
|
||||
)
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||
model_list=_llm_model_list, details=health_check_details
|
||||
@ -1879,7 +1886,7 @@ class ProxyConfig:
|
||||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval, config_passthrough_endpoints
|
||||
|
||||
config: dict = await self.get_config(config_file_path=config_file_path)
|
||||
|
||||
@ -2234,9 +2241,13 @@ class ProxyConfig:
|
||||
|
||||
## pass through endpoints
|
||||
if general_settings.get("pass_through_endpoints", None) is not None:
|
||||
config_passthrough_endpoints = general_settings[
|
||||
"pass_through_endpoints"
|
||||
]
|
||||
await initialize_pass_through_endpoints(
|
||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
|
||||
## ADMIN UI ACCESS ##
|
||||
ui_access_mode = general_settings.get(
|
||||
"ui_access_mode", "all"
|
||||
@ -3055,7 +3066,9 @@ class ProxyConfig:
|
||||
return current_config
|
||||
|
||||
# For dictionary values, update only non-none values
|
||||
if isinstance(current_config[param_name], dict) and isinstance(db_param_value, dict):
|
||||
if isinstance(current_config[param_name], dict) and isinstance(
|
||||
db_param_value, dict
|
||||
):
|
||||
_deep_merge_dicts(current_config[param_name], db_param_value)
|
||||
else:
|
||||
# Non-dict or mismatched types: DB value replaces config (unchanged behavior)
|
||||
|
||||
@ -9,9 +9,12 @@ from openai import AssistantEventHandler
|
||||
|
||||
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
||||
|
||||
|
||||
def test_pass_through_file_operations():
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", suffix=".txt", delete=False
|
||||
) as temp_file:
|
||||
temp_file.write("This is a test file for the OpenAI Assistants API.")
|
||||
temp_file.flush()
|
||||
|
||||
@ -26,6 +29,7 @@ def test_pass_through_file_operations():
|
||||
delete_file = client.files.delete(file.id)
|
||||
print("file deleted", delete_file)
|
||||
|
||||
|
||||
def test_openai_assistants_e2e_operations():
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
@ -98,13 +102,13 @@ def test_openai_assistants_e2e_operations_stream():
|
||||
stream.until_done()
|
||||
|
||||
|
||||
|
||||
def test_azure_openai_assistants_e2e_operations_stream():
|
||||
from openai import AzureOpenAI
|
||||
|
||||
client = AzureOpenAI(
|
||||
base_url="http://0.0.0.0:4000/azure-config-passthrough/openai",
|
||||
base_url="http://0.0.0.0:4000/azure-config-passthrough/openai",
|
||||
api_key="sk-1234",
|
||||
api_version="2025-01-01-preview"
|
||||
api_version="2025-01-01-preview",
|
||||
)
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
@ -134,4 +138,4 @@ def test_azure_openai_assistants_e2e_operations_stream():
|
||||
instructions="Please address the user as Jane Doe. The user has a premium account.",
|
||||
event_handler=EventHandler(),
|
||||
) as stream:
|
||||
stream.until_done()
|
||||
stream.until_done()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user