diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e67702a797..3eb536d085 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1854,22 +1854,14 @@ async def increment_spend_counters( Awaited (not create_task) in the cost callback, so the counter is updated before the next request's auth check runs. """ - reserved_counter_keys: Set[str] = set() - if budget_reservation is not None: - from litellm.proxy.spend_tracking.budget_reservation import ( - get_reserved_counter_keys, - reconcile_budget_reservation, - ) - - reserved_counter_keys = get_reserved_counter_keys( - budget_reservation=budget_reservation - ) - await reconcile_budget_reservation( - budget_reservation=budget_reservation, - actual_cost=response_cost or 0.0, - ) + reserved_counter_keys = await _reconcile_budget_reservation_for_counter_update( + budget_reservation=budget_reservation, + response_cost=response_cost, + ) if response_cost is None or response_cost == 0: + if budget_reservation is not None: + budget_reservation["finalized"] = True return if token is not None: @@ -1989,6 +1981,31 @@ async def increment_spend_counters( response_cost=response_cost, reserved_counter_keys=reserved_counter_keys, ) + if budget_reservation is not None: + budget_reservation["finalized"] = True + + +async def _reconcile_budget_reservation_for_counter_update( + budget_reservation: Optional[dict], + response_cost: Optional[float], +) -> Set[str]: + if budget_reservation is None: + return set() + + from litellm.proxy.spend_tracking.budget_reservation import ( + get_reserved_counter_keys, + reconcile_budget_reservation, + ) + + reserved_counter_keys = get_reserved_counter_keys( + budget_reservation=budget_reservation + ) + await reconcile_budget_reservation( + budget_reservation=budget_reservation, + actual_cost=response_cost or 0.0, + finalize=False, + ) + return reserved_counter_keys async def _increment_end_user_and_tag_spend_counters( diff --git a/litellm/proxy/spend_tracking/budget_reservation.py b/litellm/proxy/spend_tracking/budget_reservation.py index 25aa95e1cb..77c6f9ae63 100644 --- a/litellm/proxy/spend_tracking/budget_reservation.py +++ b/litellm/proxy/spend_tracking/budget_reservation.py @@ -151,6 +151,7 @@ async def reserve_budget_for_request( async def reconcile_budget_reservation( budget_reservation: Optional[dict], actual_cost: Optional[float], + finalize: bool = True, ) -> None: if not budget_reservation or budget_reservation.get("finalized") is True: return @@ -162,7 +163,8 @@ async def reconcile_budget_reservation( actual_cost=actual, default_reserved_cost=reserved_cost, ) - budget_reservation["finalized"] = True + if finalize: + budget_reservation["finalized"] = True async def release_budget_reservation(budget_reservation: Optional[dict]) -> None: diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 9dfa6bc193..dd6ef04bc9 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -5324,6 +5324,108 @@ async def test_window_spend_counter_redis_clean_miss_skips_stale_in_memory(): ps.prisma_client = orig_prisma +@pytest.mark.asyncio +async def test_increment_spend_counters_finalizes_after_unreserved_increments(): + from litellm.caching.dual_cache import DualCache + from litellm.proxy.proxy_server import increment_spend_counters + + counter_cache = DualCache() + counter_cache.in_memory_cache.set_cache( + key="spend:key:key-finalize-after-increments", + value=0.5, + ) + budget_reservation = { + "reserved_cost": 0.5, + "entries": [ + { + "counter_key": "spend:key:key-finalize-after-increments", + "entity_type": "Key", + "entity_id": "key-finalize-after-increments", + "reserved_cost": 0.5, + "applied_adjustment": 0.0, + } + ], + "finalized": False, + } + incremented_counters = [] + + async def assert_reservation_not_finalized_yet(**kwargs): + assert budget_reservation["finalized"] is False + incremented_counters.append(kwargs["counter_key"]) + + import litellm.proxy.proxy_server as ps + + orig_counter, orig_user = ps.spend_counter_cache, ps.user_api_key_cache + ps.spend_counter_cache = counter_cache + ps.user_api_key_cache = DualCache() + try: + with patch( + "litellm.proxy.proxy_server._init_and_increment_spend_counter", + new=AsyncMock(side_effect=assert_reservation_not_finalized_yet), + ): + await increment_spend_counters( + token="key-finalize-after-increments", + team_id="team-finalize-after-increments", + user_id=None, + response_cost=0.25, + budget_reservation=budget_reservation, + ) + + assert incremented_counters == ["spend:team:team-finalize-after-increments"] + assert budget_reservation["finalized"] is True + assert counter_cache.in_memory_cache.get_cache( + key="spend:key:key-finalize-after-increments" + ) == pytest.approx(0.25) + finally: + ps.spend_counter_cache = orig_counter + ps.user_api_key_cache = orig_user + + +@pytest.mark.asyncio +async def test_increment_spend_counters_finalizes_none_cost_reservation(): + from litellm.caching.dual_cache import DualCache + from litellm.proxy.proxy_server import increment_spend_counters + + counter_cache = DualCache() + counter_cache.in_memory_cache.set_cache( + key="spend:key:key-finalize-none-cost", + value=0.5, + ) + budget_reservation = { + "reserved_cost": 0.5, + "entries": [ + { + "counter_key": "spend:key:key-finalize-none-cost", + "entity_type": "Key", + "entity_id": "key-finalize-none-cost", + "reserved_cost": 0.5, + "applied_adjustment": 0.0, + } + ], + "finalized": False, + } + + import litellm.proxy.proxy_server as ps + + orig_counter = ps.spend_counter_cache + ps.spend_counter_cache = counter_cache + try: + await increment_spend_counters( + token="key-finalize-none-cost", + team_id=None, + user_id=None, + response_cost=None, + budget_reservation=budget_reservation, + ) + + assert budget_reservation["finalized"] is True + assert counter_cache.in_memory_cache.get_cache( + key="spend:key:key-finalize-none-cost" + ) == pytest.approx(0.0) + finally: + ps.spend_counter_cache = orig_counter + + @pytest.mark.asyncio async def test_increment_spend_counter_invalidates_stale_cache_on_redis_failure(): from litellm.caching.dual_cache import DualCache