fix(proxy): resolve managed video model ids for auth (#29545)

* fix(proxy): resolve managed video model ids for auth

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(proxy): cover character_id router model resolution

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Shivam Rawat 2026-06-02 19:31:36 -07:00 committed by GitHub
parent 8fbdfc7f0d
commit d45e9e4d56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 96 additions and 8 deletions

View File

@ -546,6 +546,7 @@ async def common_checks( # noqa: PLR0915
route=route,
request_headers=_safe_get_request_headers(request=request),
request_query_params=_safe_get_request_query_params(request=request),
llm_router=llm_router,
)
if route in MODEL_DISCOVERY_ROUTES:

View File

@ -1284,7 +1284,9 @@ def _route_uses_model_routing_sources(route: str) -> bool:
def _extract_models_from_managed_resource_id(
resource_id: Any, resource_id_field: Optional[str] = None
resource_id: Any,
resource_id_field: Optional[str] = None,
llm_router: Optional[Router] = None,
) -> List[str]:
if not isinstance(resource_id, str) or not resource_id:
return []
@ -1341,16 +1343,18 @@ def _extract_models_from_managed_resource_id(
)
if resource_id_field == "video_id":
model_id = decode_video_id_with_provider(resource_id).get("model_id")
_append_model_candidates(
candidates=candidates,
value=decode_video_id_with_provider(resource_id).get("model_id"),
value=_resolve_model_id_with_router(model_id, llm_router),
)
else:
model_id = decode_character_id_with_provider(resource_id).get(
"model_id"
)
_append_model_candidates(
candidates=candidates,
value=decode_character_id_with_provider(resource_id).get(
"model_id"
),
value=_resolve_model_id_with_router(model_id, llm_router),
)
except Exception as e:
verbose_proxy_logger.debug(
@ -1360,11 +1364,26 @@ def _extract_models_from_managed_resource_id(
return _dedupe_model_candidates(candidates)
def _resolve_model_id_with_router(
model_id: Optional[str], llm_router: Optional[Router]
) -> Optional[str]:
if model_id is None or llm_router is None:
return model_id
try:
return llm_router.resolve_model_name_from_model_id(model_id) or model_id
except Exception as e:
verbose_proxy_logger.debug(
"Unable to resolve model_id from managed resource ID: %s", str(e)
)
return model_id
def _extract_model_candidates_from_request(
request_data: dict,
route: str,
request_headers: Optional[Mapping[str, Any]] = None,
request_query_params: Optional[Mapping[str, Any]] = None,
llm_router: Optional[Router] = None,
) -> List[str]:
candidates: List[str] = []
uses_model_routing_sources = _route_uses_model_routing_sources(route=route)
@ -1414,7 +1433,9 @@ def _extract_model_candidates_from_request(
_append_model_candidates(
candidates,
_extract_models_from_managed_resource_id(
request_data.get(field), resource_id_field=field
request_data.get(field),
resource_id_field=field,
llm_router=llm_router,
),
)
@ -1436,12 +1457,14 @@ def get_model_from_request(
route: str,
request_headers: Optional[Mapping[str, Any]] = None,
request_query_params: Optional[Mapping[str, Any]] = None,
llm_router: Optional[Router] = None,
) -> Optional[Union[str, List[str]]]:
candidates = _extract_model_candidates_from_request(
request_data=request_data,
route=route,
request_headers=request_headers,
request_query_params=request_query_params,
llm_router=llm_router,
)
model = _format_model_candidates(candidates)

View File

@ -146,12 +146,14 @@ def _get_model_from_request_context(
request_data: dict,
route: str,
request: Optional[Request],
llm_router: Optional[Any] = None,
) -> Optional[Union[str, List[str]]]:
return get_model_from_request(
request_data=request_data,
route=route,
request_headers=_safe_get_request_headers(request=request),
request_query_params=_safe_get_request_query_params(request=request),
llm_router=llm_router,
)
@ -1034,6 +1036,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
skip_budget_checks = False
if model is not None and llm_router is not None:
@ -1451,6 +1454,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
skip_budget_checks = False
if model is not None and llm_router is not None:
@ -1579,6 +1583,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
current_models = _get_model_names_for_budget_checks(
model=current_model
@ -2159,6 +2164,7 @@ def _should_skip_budget_checks(
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
if model is not None and llm_router is not None:
return _is_model_cost_zero(model=model, llm_router=llm_router)
@ -2475,6 +2481,7 @@ async def _enforce_key_and_fallback_model_access(
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
if model is not None:
@ -2616,6 +2623,7 @@ async def _run_post_custom_auth_checks(
request_data=request_data,
route=route,
request=request,
llm_router=llm_router,
)
current_models = _get_model_names_for_budget_checks(model=current_model)

View File

@ -72,7 +72,7 @@ async def reserve_budget_for_request(
return None
if route in {"/models", "/v1/models", "/utils/token_counter"}:
return None
if get_model_from_request(request_body, route) is None:
if get_model_from_request(request_body, route, llm_router=llm_router) is None:
return None
counters = await _get_budget_counters(
@ -797,7 +797,7 @@ def estimate_request_max_cost(
route: str,
llm_router: Optional[Router],
) -> Optional[float]:
model = get_model_from_request(request_body, route)
model = get_model_from_request(request_body, route, llm_router=llm_router)
if model is None:
return None

View File

@ -399,6 +399,62 @@ def test_get_model_from_request_extracts_video_id_model():
)
def test_get_model_from_request_resolves_video_id_model_with_router():
from litellm.types.videos.utils import encode_video_id_with_provider
provider_video_id = (
"projects/test-project/locations/us-central1/publishers/google/models/"
"veo-3.1-generate-001/operations/operation-id"
)
video_id = encode_video_id_with_provider(
video_id=provider_video_id,
provider="vertex_ai",
model_id="veo-3.1-generate-001",
)
llm_router = MagicMock()
llm_router.resolve_model_name_from_model_id.return_value = (
"gcp/google/veo-3.1-generate-001"
)
assert (
get_model_from_request(
request_data={"video_id": video_id},
route="/v1/videos/{video_id}",
llm_router=llm_router,
)
== "gcp/google/veo-3.1-generate-001"
)
llm_router.resolve_model_name_from_model_id.assert_called_once_with(
"veo-3.1-generate-001"
)
def test_get_model_from_request_resolves_character_id_model_with_router():
from litellm.types.videos.utils import encode_character_id_with_provider
character_id = encode_character_id_with_provider(
character_id="character-provider-id",
provider="vertex_ai",
model_id="veo-3.1-generate-001",
)
llm_router = MagicMock()
llm_router.resolve_model_name_from_model_id.return_value = (
"gcp/google/veo-3.1-generate-001"
)
assert (
get_model_from_request(
request_data={"character_id": character_id},
route="/v1/videos/characters/{character_id}",
llm_router=llm_router,
)
== "gcp/google/veo-3.1-generate-001"
)
llm_router.resolve_model_name_from_model_id.assert_called_once_with(
"veo-3.1-generate-001"
)
def test_get_model_from_request_only_runs_media_decoders_for_matching_fields():
with (
patch(