diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 9dbe1fc5c6..18bab49eb4 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -1819,14 +1819,17 @@ async def initialize_pass_through_endpoints( verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes - from litellm.proxy.proxy_server import app, general_settings, premium_user + from litellm.proxy.proxy_server import ( + app, + config_passthrough_endpoints, + premium_user, + ) ## get combined pass-through endpoints from db + config - config_pass_through_endpoints = general_settings.get("pass_through_endpoints") combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]] - if config_pass_through_endpoints is not None: + if config_passthrough_endpoints is not None: combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore - pass_through_endpoints, config_pass_through_endpoints + pass_through_endpoints, config_passthrough_endpoints ) else: combined_pass_through_endpoints = pass_through_endpoints # type: ignore diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index be0f7d6be0..8f5a046be0 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -259,9 +259,7 @@ from litellm.proxy.management_endpoints.customer_endpoints import ( from litellm.proxy.management_endpoints.internal_user_endpoints import ( router as internal_user_router, ) -from litellm.proxy.management_endpoints.internal_user_endpoints import ( - user_update, -) +from litellm.proxy.management_endpoints.internal_user_endpoints import user_update from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_verification_tokens, duration_in_seconds, @@ -308,9 +306,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, ) -from litellm.proxy.openai_files_endpoints.files_endpoints import ( - set_files_config, -) +from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( passthrough_endpoint_router, ) @@ -580,7 +576,7 @@ async def _initialize_shared_aiohttp_session(): ttl_dns_cache=AIOHTTP_TTL_DNS_CACHE, enable_cleanup_closed=True, ) - + session = ClientSession(connector=connector) verbose_proxy_logger.info( f"SESSION REUSE: Created shared aiohttp session for connection pooling (ID: {id(session)})" @@ -723,7 +719,7 @@ async def proxy_startup_event(app: FastAPI): verbose_proxy_logger.info("SESSION REUSE: Closed shared aiohttp session") except Exception as e: verbose_proxy_logger.error(f"Error closing shared aiohttp session: {e}") - + await proxy_shutdown_event() @@ -995,13 +991,16 @@ experimental = False llm_router: Optional[Router] = None llm_model_list: Optional[list] = None general_settings: dict = {} +config_passthrough_endpoints: Optional[List[Dict[str, Any]]] = None callback_settings: dict = {} log_file = "api_log.json" worker_config = None master_key: Optional[str] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional["ClientSession"] = None # Global shared session for connection reuse +shared_aiohttp_session: Optional["ClientSession"] = ( + None # Global shared session for connection reuse +) user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1385,31 +1384,33 @@ async def update_cache( # noqa: PLR0915 """ if tags is None or response_cost is None: return - + try: for tag_name in tags: if not tag_name or not isinstance(tag_name, str): continue - + cache_key = f"tag:{tag_name}" # Fetch the existing tag object from cache - existing_tag_obj = await user_api_key_cache.async_get_cache(key=cache_key) + existing_tag_obj = await user_api_key_cache.async_get_cache( + key=cache_key + ) if existing_tag_obj is None: # do nothing if tag not in api key cache continue - + verbose_proxy_logger.debug( f"_update_tag_cache: existing spend for tag={tag_name}: {existing_tag_obj}; response_cost: {response_cost}" ) - + if isinstance(existing_tag_obj, dict): existing_spend = existing_tag_obj.get("spend", 0) or 0 else: existing_spend = getattr(existing_tag_obj, "spend", 0) or 0 - + # Calculate the new cost by adding the existing cost and response_cost new_spend = existing_spend + response_cost - + # Update the spend column for the given tag if isinstance(existing_tag_obj, dict): existing_tag_obj["spend"] = new_spend @@ -1482,6 +1483,7 @@ async def _run_background_health_check(): from litellm.proxy.health_check_utils.shared_health_check_manager import ( SharedHealthCheckManager, ) + shared_health_manager = SharedHealthCheckManager( redis_cache=redis_usage_cache, health_check_ttl=DEFAULT_SHARED_HEALTH_CHECK_TTL, @@ -1502,16 +1504,21 @@ async def _run_background_health_check(): # Use shared health check if available, otherwise fall back to direct health check # Convert health_check_details to bool for perform_shared_health_check (defaults to True if None) - details_bool = health_check_details if health_check_details is not None else True - + details_bool = ( + health_check_details if health_check_details is not None else True + ) + if shared_health_manager is not None: try: - healthy_endpoints, unhealthy_endpoints = await shared_health_manager.perform_shared_health_check( - model_list=_llm_model_list, details=details_bool + healthy_endpoints, unhealthy_endpoints = ( + await shared_health_manager.perform_shared_health_check( + model_list=_llm_model_list, details=details_bool + ) ) except Exception as e: verbose_proxy_logger.error( - "Error in shared health check, falling back to direct health check: %s", str(e) + "Error in shared health check, falling back to direct health check: %s", + str(e), ) healthy_endpoints, unhealthy_endpoints = await perform_health_check( model_list=_llm_model_list, details=health_check_details @@ -1879,7 +1886,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, user_custom_ui_sso_sign_in_handler, use_background_health_checks, use_shared_health_check, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings, proxy_batch_polling_interval, config_passthrough_endpoints config: dict = await self.get_config(config_file_path=config_file_path) @@ -2234,9 +2241,13 @@ class ProxyConfig: ## pass through endpoints if general_settings.get("pass_through_endpoints", None) is not None: + config_passthrough_endpoints = general_settings[ + "pass_through_endpoints" + ] await initialize_pass_through_endpoints( pass_through_endpoints=general_settings["pass_through_endpoints"] ) + ## ADMIN UI ACCESS ## ui_access_mode = general_settings.get( "ui_access_mode", "all" @@ -3055,7 +3066,9 @@ class ProxyConfig: return current_config # For dictionary values, update only non-none values - if isinstance(current_config[param_name], dict) and isinstance(db_param_value, dict): + if isinstance(current_config[param_name], dict) and isinstance( + db_param_value, dict + ): _deep_merge_dicts(current_config[param_name], db_param_value) else: # Non-dict or mismatched types: DB value replaces config (unchanged behavior) diff --git a/tests/pass_through_tests/test_openai_assistants_passthrough.py b/tests/pass_through_tests/test_openai_assistants_passthrough.py index e5783877ec..28568005fd 100644 --- a/tests/pass_through_tests/test_openai_assistants_passthrough.py +++ b/tests/pass_through_tests/test_openai_assistants_passthrough.py @@ -9,9 +9,12 @@ from openai import AssistantEventHandler client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234") + def test_pass_through_file_operations(): # Create a temporary file - with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".txt", delete=False + ) as temp_file: temp_file.write("This is a test file for the OpenAI Assistants API.") temp_file.flush() @@ -26,6 +29,7 @@ def test_pass_through_file_operations(): delete_file = client.files.delete(file.id) print("file deleted", delete_file) + def test_openai_assistants_e2e_operations(): assistant = client.beta.assistants.create( name="Math Tutor", @@ -98,13 +102,13 @@ def test_openai_assistants_e2e_operations_stream(): stream.until_done() - def test_azure_openai_assistants_e2e_operations_stream(): from openai import AzureOpenAI + client = AzureOpenAI( - base_url="http://0.0.0.0:4000/azure-config-passthrough/openai", + base_url="http://0.0.0.0:4000/azure-config-passthrough/openai", api_key="sk-1234", - api_version="2025-01-01-preview" + api_version="2025-01-01-preview", ) assistant = client.beta.assistants.create( name="Math Tutor", @@ -134,4 +138,4 @@ def test_azure_openai_assistants_e2e_operations_stream(): instructions="Please address the user as Jane Doe. The user has a premium account.", event_handler=EventHandler(), ) as stream: - stream.until_done() \ No newline at end of file + stream.until_done()