diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 6ea4bd80e2..e8d05031a5 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -546,6 +546,7 @@ async def common_checks( # noqa: PLR0915 route=route, request_headers=_safe_get_request_headers(request=request), request_query_params=_safe_get_request_query_params(request=request), + llm_router=llm_router, ) if route in MODEL_DISCOVERY_ROUTES: diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 80840c2742..71cf5197de 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -1284,7 +1284,9 @@ def _route_uses_model_routing_sources(route: str) -> bool: def _extract_models_from_managed_resource_id( - resource_id: Any, resource_id_field: Optional[str] = None + resource_id: Any, + resource_id_field: Optional[str] = None, + llm_router: Optional[Router] = None, ) -> List[str]: if not isinstance(resource_id, str) or not resource_id: return [] @@ -1341,16 +1343,18 @@ def _extract_models_from_managed_resource_id( ) if resource_id_field == "video_id": + model_id = decode_video_id_with_provider(resource_id).get("model_id") _append_model_candidates( candidates=candidates, - value=decode_video_id_with_provider(resource_id).get("model_id"), + value=_resolve_model_id_with_router(model_id, llm_router), ) else: + model_id = decode_character_id_with_provider(resource_id).get( + "model_id" + ) _append_model_candidates( candidates=candidates, - value=decode_character_id_with_provider(resource_id).get( - "model_id" - ), + value=_resolve_model_id_with_router(model_id, llm_router), ) except Exception as e: verbose_proxy_logger.debug( @@ -1360,11 +1364,26 @@ def _extract_models_from_managed_resource_id( return _dedupe_model_candidates(candidates) +def _resolve_model_id_with_router( + model_id: Optional[str], llm_router: Optional[Router] +) -> Optional[str]: + if model_id is None or llm_router is None: + return model_id + try: + return llm_router.resolve_model_name_from_model_id(model_id) or model_id + except Exception as e: + verbose_proxy_logger.debug( + "Unable to resolve model_id from managed resource ID: %s", str(e) + ) + return model_id + + def _extract_model_candidates_from_request( request_data: dict, route: str, request_headers: Optional[Mapping[str, Any]] = None, request_query_params: Optional[Mapping[str, Any]] = None, + llm_router: Optional[Router] = None, ) -> List[str]: candidates: List[str] = [] uses_model_routing_sources = _route_uses_model_routing_sources(route=route) @@ -1414,7 +1433,9 @@ def _extract_model_candidates_from_request( _append_model_candidates( candidates, _extract_models_from_managed_resource_id( - request_data.get(field), resource_id_field=field + request_data.get(field), + resource_id_field=field, + llm_router=llm_router, ), ) @@ -1436,12 +1457,14 @@ def get_model_from_request( route: str, request_headers: Optional[Mapping[str, Any]] = None, request_query_params: Optional[Mapping[str, Any]] = None, + llm_router: Optional[Router] = None, ) -> Optional[Union[str, List[str]]]: candidates = _extract_model_candidates_from_request( request_data=request_data, route=route, request_headers=request_headers, request_query_params=request_query_params, + llm_router=llm_router, ) model = _format_model_candidates(candidates) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 2f9a141140..2e6cd1f8e7 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -146,12 +146,14 @@ def _get_model_from_request_context( request_data: dict, route: str, request: Optional[Request], + llm_router: Optional[Any] = None, ) -> Optional[Union[str, List[str]]]: return get_model_from_request( request_data=request_data, route=route, request_headers=_safe_get_request_headers(request=request), request_query_params=_safe_get_request_query_params(request=request), + llm_router=llm_router, ) @@ -1034,6 +1036,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 request_data=request_data, route=route, request=request, + llm_router=llm_router, ) skip_budget_checks = False if model is not None and llm_router is not None: @@ -1451,6 +1454,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 request_data=request_data, route=route, request=request, + llm_router=llm_router, ) skip_budget_checks = False if model is not None and llm_router is not None: @@ -1579,6 +1583,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 request_data=request_data, route=route, request=request, + llm_router=llm_router, ) current_models = _get_model_names_for_budget_checks( model=current_model @@ -2159,6 +2164,7 @@ def _should_skip_budget_checks( request_data=request_data, route=route, request=request, + llm_router=llm_router, ) if model is not None and llm_router is not None: return _is_model_cost_zero(model=model, llm_router=llm_router) @@ -2475,6 +2481,7 @@ async def _enforce_key_and_fallback_model_access( request_data=request_data, route=route, request=request, + llm_router=llm_router, ) if model is not None: @@ -2616,6 +2623,7 @@ async def _run_post_custom_auth_checks( request_data=request_data, route=route, request=request, + llm_router=llm_router, ) current_models = _get_model_names_for_budget_checks(model=current_model) diff --git a/litellm/proxy/spend_tracking/budget_reservation.py b/litellm/proxy/spend_tracking/budget_reservation.py index 200a17e236..eb8af3b073 100644 --- a/litellm/proxy/spend_tracking/budget_reservation.py +++ b/litellm/proxy/spend_tracking/budget_reservation.py @@ -72,7 +72,7 @@ async def reserve_budget_for_request( return None if route in {"/models", "/v1/models", "/utils/token_counter"}: return None - if get_model_from_request(request_body, route) is None: + if get_model_from_request(request_body, route, llm_router=llm_router) is None: return None counters = await _get_budget_counters( @@ -797,7 +797,7 @@ def estimate_request_max_cost( route: str, llm_router: Optional[Router], ) -> Optional[float]: - model = get_model_from_request(request_body, route) + model = get_model_from_request(request_body, route, llm_router=llm_router) if model is None: return None diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index 60cf50efc7..d4ca55ca16 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -399,6 +399,62 @@ def test_get_model_from_request_extracts_video_id_model(): ) +def test_get_model_from_request_resolves_video_id_model_with_router(): + from litellm.types.videos.utils import encode_video_id_with_provider + + provider_video_id = ( + "projects/test-project/locations/us-central1/publishers/google/models/" + "veo-3.1-generate-001/operations/operation-id" + ) + video_id = encode_video_id_with_provider( + video_id=provider_video_id, + provider="vertex_ai", + model_id="veo-3.1-generate-001", + ) + llm_router = MagicMock() + llm_router.resolve_model_name_from_model_id.return_value = ( + "gcp/google/veo-3.1-generate-001" + ) + + assert ( + get_model_from_request( + request_data={"video_id": video_id}, + route="/v1/videos/{video_id}", + llm_router=llm_router, + ) + == "gcp/google/veo-3.1-generate-001" + ) + llm_router.resolve_model_name_from_model_id.assert_called_once_with( + "veo-3.1-generate-001" + ) + + +def test_get_model_from_request_resolves_character_id_model_with_router(): + from litellm.types.videos.utils import encode_character_id_with_provider + + character_id = encode_character_id_with_provider( + character_id="character-provider-id", + provider="vertex_ai", + model_id="veo-3.1-generate-001", + ) + llm_router = MagicMock() + llm_router.resolve_model_name_from_model_id.return_value = ( + "gcp/google/veo-3.1-generate-001" + ) + + assert ( + get_model_from_request( + request_data={"character_id": character_id}, + route="/v1/videos/characters/{character_id}", + llm_router=llm_router, + ) + == "gcp/google/veo-3.1-generate-001" + ) + llm_router.resolve_model_name_from_model_id.assert_called_once_with( + "veo-3.1-generate-001" + ) + + def test_get_model_from_request_only_runs_media_decoders_for_matching_fields(): with ( patch(