Merge pull request #21877 from BerriAI/litellm_oss_staging_02_22_2026

Litellm oss staging 02 22 2026
This commit is contained in:
Sameer Kankute 2026-02-23 18:50:47 +05:30 committed by GitHub
commit 8decf04d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 542 additions and 60 deletions

View File

@ -0,0 +1,119 @@
# Gollem Go Agent Framework with LiteLLM
A working example showing how to use [gollem](https://github.com/fugue-labs/gollem), a production-grade Go agent framework, with LiteLLM as a proxy gateway. This lets Go developers access 100+ LLM providers through a single proxy while keeping compile-time type safety for tools and structured output.
## Quick Start
### 1. Start LiteLLM Proxy
```bash
# Simple start with a single model
litellm --model gpt-4o
# Or with the example config for multi-provider access
litellm --config proxy_config.yaml
```
### 2. Run the examples
```bash
# Install Go dependencies
go mod tidy
# Basic agent
go run ./basic
# Agent with type-safe tools
go run ./tools
# Streaming responses
go run ./streaming
```
## Configuration
The included `proxy_config.yaml` sets up three providers through LiteLLM:
```yaml
model_list:
- model_name: gpt-4o # OpenAI
- model_name: claude-sonnet # Anthropic
- model_name: gemini-pro # Google Vertex AI
```
Switch providers in Go by changing a single string — no code changes needed:
```go
model := openai.NewLiteLLM("http://localhost:4000",
openai.WithModel("gpt-4o"), // OpenAI
// openai.WithModel("claude-sonnet"), // Anthropic
// openai.WithModel("gemini-pro"), // Google
)
```
## Examples
### `basic/` — Basic Agent
Connects gollem to LiteLLM and runs a simple prompt. Demonstrates the `NewLiteLLM` constructor and basic agent creation.
### `tools/` — Type-Safe Tools
Shows gollem's compile-time type-safe tool framework working through LiteLLM's tool-use passthrough. The tool parameters are Go structs with JSON tags — the schema is generated automatically at compile time.
### `streaming/` — Streaming Responses
Real-time token streaming using Go 1.23+ range-over-function iterators, proxied through LiteLLM's SSE passthrough.
## How It Works
Gollem's `openai.NewLiteLLM()` constructor creates an OpenAI-compatible provider pointed at your LiteLLM proxy. Since LiteLLM speaks the OpenAI API protocol, everything works out of the box:
- **Chat completions** — standard request/response
- **Tool use** — LiteLLM passes tool definitions and calls through transparently
- **Streaming** — Server-Sent Events proxied through LiteLLM
- **Structured output** — JSON schema response format works with supporting models
```
Go App (gollem) → LiteLLM Proxy → OpenAI / Anthropic / Google / ...
```
## Why Use This?
- **Type-safe Go**: Compile-time type checking for tools, structured output, and agent configuration — no runtime surprises
- **Single proxy, many models**: Switch between OpenAI, Anthropic, Google, and 100+ other providers by changing a model name string
- **Zero-dependency core**: gollem's core has no external dependencies — just stdlib
- **Single binary deployment**: `go build` produces one binary, no pip/venv/Docker needed
- **Cost tracking & rate limiting**: LiteLLM handles cost tracking, rate limits, and fallbacks at the proxy layer
## Environment Variables
```bash
# Required for providers you want to use (set in LiteLLM config or env)
export OPENAI_API_KEY="sk-..."
export ANTHROPIC_API_KEY="sk-ant-..."
# Optional: point to a non-default LiteLLM proxy
export LITELLM_PROXY_URL="http://localhost:4000"
```
## Troubleshooting
**Connection errors?**
- Make sure LiteLLM is running: `litellm --model gpt-4o`
- Check the URL is correct (default: `http://localhost:4000`)
**Model not found?**
- Verify the model name matches what's configured in LiteLLM
- Run `curl http://localhost:4000/models` to see available models
**Tool calls not working?**
- Ensure the underlying model supports tool use (GPT-4o, Claude, Gemini)
- Check LiteLLM logs for any provider-specific errors
## Learn More
- [gollem GitHub](https://github.com/fugue-labs/gollem)
- [gollem API Reference](https://pkg.go.dev/github.com/fugue-labs/gollem/core)
- [LiteLLM Proxy Docs](https://docs.litellm.ai/docs/simple_proxy)
- [LiteLLM Supported Models](https://docs.litellm.ai/docs/providers)

View File

@ -0,0 +1,41 @@
// Basic gollem agent connected to a LiteLLM proxy.
//
// Usage:
//
// litellm --model gpt-4o # start proxy in another terminal
// go run ./basic
package main
import (
"context"
"fmt"
"log"
"os"
"github.com/fugue-labs/gollem/core"
"github.com/fugue-labs/gollem/provider/openai"
)
func main() {
proxyURL := "http://localhost:4000"
if u := os.Getenv("LITELLM_PROXY_URL"); u != "" {
proxyURL = u
}
// Connect to LiteLLM proxy. NewLiteLLM creates an OpenAI-compatible
// provider pointed at the given URL.
model := openai.NewLiteLLM(proxyURL,
openai.WithModel("gpt-4o"), // any model name configured in LiteLLM
)
// Create and run a simple agent.
agent := core.NewAgent[string](model,
core.WithSystemPrompt[string]("You are a helpful assistant. Be concise."),
)
result, err := agent.Run(context.Background(), "Explain quantum computing in two sentences.")
if err != nil {
log.Fatal(err)
}
fmt.Println(result.Output)
}

View File

@ -0,0 +1,5 @@
module github.com/BerriAI/litellm/cookbook/gollem_go_agent_framework
go 1.25.1
require github.com/fugue-labs/gollem v0.1.0

View File

@ -0,0 +1,2 @@
github.com/fugue-labs/gollem v0.1.0 h1:QexYnvkb44QZFEljgAePqMIGZjgsbk0Y5GJ2jYYgfa8=
github.com/fugue-labs/gollem v0.1.0/go.mod h1:htW1YO81uysSKVOkYJtxhGCFrzm+36HBFxEWuECoHKQ=

View File

@ -0,0 +1,16 @@
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
- model_name: claude-sonnet
litellm_params:
model: anthropic/claude-sonnet-4-20250514
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: gemini-pro
litellm_params:
model: vertex_ai/gemini-2.0-flash
vertex_project: my-project
vertex_location: us-central1

View File

@ -0,0 +1,56 @@
// Streaming responses from gollem through LiteLLM.
//
// Uses Go 1.23+ range-over-function iterators for real-time token
// streaming via LiteLLM's SSE passthrough.
//
// Usage:
//
// litellm --model gpt-4o
// go run ./streaming
package main
import (
"context"
"fmt"
"log"
"os"
"github.com/fugue-labs/gollem/core"
"github.com/fugue-labs/gollem/provider/openai"
)
func main() {
proxyURL := "http://localhost:4000"
if u := os.Getenv("LITELLM_PROXY_URL"); u != "" {
proxyURL = u
}
model := openai.NewLiteLLM(proxyURL,
openai.WithModel("gpt-4o"),
)
agent := core.NewAgent[string](model)
// RunStream returns a streaming result that yields tokens as they arrive.
stream, err := agent.RunStream(context.Background(), "Write a haiku about distributed systems")
if err != nil {
log.Fatal(err)
}
// StreamText yields text chunks in real-time.
// The boolean argument controls whether deltas (true) or accumulated
// text (false) is returned.
fmt.Print("Response: ")
for text, err := range stream.StreamText(true) {
if err != nil {
log.Fatal(err)
}
fmt.Print(text)
}
fmt.Println()
// After streaming completes, the final response is available.
resp := stream.Response()
fmt.Printf("\nTokens used: input=%d, output=%d\n",
resp.Usage.InputTokens, resp.Usage.OutputTokens)
}

View File

@ -0,0 +1,64 @@
// Gollem agent with type-safe tools through LiteLLM.
//
// The tool parameters are Go structs — gollem generates the JSON schema
// automatically at compile time. LiteLLM passes tool definitions through
// transparently to the underlying provider.
//
// Usage:
//
// litellm --model gpt-4o
// go run ./tools
package main
import (
"context"
"fmt"
"log"
"os"
"github.com/fugue-labs/gollem/core"
"github.com/fugue-labs/gollem/provider/openai"
)
// WeatherParams defines the tool's input schema via struct tags.
// The JSON schema is generated at compile time — no runtime reflection needed.
type WeatherParams struct {
City string `json:"city" description:"City name to get weather for"`
Unit string `json:"unit,omitempty" description:"Temperature unit: celsius or fahrenheit"`
}
func main() {
proxyURL := "http://localhost:4000"
if u := os.Getenv("LITELLM_PROXY_URL"); u != "" {
proxyURL = u
}
model := openai.NewLiteLLM(proxyURL,
openai.WithModel("gpt-4o"),
)
// Define a type-safe tool. The function signature enforces correct types.
weatherTool := core.FuncTool[WeatherParams](
"get_weather",
"Get current weather for a city",
func(ctx context.Context, p WeatherParams) (string, error) {
unit := p.Unit
if unit == "" {
unit = "fahrenheit"
}
// In production, call a real weather API here.
return fmt.Sprintf("Weather in %s: 72°F (22°C), sunny", p.City), nil
},
)
agent := core.NewAgent[string](model,
core.WithTools[string](weatherTool),
core.WithSystemPrompt[string]("You are a helpful weather assistant. Use the get_weather tool to answer weather questions."),
)
result, err := agent.Run(context.Background(), "What's the weather like in San Francisco and Tokyo?")
if err != nil {
log.Fatal(err)
}
fmt.Println(result.Output)
}

View File

@ -22,6 +22,8 @@ litellm_settings:
This ensures that all budget resets happen at midnight in your specified timezone rather than in UTC.
If no timezone is specified, UTC will be used by default.
Any valid [IANA timezone string](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) is supported (powered by Python's `zoneinfo` module). DST transitions are handled automatically.
Common timezone values:
- `UTC` - Coordinated Universal Time

View File

@ -419,6 +419,7 @@ const sidebars = {
"proxy/dynamic_rate_limit",
"proxy/rate_limit_tiers",
"proxy/temporary_budget_increase",
"proxy/budget_reset_and_tz",
],
},
"proxy/caching",

View File

@ -13,6 +13,9 @@ if TYPE_CHECKING:
from litellm.router import Router
CHECK_BATCH_COST_USER_AGENT = "LiteLLM Proxy/CheckBatchCost"
class CheckBatchCost:
def __init__(
self,
@ -27,6 +30,25 @@ class CheckBatchCost:
self.prisma_client: PrismaClient = prisma_client
self.llm_router: Router = llm_router
async def _get_user_info(self, batch_id, user_id) -> dict:
"""
Look up user email and key alias by user_id for enriching the S3 callback metadata.
Returns a dict with user_api_key_user_email and user_api_key_alias (both may be None).
"""
try:
user_row = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if user_row is None:
return {}
return {
"user_api_key_user_email": getattr(user_row, "user_email", None),
"user_api_key_alias": getattr(user_row, "user_alias", None),
}
except Exception as e:
verbose_proxy_logger.error(f"CheckBatchCost: could not look up user {user_id} for batch {batch_id}: {e}")
return {}
async def check_batch_cost(self):
"""
Check if the batch JOB has been tracked.
@ -48,10 +70,12 @@ class CheckBatchCost:
get_model_id_from_unified_batch_id,
)
# Look for all batches that have not yet been processed by CheckBatchCost
jobs = await self.prisma_client.db.litellm_managedobjecttable.find_many(
where={
"status": {"in": ["validating", "in_progress", "finalizing"]},
"file_purpose": "batch",
"batch_processed" : False,
"status": {"not_in": ["failed", "expired", "cancelled"]}
}
)
completed_jobs = []
@ -107,6 +131,21 @@ class CheckBatchCost:
f"Batch ID: {batch_id} is complete, tracking cost and usage"
)
# aretrieve_batch is called with the raw provider batch ID, so response.id
# is the raw provider value (e.g. "batch_20260223-0518.234"). We need the
# unified base64 ID in the S3 log so downstream consumers can correlate it
# back to the batch they submitted via the proxy.
#
# CheckBatchCost builds its own LiteLLMLogging object (logging_obj below) and
# calls async_success_handler(result=response) directly. That handler calls
# _build_standard_logging_payload(response, ...) which reads response.id at
# that point — so setting response.id here is sufficient.
#
# The HTTP endpoint does this substitution via the managed files hook
# (async_post_call_success_hook). CheckBatchCost bypasses that hook entirely,
# so we do it explicitly here.
response.id = job.unified_object_id
# This background job runs as default_user_id, so going through the HTTP endpoint
# would trigger check_managed_file_id_access and get 403. Instead, extract the raw
# provider file ID and call afile_content directly with deployment credentials.
@ -171,11 +210,21 @@ class CheckBatchCost:
function_id=str(uuid.uuid4()),
)
creator_user_id = job.created_by
user_info = await self._get_user_info(batch_id, job.created_by)
logging_obj.update_environment_variables(
litellm_params={
# set the user-agent header so that S3 callback consumers can easily identify CheckBatchCost callbacks
"proxy_server_request": {
"headers": {
"user-agent": CHECK_BATCH_COST_USER_AGENT,
}
},
"metadata": {
"user_api_key_user_id": job.created_by or "default-user-id",
}
"user_api_key_user_id": creator_user_id,
**user_info,
},
},
optional_params={},
)
@ -191,8 +240,7 @@ class CheckBatchCost:
completed_jobs.append(job)
if len(completed_jobs) > 0:
# mark the jobs as complete
await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={"id": {"in": [job.id for job in completed_jobs]}},
data={"status": "complete"},
data={"batch_processed": True, "status": "complete"},
)

View File

@ -1086,11 +1086,8 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
self, file_id: str
) -> List[Dict[str, Any]]:
"""
Find batches in non-terminal states that reference this file.
Non-terminal states: validating, in_progress, finalizing
Terminal states: completed, complete, failed, expired, cancelled
Find batches that reference this file and still need cost tracking.
Find batches that are in non-terminal state and have not yet been processed by CheckBatchCost.
Args:
file_id: The unified file ID to check
@ -1121,7 +1118,8 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
batches = await self.prisma_client.db.litellm_managedobjecttable.find_many(
where={
"file_purpose": "batch",
"status": {"in": ["validating", "in_progress", "finalizing"]},
"batch_processed": False,
"status": {"not_in": ["failed", "expired", "cancelled"]}
},
take=MAX_MATCHES_TO_RETURN,
order={"created_at": "desc"},
@ -1205,7 +1203,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
error_message += (
f"To delete this file before complete cost tracking, please delete or cancel the referencing batch(es) first. "
f"Alternatively, wait for all batches to complete processing."
f"Alternatively, wait for all batches to complete and for cost to be computed (batch_processed=true)."
)
raise HTTPException(

View File

@ -0,0 +1,3 @@
-- Add batch_processed column to LiteLLM_ManagedObjectTable
-- Set to true by CheckBatchCost after cost has been computed for a completed batch
ALTER TABLE "LiteLLM_ManagedObjectTable" ADD COLUMN "batch_processed" BOOLEAN NOT NULL DEFAULT false;

View File

@ -813,6 +813,7 @@ model LiteLLM_ManagedObjectTable { // for batches or finetuning jobs which use t
file_object Json // Stores the OpenAIFileObject
file_purpose String // either 'batch' or 'fine-tune'
status String? // check if batch cost has been tracked
batch_processed Boolean @default(false) // set to true by CheckBatchCost after cost is computed
created_at DateTime @default(now())
created_by String?
updated_at DateTime @updatedAt

View File

@ -8,8 +8,9 @@ duration_in_seconds is used in diff parts of the code base, example
import re
import time
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta, timezone, tzinfo
from typing import Optional, Tuple
from zoneinfo import ZoneInfo
def _extract_from_regex(duration: str) -> Tuple[int, str]:
@ -116,7 +117,7 @@ def get_next_standardized_reset_time(
- Next reset time at a standardized interval in the specified timezone
"""
# Set up timezone and normalize current time
current_time, timezone = _setup_timezone(current_time, timezone_str)
current_time, tz = _setup_timezone(current_time, timezone_str)
# Parse duration
value, unit = _parse_duration(duration)
@ -131,7 +132,7 @@ def get_next_standardized_reset_time(
# Handle different time units
if unit == "d":
return _handle_day_reset(current_time, base_midnight, value, timezone)
return _handle_day_reset(current_time, base_midnight, value, tz)
elif unit == "h":
return _handle_hour_reset(current_time, base_midnight, value)
elif unit == "m":
@ -147,22 +148,13 @@ def get_next_standardized_reset_time(
def _setup_timezone(
current_time: datetime, timezone_str: str = "UTC"
) -> Tuple[datetime, timezone]:
) -> Tuple[datetime, tzinfo]:
"""Set up timezone and normalize current time to that timezone."""
try:
if timezone_str is None:
tz = timezone.utc
tz: tzinfo = timezone.utc
else:
# Map common timezone strings to their UTC offsets
timezone_map = {
"US/Eastern": timezone(timedelta(hours=-4)), # EDT
"US/Pacific": timezone(timedelta(hours=-7)), # PDT
"Asia/Kolkata": timezone(timedelta(hours=5, minutes=30)), # IST
"Asia/Bangkok": timezone(timedelta(hours=7)), # ICT (Indochina Time)
"Europe/London": timezone(timedelta(hours=1)), # BST
"UTC": timezone.utc,
}
tz = timezone_map.get(timezone_str, timezone.utc)
tz = ZoneInfo(timezone_str)
except Exception:
# If timezone is invalid, fall back to UTC
tz = timezone.utc
@ -190,7 +182,7 @@ def _parse_duration(duration: str) -> Tuple[Optional[int], Optional[str]]:
def _handle_day_reset(
current_time: datetime, base_midnight: datetime, value: int, timezone: timezone
current_time: datetime, base_midnight: datetime, value: int, tz: tzinfo
) -> datetime:
"""Handle day-based reset times."""
# Handle zero value - immediate expiration
@ -215,7 +207,7 @@ def _handle_day_reset(
minute=0,
second=0,
microsecond=0,
tzinfo=timezone,
tzinfo=tz,
)
else:
next_reset = datetime(
@ -226,7 +218,7 @@ def _handle_day_reset(
minute=0,
second=0,
microsecond=0,
tzinfo=timezone,
tzinfo=tz,
)
return next_reset
else: # Custom day value - next interval is value days from current

View File

@ -6,7 +6,17 @@ import logging
import threading
import time
import traceback
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Union,
cast,
)
import anyio
import httpx
@ -151,10 +161,10 @@ class CustomStreamWrapper:
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
self.created: Optional[int] = None
def __iter__(self):
def __iter__(self) -> Iterator["ModelResponseStream"]:
return self
def __aiter__(self):
def __aiter__(self) -> AsyncIterator["ModelResponseStream"]:
return self
async def aclose(self):
@ -1726,7 +1736,7 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "tool_calls"
return model_response
def __next__(self): # noqa: PLR0915
def __next__(self) -> "ModelResponseStream": # noqa: PLR0915
cache_hit = False
if (
self.custom_llm_provider is not None
@ -1748,7 +1758,7 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk.decode('utf-8', errors='replace') if isinstance(chunk, bytes) else chunk}; custom_llm_provider: {self.custom_llm_provider}"
)
response: Optional[ModelResponseStream] = self.chunk_creator(
chunk=chunk
@ -1900,7 +1910,7 @@ class CustomStreamWrapper:
return self.completion_stream
async def __anext__(self): # noqa: PLR0915
async def __anext__(self) -> "ModelResponseStream": # noqa: PLR0915
cache_hit = False
if (
self.custom_llm_provider is not None
@ -1996,9 +2006,7 @@ class CustomStreamWrapper:
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
processed_chunk: Optional[
ModelResponseStream
] = self.chunk_creator(chunk=chunk)
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue

View File

@ -1,27 +1,23 @@
from datetime import datetime, timezone
import litellm
from litellm.litellm_core_utils.duration_parser import get_next_standardized_reset_time
def get_budget_reset_timezone():
"""
Get the budget reset timezone from general_settings.
Get the budget reset timezone from litellm_settings.
Falls back to UTC if not specified.
litellm_settings values are set as attributes on the litellm module
by proxy_server.py at startup (via setattr(litellm, key, value)).
"""
# Import at function level to avoid circular imports
from litellm.proxy.proxy_server import general_settings
if general_settings:
litellm_settings = general_settings.get("litellm_settings", {})
if litellm_settings and "timezone" in litellm_settings:
return litellm_settings["timezone"]
return "UTC"
return getattr(litellm, "timezone", None) or "UTC"
def get_budget_reset_time(budget_duration: str):
"""
Get the budget reset time from general_settings.
Get the budget reset time based on the configured timezone.
Falls back to UTC if not specified.
"""

View File

@ -812,6 +812,7 @@ model LiteLLM_ManagedObjectTable { // for batches or finetuning jobs which use t
file_object Json // Stores the OpenAIFileObject
file_purpose String // either 'batch' or 'fine-tune'
status String? // check if batch cost has been tracked
batch_processed Boolean @default(false) // set to true by CheckBatchCost after cost is computed
created_at DateTime @default(now())
created_by String?
updated_at DateTime @updatedAt

View File

@ -12,4 +12,7 @@ exclude = ["litellm/types/*", "litellm/__init__.py", "litellm/proxy/example_conf
"litellm/llms/anthropic/chat/__init__.py" = ["F401"]
"litellm/llms/azure_ai/embed/__init__.py" = ["F401"]
"litellm/llms/azure_ai/rerank/__init__.py" = ["F401"]
"litellm/llms/bedrock/chat/__init__.py" = ["F401"]
"litellm/llms/bedrock/chat/__init__.py" = ["F401"]
"litellm/proxy/utils.py" = ["F401", "PLR0915"]
"litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py" = ["PLR0915"]
"litellm/proxy/guardrails/guardrail_hooks/guardrail_benchmarks/test_eval.py" = ["PLR0915"]

View File

@ -813,6 +813,7 @@ model LiteLLM_ManagedObjectTable { // for batches or finetuning jobs which use t
file_object Json // Stores the OpenAIFileObject
file_purpose String // either 'batch' or 'fine-tune'
status String? // check if batch cost has been tracked
batch_processed Boolean @default(false) // set to true by CheckBatchCost after cost is computed
created_at DateTime @default(now())
created_by String?
updated_at DateTime @updatedAt

View File

@ -91,7 +91,9 @@ class TestStandardizedResetTime(unittest.TestCase):
# Test Bangkok timezone (UTC+7): 5:30 AM next day, so next reset is midnight the day after
bangkok = ZoneInfo("Asia/Bangkok")
bangkok_expected = datetime(2023, 5, 17, 0, 0, 0, tzinfo=bangkok)
bangkok_result = get_next_standardized_reset_time("1d", base_time, "Asia/Bangkok")
bangkok_result = get_next_standardized_reset_time(
"1d", base_time, "Asia/Bangkok"
)
self.assertEqual(bangkok_result, bangkok_expected)
def test_edge_cases(self):
@ -125,6 +127,62 @@ class TestStandardizedResetTime(unittest.TestCase):
)
self.assertEqual(invalid_tz_result, invalid_tz_expected)
def test_iana_timezones_previously_unsupported(self):
"""Test IANA timezones that were previously unsupported by the hardcoded map."""
# Base time: 2023-05-15 15:00:00 UTC
base_time = datetime(2023, 5, 15, 15, 0, 0, tzinfo=timezone.utc)
# Asia/Tokyo (UTC+9): 15:00 UTC = 00:00 JST May 16, exactly on midnight boundary → next day
tokyo = ZoneInfo("Asia/Tokyo")
tokyo_expected = datetime(2023, 5, 17, 0, 0, 0, tzinfo=tokyo)
tokyo_result = get_next_standardized_reset_time(
"1d", base_time, "Asia/Tokyo"
)
self.assertEqual(tokyo_result, tokyo_expected)
# Australia/Sydney (UTC+10): 2023-05-16 01:00 AEST
sydney = ZoneInfo("Australia/Sydney")
# At 15:00 UTC it's 01:00 AEST May 16 → next midnight is May 17 00:00 AEST
sydney_expected = datetime(2023, 5, 17, 0, 0, 0, tzinfo=sydney)
sydney_result = get_next_standardized_reset_time(
"1d", base_time, "Australia/Sydney"
)
self.assertEqual(sydney_result, sydney_expected)
# America/Chicago (UTC-5): at 15:00 UTC it's 10:00 CDT → next midnight is May 16 00:00 CDT
chicago = ZoneInfo("America/Chicago")
chicago_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=chicago)
chicago_result = get_next_standardized_reset_time(
"1d", base_time, "America/Chicago"
)
self.assertEqual(chicago_result, chicago_expected)
def test_dst_fall_back(self):
"""Test DST fall-back transition (clocks go back 1 hour)."""
# US/Eastern DST ends first Sunday of November 2023 (Nov 5)
# At 2023-11-05 05:30 UTC = 01:30 EDT (before fall-back)
# After fall-back at 06:00 UTC = 01:00 EST
pre_fallback = datetime(2023, 11, 5, 5, 30, 0, tzinfo=timezone.utc)
eastern = ZoneInfo("US/Eastern")
# Daily reset: next midnight should be Nov 6 00:00 EST
expected = datetime(2023, 11, 6, 0, 0, 0, tzinfo=eastern)
result = get_next_standardized_reset_time("1d", pre_fallback, "US/Eastern")
self.assertEqual(result, expected)
def test_dst_spring_forward(self):
"""Test DST spring-forward transition (clocks go forward 1 hour)."""
# US/Eastern DST starts second Sunday of March 2023 (Mar 12)
# At 2023-03-12 06:30 UTC = 01:30 EST (before spring-forward)
# After spring-forward at 07:00 UTC = 03:00 EDT
pre_spring = datetime(2023, 3, 12, 6, 30, 0, tzinfo=timezone.utc)
eastern = ZoneInfo("US/Eastern")
# Daily reset: next midnight should be Mar 13 00:00 EDT
expected = datetime(2023, 3, 13, 0, 0, 0, tzinfo=eastern)
result = get_next_standardized_reset_time("1d", pre_spring, "US/Eastern")
self.assertEqual(result, expected)
if __name__ == "__main__":
unittest.main()

View File

@ -1,18 +1,17 @@
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timedelta, timezone
import pytest
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from zoneinfo import ZoneInfo
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.timezone_utils import get_budget_reset_time
import litellm
from litellm.proxy.common_utils.timezone_utils import (
get_budget_reset_time,
get_budget_reset_timezone,
)
def test_get_budget_reset_time():
@ -33,3 +32,71 @@ def test_get_budget_reset_time():
# Verify budget_reset_at is set to first of next month
assert get_budget_reset_time(budget_duration="1mo") == expected_reset_at
def test_get_budget_reset_timezone_reads_litellm_attr():
"""
Test that get_budget_reset_timezone reads from litellm.timezone attribute.
"""
original = getattr(litellm, "timezone", None)
try:
litellm.timezone = "Asia/Tokyo"
assert get_budget_reset_timezone() == "Asia/Tokyo"
finally:
if original is None:
if hasattr(litellm, "timezone"):
delattr(litellm, "timezone")
else:
litellm.timezone = original
def test_get_budget_reset_timezone_fallback_utc():
"""
Test that get_budget_reset_timezone falls back to UTC when litellm.timezone is not set.
"""
original = getattr(litellm, "timezone", None)
try:
if hasattr(litellm, "timezone"):
delattr(litellm, "timezone")
assert get_budget_reset_timezone() == "UTC"
finally:
if original is not None:
litellm.timezone = original
def test_get_budget_reset_timezone_fallback_on_none():
"""
Test that get_budget_reset_timezone falls back to UTC when litellm.timezone is None.
"""
original = getattr(litellm, "timezone", None)
try:
litellm.timezone = None
assert get_budget_reset_timezone() == "UTC"
finally:
if original is None:
if hasattr(litellm, "timezone"):
delattr(litellm, "timezone")
else:
litellm.timezone = original
def test_get_budget_reset_time_respects_timezone():
"""
Test that get_budget_reset_time uses the configured timezone for reset calculation.
A daily reset should align to midnight in the configured timezone.
"""
original = getattr(litellm, "timezone", None)
try:
litellm.timezone = "Asia/Tokyo"
reset_at = get_budget_reset_time(budget_duration="1d")
# The reset time should be midnight in Asia/Tokyo
tokyo_reset = reset_at.astimezone(ZoneInfo("Asia/Tokyo"))
assert tokyo_reset.hour == 0
assert tokyo_reset.minute == 0
assert tokyo_reset.second == 0
finally:
if original is None:
if hasattr(litellm, "timezone"):
delattr(litellm, "timezone")
else:
litellm.timezone = original