fix(proxy): resolve managed video model ids for auth (#29545)
* fix(proxy): resolve managed video model ids for auth Co-authored-by: Cursor <cursoragent@cursor.com> * test(proxy): cover character_id router model resolution Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
8fbdfc7f0d
commit
d45e9e4d56
@ -546,6 +546,7 @@ async def common_checks( # noqa: PLR0915
|
||||
route=route,
|
||||
request_headers=_safe_get_request_headers(request=request),
|
||||
request_query_params=_safe_get_request_query_params(request=request),
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if route in MODEL_DISCOVERY_ROUTES:
|
||||
|
||||
@ -1284,7 +1284,9 @@ def _route_uses_model_routing_sources(route: str) -> bool:
|
||||
|
||||
|
||||
def _extract_models_from_managed_resource_id(
|
||||
resource_id: Any, resource_id_field: Optional[str] = None
|
||||
resource_id: Any,
|
||||
resource_id_field: Optional[str] = None,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
if not isinstance(resource_id, str) or not resource_id:
|
||||
return []
|
||||
@ -1341,16 +1343,18 @@ def _extract_models_from_managed_resource_id(
|
||||
)
|
||||
|
||||
if resource_id_field == "video_id":
|
||||
model_id = decode_video_id_with_provider(resource_id).get("model_id")
|
||||
_append_model_candidates(
|
||||
candidates=candidates,
|
||||
value=decode_video_id_with_provider(resource_id).get("model_id"),
|
||||
value=_resolve_model_id_with_router(model_id, llm_router),
|
||||
)
|
||||
else:
|
||||
model_id = decode_character_id_with_provider(resource_id).get(
|
||||
"model_id"
|
||||
)
|
||||
_append_model_candidates(
|
||||
candidates=candidates,
|
||||
value=decode_character_id_with_provider(resource_id).get(
|
||||
"model_id"
|
||||
),
|
||||
value=_resolve_model_id_with_router(model_id, llm_router),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
@ -1360,11 +1364,26 @@ def _extract_models_from_managed_resource_id(
|
||||
return _dedupe_model_candidates(candidates)
|
||||
|
||||
|
||||
def _resolve_model_id_with_router(
|
||||
model_id: Optional[str], llm_router: Optional[Router]
|
||||
) -> Optional[str]:
|
||||
if model_id is None or llm_router is None:
|
||||
return model_id
|
||||
try:
|
||||
return llm_router.resolve_model_name_from_model_id(model_id) or model_id
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Unable to resolve model_id from managed resource ID: %s", str(e)
|
||||
)
|
||||
return model_id
|
||||
|
||||
|
||||
def _extract_model_candidates_from_request(
|
||||
request_data: dict,
|
||||
route: str,
|
||||
request_headers: Optional[Mapping[str, Any]] = None,
|
||||
request_query_params: Optional[Mapping[str, Any]] = None,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
candidates: List[str] = []
|
||||
uses_model_routing_sources = _route_uses_model_routing_sources(route=route)
|
||||
@ -1414,7 +1433,9 @@ def _extract_model_candidates_from_request(
|
||||
_append_model_candidates(
|
||||
candidates,
|
||||
_extract_models_from_managed_resource_id(
|
||||
request_data.get(field), resource_id_field=field
|
||||
request_data.get(field),
|
||||
resource_id_field=field,
|
||||
llm_router=llm_router,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1436,12 +1457,14 @@ def get_model_from_request(
|
||||
route: str,
|
||||
request_headers: Optional[Mapping[str, Any]] = None,
|
||||
request_query_params: Optional[Mapping[str, Any]] = None,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
candidates = _extract_model_candidates_from_request(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request_headers=request_headers,
|
||||
request_query_params=request_query_params,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
model = _format_model_candidates(candidates)
|
||||
|
||||
|
||||
@ -146,12 +146,14 @@ def _get_model_from_request_context(
|
||||
request_data: dict,
|
||||
route: str,
|
||||
request: Optional[Request],
|
||||
llm_router: Optional[Any] = None,
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
return get_model_from_request(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request_headers=_safe_get_request_headers(request=request),
|
||||
request_query_params=_safe_get_request_query_params(request=request),
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
|
||||
@ -1034,6 +1036,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
skip_budget_checks = False
|
||||
if model is not None and llm_router is not None:
|
||||
@ -1451,6 +1454,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
skip_budget_checks = False
|
||||
if model is not None and llm_router is not None:
|
||||
@ -1579,6 +1583,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
current_models = _get_model_names_for_budget_checks(
|
||||
model=current_model
|
||||
@ -2159,6 +2164,7 @@ def _should_skip_budget_checks(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
if model is not None and llm_router is not None:
|
||||
return _is_model_cost_zero(model=model, llm_router=llm_router)
|
||||
@ -2475,6 +2481,7 @@ async def _enforce_key_and_fallback_model_access(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
@ -2616,6 +2623,7 @@ async def _run_post_custom_auth_checks(
|
||||
request_data=request_data,
|
||||
route=route,
|
||||
request=request,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
current_models = _get_model_names_for_budget_checks(model=current_model)
|
||||
|
||||
|
||||
@ -72,7 +72,7 @@ async def reserve_budget_for_request(
|
||||
return None
|
||||
if route in {"/models", "/v1/models", "/utils/token_counter"}:
|
||||
return None
|
||||
if get_model_from_request(request_body, route) is None:
|
||||
if get_model_from_request(request_body, route, llm_router=llm_router) is None:
|
||||
return None
|
||||
|
||||
counters = await _get_budget_counters(
|
||||
@ -797,7 +797,7 @@ def estimate_request_max_cost(
|
||||
route: str,
|
||||
llm_router: Optional[Router],
|
||||
) -> Optional[float]:
|
||||
model = get_model_from_request(request_body, route)
|
||||
model = get_model_from_request(request_body, route, llm_router=llm_router)
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
|
||||
@ -399,6 +399,62 @@ def test_get_model_from_request_extracts_video_id_model():
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_from_request_resolves_video_id_model_with_router():
|
||||
from litellm.types.videos.utils import encode_video_id_with_provider
|
||||
|
||||
provider_video_id = (
|
||||
"projects/test-project/locations/us-central1/publishers/google/models/"
|
||||
"veo-3.1-generate-001/operations/operation-id"
|
||||
)
|
||||
video_id = encode_video_id_with_provider(
|
||||
video_id=provider_video_id,
|
||||
provider="vertex_ai",
|
||||
model_id="veo-3.1-generate-001",
|
||||
)
|
||||
llm_router = MagicMock()
|
||||
llm_router.resolve_model_name_from_model_id.return_value = (
|
||||
"gcp/google/veo-3.1-generate-001"
|
||||
)
|
||||
|
||||
assert (
|
||||
get_model_from_request(
|
||||
request_data={"video_id": video_id},
|
||||
route="/v1/videos/{video_id}",
|
||||
llm_router=llm_router,
|
||||
)
|
||||
== "gcp/google/veo-3.1-generate-001"
|
||||
)
|
||||
llm_router.resolve_model_name_from_model_id.assert_called_once_with(
|
||||
"veo-3.1-generate-001"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_from_request_resolves_character_id_model_with_router():
|
||||
from litellm.types.videos.utils import encode_character_id_with_provider
|
||||
|
||||
character_id = encode_character_id_with_provider(
|
||||
character_id="character-provider-id",
|
||||
provider="vertex_ai",
|
||||
model_id="veo-3.1-generate-001",
|
||||
)
|
||||
llm_router = MagicMock()
|
||||
llm_router.resolve_model_name_from_model_id.return_value = (
|
||||
"gcp/google/veo-3.1-generate-001"
|
||||
)
|
||||
|
||||
assert (
|
||||
get_model_from_request(
|
||||
request_data={"character_id": character_id},
|
||||
route="/v1/videos/characters/{character_id}",
|
||||
llm_router=llm_router,
|
||||
)
|
||||
== "gcp/google/veo-3.1-generate-001"
|
||||
)
|
||||
llm_router.resolve_model_name_from_model_id.assert_called_once_with(
|
||||
"veo-3.1-generate-001"
|
||||
)
|
||||
|
||||
|
||||
def test_get_model_from_request_only_runs_media_decoders_for_matching_fields():
|
||||
with (
|
||||
patch(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user