Litellm fix GitHub action testing (#11163)

* test: add __init__.py files

* refactor: rename test folder to avoid naming conflict

* test: update workflows

* test: update tests

* test: update imports

* test: update tests

* test: remove unused import

* ci(test-litellm.yml): add pytest retry to github workflow

* test: fix test
This commit is contained in:
Krish Dholakia 2025-05-26 14:41:42 -07:00 committed by GitHub
parent 9a35c41462
commit ef42461c1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
170 changed files with 1371 additions and 793 deletions

View File

@ -28,6 +28,7 @@ jobs:
- name: Install dependencies
run: |
poetry install --with dev,proxy-dev --extras proxy
poetry run pip install "pytest-retry==1.6.3"
poetry run pip install pytest-xdist
- name: Setup litellm-enterprise as local package
run: |
@ -36,4 +37,4 @@ jobs:
cd ..
- name: Run tests
run: |
poetry run pytest tests/litellm -x -vv -n 4
poetry run pytest tests/test_litellm -x -vv -n 4

View File

@ -24,7 +24,7 @@ repos:
rev: 7.0.0 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/tests/|^litellm/tests/litellm/|^tests/litellm/
exclude: ^litellm/tests/|^litellm/proxy/tests/|^litellm/tests/test_litellm/|^tests/test_litellm/
additional_dependencies: [flake8-print]
files: (litellm/|litellm_proxy_extras/|enterprise/).*\.py
- repo: https://github.com/python-poetry/poetry

View File

@ -4392,7 +4392,7 @@
"supports_function_calling": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"deprecation_date": "2025-1-6"
"deprecation_date": "2025-01-06"
},
"groq/llama3-groq-8b-8192-tool-use-preview": {
"max_tokens": 8192,
@ -4405,7 +4405,7 @@
"supports_function_calling": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"deprecation_date": "2025-1-6"
"deprecation_date": "2025-01-06"
},
"groq/qwen-qwq-32b": {
"max_tokens": 128000,

View File

@ -1,207 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock, ANY
import json
import datetime
import sys
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from litellm.integrations.athina import AthinaLogger
class TestAthinaLogger(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
self.env_patcher = patch.dict('os.environ', {
'ATHINA_API_KEY': 'test-api-key',
'ATHINA_BASE_URL': 'https://test.athina.ai'
})
self.env_patcher.start()
self.logger = AthinaLogger()
# Setup common test variables
self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1)
self.print_verbose = MagicMock()
def tearDown(self):
self.env_patcher.stop()
def test_init(self):
"""Test the initialization of AthinaLogger"""
self.assertEqual(self.logger.athina_api_key, 'test-api-key')
self.assertEqual(self.logger.athina_logging_url, 'https://test.athina.ai/api/v1/log/inference')
self.assertEqual(self.logger.headers, {
'athina-api-key': 'test-api-key',
'Content-Type': 'application/json'
})
@patch('litellm.module_level_client.post')
def test_log_event_success(self, mock_post):
"""Test successful logging of an event"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "Success"
mock_post.return_value = mock_response
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False,
'litellm_params': {
'metadata': {
'environment': 'test-environment',
'prompt_slug': 'test-prompt',
'customer_id': 'test-customer',
'customer_user_id': 'test-user',
'session_id': 'test-session',
'external_reference_id': 'test-ext-ref',
'context': 'test-context',
'expected_response': 'test-expected',
'user_query': 'test-query',
'tags': ['test-tag'],
'user_feedback': 'test-feedback',
'model_options': {'test-opt': 'test-val'},
'custom_attributes': {'test-attr': 'test-val'}
}
}
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'choices': [{'message': {'content': 'Hi there'}}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify the results
mock_post.assert_called_once()
call_args = mock_post.call_args
self.assertEqual(call_args[0][0], 'https://test.athina.ai/api/v1/log/inference')
self.assertEqual(call_args[1]['headers'], self.logger.headers)
# Parse and verify the sent data
sent_data = json.loads(call_args[1]['data'])
self.assertEqual(sent_data['language_model_id'], 'gpt-4')
self.assertEqual(sent_data['prompt'], kwargs['messages'])
self.assertEqual(sent_data['prompt_tokens'], 10)
self.assertEqual(sent_data['completion_tokens'], 5)
self.assertEqual(sent_data['total_tokens'], 15)
self.assertEqual(sent_data['response_time'], 1000) # 1 second = 1000ms
self.assertEqual(sent_data['customer_id'], 'test-customer')
self.assertEqual(sent_data['session_id'], 'test-session')
self.assertEqual(sent_data['environment'], 'test-environment')
self.assertEqual(sent_data['prompt_slug'], 'test-prompt')
self.assertEqual(sent_data['external_reference_id'], 'test-ext-ref')
self.assertEqual(sent_data['context'], 'test-context')
self.assertEqual(sent_data['expected_response'], 'test-expected')
self.assertEqual(sent_data['user_query'], 'test-query')
self.assertEqual(sent_data['tags'], ['test-tag'])
self.assertEqual(sent_data['user_feedback'], 'test-feedback')
self.assertEqual(sent_data['model_options'], {'test-opt': 'test-val'})
self.assertEqual(sent_data['custom_attributes'], {'test-attr': 'test-val'})
# Verify the print_verbose was called
self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success")
@patch('litellm.module_level_client.post')
def test_log_event_error_response(self, mock_post):
"""Test handling of error response from the API"""
# Setup mock error response
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'choices': [{'message': {'content': 'Hi there'}}],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify print_verbose was called with error message
self.print_verbose.assert_called_once_with("Athina Logger Error - Bad Request, 400")
@patch('litellm.module_level_client.post')
def test_log_event_exception(self, mock_post):
"""Test handling of exceptions during logging"""
# Setup mock to raise exception
mock_post.side_effect = Exception("Test exception")
# Create test data
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': 'Hello'}],
'stream': False
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify print_verbose was called with exception info
self.print_verbose.assert_called_once()
self.assertIn("Athina Logger Error - Test exception", self.print_verbose.call_args[0][0])
@patch('litellm.module_level_client.post')
def test_log_event_with_tools(self, mock_post):
"""Test logging with tools/functions data"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
# Create test data with tools
kwargs = {
'model': 'gpt-4',
'messages': [{'role': 'user', 'content': "What's the weather?"}],
'stream': False,
'optional_params': {
'tools': [{'type': 'function', 'function': {'name': 'get_weather'}}]
}
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
'id': 'resp-123',
'usage': {'prompt_tokens': 10, 'completion_tokens': 5, 'total_tokens': 15}
}
# Call the method
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
# Verify the results
sent_data = json.loads(mock_post.call_args[1]['data'])
self.assertEqual(sent_data['tools'], [{'type': 'function', 'function': {'name': 'get_weather'}}])
if __name__ == '__main__':
unittest.main()

View File

@ -1 +0,0 @@
"""Tests for the LiteLLM Proxy Client CLI package."""

View File

View File

@ -16,7 +16,7 @@ from litellm.caching.redis_cache import RedisCache
@pytest.fixture
def redis_no_ping():
"""Patch RedisCache initialization to prevent async ping tasks from being created"""
with patch('asyncio.get_running_loop') as mock_get_loop:
with patch("asyncio.get_running_loop") as mock_get_loop:
# Either raise an exception or return a mock that will handle the task creation
mock_get_loop.side_effect = RuntimeError("No running event loop")
yield
@ -64,32 +64,32 @@ async def test_redis_client_init_with_socket_timeout(monkeypatch, redis_no_ping)
async def test_redis_cache_async_batch_get_cache(monkeypatch, redis_no_ping):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache()
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
# Make sure the mock can be used as an async context manager
mock_redis_instance.__aenter__.return_value = mock_redis_instance
mock_redis_instance.__aexit__.return_value = None
# Setup the return value for mget
mock_redis_instance.mget.return_value = [
b'{"key1": "value1"}',
None,
b'{"key3": "value3"}'
b'{"key3": "value3"}',
]
test_keys = ["key1", "key2", "key3"]
with patch.object(
redis_cache, "init_async_client", return_value=mock_redis_instance
):
# Call async_batch_get_cache
result = await redis_cache.async_batch_get_cache(key_list=test_keys)
# Verify mget was called with the correct keys
mock_redis_instance.mget.assert_called_once()
# Check that results were properly decoded
assert result["key1"] == {"key1": "value1"}
assert result["key2"] is None

View File

@ -1,6 +1,6 @@
import os
import sys
from unittest.mock import MagicMock, patch, AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -13,27 +13,30 @@ sys.path.insert(
def test_redis_semantic_cache_initialization(monkeypatch):
# Mock the redisvl import
semantic_cache_mock = MagicMock()
with patch.dict("sys.modules", {
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock())
}):
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock()),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize the cache with a similarity threshold
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Verify the semantic cache was initialized with correct parameters
assert redis_semantic_cache.similarity_threshold == 0.8
# Use pytest.approx for floating point comparison to handle precision issues
assert redis_semantic_cache.distance_threshold == pytest.approx(0.2, abs=1e-10)
assert redis_semantic_cache.embedding_model == "text-embedding-ada-002"
# Test initialization with missing similarity_threshold
with pytest.raises(ValueError, match="similarity_threshold must be provided"):
RedisSemanticCache()
@ -43,42 +46,48 @@ def test_redis_semantic_cache_get_cache(monkeypatch):
# Mock the redisvl import and embedding function
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict("sys.modules", {
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=custom_vectorizer_mock)
}):
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize cache
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Mock the llmcache.check method to return a result
mock_result = [
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1 # Distance of 0.1 means similarity of 0.9
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
}
]
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
# Mock the embedding function
with patch("litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}):
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
# Test get_cache with a message
result = redis_semantic_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}]
key="test_key", messages=[{"content": "What is the capital of France?"}]
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
# Verify llmcache.check was called
redis_semantic_cache.llmcache.check.assert_called_once()
@ -88,43 +97,50 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
# Mock the redisvl import
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict("sys.modules", {
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=custom_vectorizer_mock)
}):
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
# Set environment variables
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
# Initialize cache
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
# Mock the async methods
mock_result = [
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1 # Distance of 0.1 means similarity of 0.9
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
}
]
redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=mock_result)
redis_semantic_cache._get_async_embedding = AsyncMock(return_value=[0.1, 0.2, 0.3])
redis_semantic_cache._get_async_embedding = AsyncMock(
return_value=[0.1, 0.2, 0.3]
)
# Test async_get_cache with a message
result = await redis_semantic_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata={}
metadata={},
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
# Verify methods were called
redis_semantic_cache._get_async_embedding.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once()

View File

@ -162,13 +162,13 @@ class TestSlackAlerting(unittest.TestCase):
self.assertEqual(event, "soft_budget_crossed")
self.assertTrue("Total Soft Budget" in event_message)
# Calling update_values with alerting args should try to start the periodic task
# Calling update_values with alerting args should try to start the periodic task
@patch("asyncio.create_task")
def test_update_values_starts_periodic_task(self, mock_create_task):
# Make it do nothing (or return a dummy future)
mock_create_task.return_value = AsyncMock() # prevents awaiting errors
assert(self.slack_alerting.periodic_started == False)
assert self.slack_alerting.periodic_started == False
self.slack_alerting.update_values(alerting_args={"slack_alerting": "True"})
assert(self.slack_alerting.periodic_started == True)
assert self.slack_alerting.periodic_started == True

View File

@ -5,7 +5,6 @@ from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
class TestArizePhoenixConfig(unittest.TestCase):
@patch.dict(
"os.environ",
{

View File

@ -7,15 +7,17 @@ from typing import Optional
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import litellm
import pytest
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.integrations._types.open_inference import (
SpanAttributes,
MessageAttributes,
SpanAttributes,
ToolCallAttributes,
)
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Choices, StandardCallbackDynamicParams
@ -25,6 +27,7 @@ def test_arize_set_attributes():
Ensures that the correct span attributes are being added during a request.
"""
from unittest.mock import MagicMock
from litellm.types.utils import ModelResponse
span = MagicMock() # Mocked tracing span to test attribute setting

View File

@ -1,18 +1,20 @@
import os
import sys
import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
sys.path.insert(0, os.path.abspath("../..")) # Adds the parent directory to the system-path
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from litellm.integrations.agentops.agentops import AgentOps, AgentOpsConfig
@pytest.fixture
def mock_auth_response():
return {
"token": "test_jwt_token",
"project_id": "test_project_id"
}
return {"token": "test_jwt_token", "project_id": "test_project_id"}
@pytest.fixture
def agentops_config():
@ -21,16 +23,20 @@ def agentops_config():
api_key="test_api_key",
service_name="test_service",
deployment_environment="test_env",
auth_endpoint="https://api.agentops.ai/v3/auth/token"
auth_endpoint="https://api.agentops.ai/v3/auth/token",
)
def test_agentops_config_from_env():
"""Test that AgentOpsConfig correctly reads from environment variables"""
with patch.dict(os.environ, {
"AGENTOPS_API_KEY": "test_key",
"AGENTOPS_SERVICE_NAME": "test_service",
"AGENTOPS_ENVIRONMENT": "test_env"
}):
with patch.dict(
os.environ,
{
"AGENTOPS_API_KEY": "test_key",
"AGENTOPS_SERVICE_NAME": "test_service",
"AGENTOPS_ENVIRONMENT": "test_env",
},
):
config = AgentOpsConfig.from_env()
assert config.api_key == "test_key"
assert config.service_name == "test_service"
@ -38,6 +44,7 @@ def test_agentops_config_from_env():
assert config.endpoint == "https://otlp.agentops.cloud/v1/traces"
assert config.auth_endpoint == "https://api.agentops.ai/v3/auth/token"
def test_agentops_config_defaults():
"""Test that AgentOpsConfig uses correct default values"""
config = AgentOpsConfig()
@ -47,52 +54,64 @@ def test_agentops_config_defaults():
assert config.endpoint == "https://otlp.agentops.cloud/v1/traces"
assert config.auth_endpoint == "https://api.agentops.ai/v3/auth/token"
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
def test_fetch_auth_token_success(mock_fetch_auth_token, mock_auth_response):
"""Test successful JWT token fetch"""
mock_fetch_auth_token.return_value = mock_auth_response
config = AgentOpsConfig(api_key="test_key")
agentops = AgentOps(config=config)
mock_fetch_auth_token.assert_called_once_with("test_key", "https://api.agentops.ai/v3/auth/token")
assert agentops.resource_attributes.get("project.id") == mock_auth_response.get("project_id")
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
mock_fetch_auth_token.assert_called_once_with(
"test_key", "https://api.agentops.ai/v3/auth/token"
)
assert agentops.resource_attributes.get("project.id") == mock_auth_response.get(
"project_id"
)
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
def test_fetch_auth_token_failure(mock_fetch_auth_token):
"""Test failed JWT token fetch"""
mock_fetch_auth_token.side_effect = Exception("Failed to fetch auth token: Unauthorized")
mock_fetch_auth_token.side_effect = Exception(
"Failed to fetch auth token: Unauthorized"
)
config = AgentOpsConfig(api_key="test_key")
agentops = AgentOps(config=config)
mock_fetch_auth_token.assert_called_once()
assert "project.id" not in agentops.resource_attributes
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
def test_agentops_initialization(mock_fetch_auth_token, agentops_config, mock_auth_response):
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
def test_agentops_initialization(
mock_fetch_auth_token, agentops_config, mock_auth_response
):
"""Test AgentOps initialization with config"""
mock_fetch_auth_token.return_value = mock_auth_response
agentops = AgentOps(config=agentops_config)
assert agentops.resource_attributes["service.name"] == "test_service"
assert agentops.resource_attributes["deployment.environment"] == "test_env"
assert agentops.resource_attributes["telemetry.sdk.name"] == "agentops"
assert agentops.resource_attributes["project.id"] == "test_project_id"
def test_agentops_initialization_no_auth():
"""Test AgentOps initialization without authentication"""
test_config = AgentOpsConfig(
endpoint="https://otlp.agentops.cloud/v1/traces",
api_key=None, # No API key
service_name="test_service",
deployment_environment="test_env"
deployment_environment="test_env",
)
agentops = AgentOps(config=test_config)
assert agentops.resource_attributes["service.name"] == "test_service"
assert agentops.resource_attributes["deployment.environment"] == "test_env"
assert agentops.resource_attributes["telemetry.sdk.name"] == "agentops"
assert "project.id" not in agentops.resource_attributes
assert "project.id" not in agentops.resource_attributes

View File

@ -0,0 +1,220 @@
import datetime
import json
import os
import sys
import unittest
from unittest.mock import ANY, MagicMock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
from litellm.integrations.athina import AthinaLogger
class TestAthinaLogger(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
self.env_patcher = patch.dict(
"os.environ",
{
"ATHINA_API_KEY": "test-api-key",
"ATHINA_BASE_URL": "https://test.athina.ai",
},
)
self.env_patcher.start()
self.logger = AthinaLogger()
# Setup common test variables
self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1)
self.print_verbose = MagicMock()
def tearDown(self):
self.env_patcher.stop()
def test_init(self):
"""Test the initialization of AthinaLogger"""
self.assertEqual(self.logger.athina_api_key, "test-api-key")
self.assertEqual(
self.logger.athina_logging_url,
"https://test.athina.ai/api/v1/log/inference",
)
self.assertEqual(
self.logger.headers,
{"athina-api-key": "test-api-key", "Content-Type": "application/json"},
)
@patch("litellm.module_level_client.post")
def test_log_event_success(self, mock_post):
"""Test successful logging of an event"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "Success"
mock_post.return_value = mock_response
# Create test data
kwargs = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"stream": False,
"litellm_params": {
"metadata": {
"environment": "test-environment",
"prompt_slug": "test-prompt",
"customer_id": "test-customer",
"customer_user_id": "test-user",
"session_id": "test-session",
"external_reference_id": "test-ext-ref",
"context": "test-context",
"expected_response": "test-expected",
"user_query": "test-query",
"tags": ["test-tag"],
"user_feedback": "test-feedback",
"model_options": {"test-opt": "test-val"},
"custom_attributes": {"test-attr": "test-val"},
}
},
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
"id": "resp-123",
"choices": [{"message": {"content": "Hi there"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
# Call the method
self.logger.log_event(
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
)
# Verify the results
mock_post.assert_called_once()
call_args = mock_post.call_args
self.assertEqual(call_args[0][0], "https://test.athina.ai/api/v1/log/inference")
self.assertEqual(call_args[1]["headers"], self.logger.headers)
# Parse and verify the sent data
sent_data = json.loads(call_args[1]["data"])
self.assertEqual(sent_data["language_model_id"], "gpt-4")
self.assertEqual(sent_data["prompt"], kwargs["messages"])
self.assertEqual(sent_data["prompt_tokens"], 10)
self.assertEqual(sent_data["completion_tokens"], 5)
self.assertEqual(sent_data["total_tokens"], 15)
self.assertEqual(sent_data["response_time"], 1000) # 1 second = 1000ms
self.assertEqual(sent_data["customer_id"], "test-customer")
self.assertEqual(sent_data["session_id"], "test-session")
self.assertEqual(sent_data["environment"], "test-environment")
self.assertEqual(sent_data["prompt_slug"], "test-prompt")
self.assertEqual(sent_data["external_reference_id"], "test-ext-ref")
self.assertEqual(sent_data["context"], "test-context")
self.assertEqual(sent_data["expected_response"], "test-expected")
self.assertEqual(sent_data["user_query"], "test-query")
self.assertEqual(sent_data["tags"], ["test-tag"])
self.assertEqual(sent_data["user_feedback"], "test-feedback")
self.assertEqual(sent_data["model_options"], {"test-opt": "test-val"})
self.assertEqual(sent_data["custom_attributes"], {"test-attr": "test-val"})
# Verify the print_verbose was called
self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success")
@patch("litellm.module_level_client.post")
def test_log_event_error_response(self, mock_post):
"""Test handling of error response from the API"""
# Setup mock error response
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response
# Create test data
kwargs = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"stream": False,
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
"id": "resp-123",
"choices": [{"message": {"content": "Hi there"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
# Call the method
self.logger.log_event(
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
)
# Verify print_verbose was called with error message
self.print_verbose.assert_called_once_with(
"Athina Logger Error - Bad Request, 400"
)
@patch("litellm.module_level_client.post")
def test_log_event_exception(self, mock_post):
"""Test handling of exceptions during logging"""
# Setup mock to raise exception
mock_post.side_effect = Exception("Test exception")
# Create test data
kwargs = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"stream": False,
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {}
# Call the method
self.logger.log_event(
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
)
# Verify print_verbose was called with exception info
self.print_verbose.assert_called_once()
self.assertIn(
"Athina Logger Error - Test exception", self.print_verbose.call_args[0][0]
)
@patch("litellm.module_level_client.post")
def test_log_event_with_tools(self, mock_post):
"""Test logging with tools/functions data"""
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
# Create test data with tools
kwargs = {
"model": "gpt-4",
"messages": [{"role": "user", "content": "What's the weather?"}],
"stream": False,
"optional_params": {
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
},
}
response_obj = MagicMock()
response_obj.model_dump.return_value = {
"id": "resp-123",
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
# Call the method
self.logger.log_event(
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
)
# Verify the results
sent_data = json.loads(mock_post.call_args[1]["data"])
self.assertEqual(
sent_data["tools"],
[{"type": "function", "function": {"name": "get_weather"}}],
)
if __name__ == "__main__":
unittest.main()

View File

@ -1,12 +1,12 @@
import unittest
from unittest.mock import patch, MagicMock
from datetime import datetime, timezone
import uuid
import os
import unittest
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from litellm.integrations.deepeval.api import Endpoints, HttpMethods
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
from litellm.integrations.deepeval.api import HttpMethods, Endpoints
from litellm.integrations.deepeval.types import TraceSpanApiStatus, SpanApiType
from litellm.integrations.deepeval.types import SpanApiType, TraceSpanApiStatus
class TestDeepEvalLogger(unittest.TestCase):

View File

@ -410,8 +410,8 @@ inner_object = {
}
],
"tool_choice": "none",
"count": 65,
"count-tolerate" : 67 #over by 2
"count": 65,
"count-tolerate": 67, # over by 2
}
"""
namespace functions {
@ -459,7 +459,7 @@ inner_object_with_enum_only = {
],
"tool_choice": "none",
"count": 73,
"count-tolerate" : 74 #over by 1
"count-tolerate": 74, # over by 1
}
"""
namespace functions {
@ -511,7 +511,7 @@ inner_object_with_enum = {
],
"tool_choice": "none",
"count": 89,
"count-tolerate" : 92, #over by 3
"count-tolerate": 92, # over by 3
}
"""
namespace functions {
@ -568,8 +568,8 @@ inner_object_and_string = {
}
],
"tool_choice": "none",
"count": 103,
"count-tolerate" : 106, #over by 3
"count": 103,
"count-tolerate": 106, # over by 3
}
"""
namespace functions {

View File

@ -6,10 +6,11 @@ import pytest
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
BAD_MESSAGE_ERROR_STR,
ollama_pt,
BedrockConverseMessagesProcessor,
ollama_pt,
)
def test_ollama_pt_simple_messages():
"""Test basic functionality with simple text messages"""
messages = [
@ -43,6 +44,7 @@ def test_ollama_pt_consecutive_user_messages():
assert isinstance(result, dict)
assert result["prompt"] == expected_prompt
@pytest.mark.asyncio
async def test_anthropic_bedrock_thinking_blocks_with_none_content():
"""
@ -55,28 +57,29 @@ async def test_anthropic_bedrock_thinking_blocks_with_none_content():
{
"type": "thinking",
"thinking": "This is a test thinking block",
"signature": "test-signature"
"signature": "test-signature",
}
],
"reasoning_content": "This is the reasoning content"
"reasoning_content": "This is the reasoning content",
}
messages = [
{"role": "user", "content": "What is the capital of France?"},
mock_assistant_message
mock_assistant_message,
]
# test _bedrock_converse_messages_pt_async
result = await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
messages=messages,
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
llm_provider="bedrock"
llm_provider="bedrock",
)
# verify the result
assert len(result) == 2
assert result[1]["content"][0]["reasoningContent"]["reasoningText"]["text"] == "This is a test thinking block"
assert (
result[1]["content"][0]["reasoningContent"]["reasoningText"]["text"]
== "This is a test thinking block"
)
# def test_ollama_pt_consecutive_system_messages():

View File

@ -1,30 +1,34 @@
import unittest
from datetime import datetime, timezone
from zoneinfo import ZoneInfo
from litellm.litellm_core_utils.duration_parser import get_next_standardized_reset_time
class TestStandardizedResetTime(unittest.TestCase):
def test_day_based_resets(self):
"""Test day-based reset durations (1d, 7d, 30d)"""
# Base time: 2023-05-15 10:30:00 UTC
base_time = datetime(2023, 5, 15, 10, 30, 0, tzinfo=timezone.utc)
# Daily reset (1d) - should reset at next midnight
daily_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
daily_result = get_next_standardized_reset_time("1d", base_time, "UTC")
self.assertEqual(daily_result, daily_expected)
# Weekly reset (7d) - should reset on next Monday
wednesday = datetime(2023, 5, 17, 15, 45, 0, tzinfo=timezone.utc) # A Wednesday
weekly_expected = datetime(2023, 5, 22, 0, 0, 0, tzinfo=timezone.utc) # Next Monday
weekly_expected = datetime(
2023, 5, 22, 0, 0, 0, tzinfo=timezone.utc
) # Next Monday
weekly_result = get_next_standardized_reset_time("7d", wednesday, "UTC")
self.assertEqual(weekly_result, weekly_expected)
# Monthly reset (30d) - should reset on 1st of next month
monthly_expected = datetime(2023, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
monthly_result = get_next_standardized_reset_time("30d", base_time, "UTC")
self.assertEqual(monthly_result, monthly_expected)
# Custom day reset (3d) - should reset after 3 days
custom_day_expected = datetime(2023, 5, 18, 0, 0, 0, tzinfo=timezone.utc)
custom_day_result = get_next_standardized_reset_time("3d", base_time, "UTC")
@ -34,17 +38,17 @@ class TestStandardizedResetTime(unittest.TestCase):
"""Test hour, minute, and second based reset durations"""
# Base time: 2023-05-15 15:20:30 UTC (3:20:30 PM)
base_time = datetime(2023, 5, 15, 15, 20, 30, tzinfo=timezone.utc)
# 2-hour reset - should reset at next even hour (16:00)
hour_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=timezone.utc)
hour_result = get_next_standardized_reset_time("2h", base_time, "UTC")
self.assertEqual(hour_result, hour_expected)
# 30-minute reset - should reset at next 30-minute mark (15:30)
minute_expected = datetime(2023, 5, 15, 15, 30, 0, tzinfo=timezone.utc)
minute_result = get_next_standardized_reset_time("30m", base_time, "UTC")
self.assertEqual(minute_result, minute_expected)
# 15-second reset - should reset at next 15-second mark (15:20:45)
second_expected = datetime(2023, 5, 15, 15, 20, 45, tzinfo=timezone.utc)
second_result = get_next_standardized_reset_time("15s", base_time, "UTC")
@ -54,32 +58,34 @@ class TestStandardizedResetTime(unittest.TestCase):
"""Test timezone handling with different regions"""
# Base time: 2023-05-15 22:30:00 UTC (late in UTC day)
base_time = datetime(2023, 5, 15, 22, 30, 0, tzinfo=timezone.utc)
# Test daily reset in different timezones
# US/Eastern (UTC-4): 6:30 PM, so next reset is midnight same day
eastern = ZoneInfo("US/Eastern")
eastern_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=eastern)
eastern_result = get_next_standardized_reset_time("1d", base_time, "US/Eastern")
self.assertEqual(eastern_result, eastern_expected)
# Asia/Kolkata (UTC+5:30): 4:00 AM next day, so next reset is midnight the day after
ist = ZoneInfo("Asia/Kolkata")
ist_expected = datetime(2023, 5, 17, 0, 0, 0, tzinfo=ist)
ist_result = get_next_standardized_reset_time("1d", base_time, "Asia/Kolkata")
self.assertEqual(ist_result, ist_expected)
# Test hourly reset in different timezones
# US/Pacific (UTC-7): 3:30 PM, so next 2h reset is 4:00 PM
pacific = ZoneInfo("US/Pacific")
pacific_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=pacific)
pacific_result = get_next_standardized_reset_time("2h", base_time, "US/Pacific")
self.assertEqual(pacific_result, pacific_expected)
# Test minute reset in different timezones
# Europe/London (UTC+1): 11:30 PM, so next 15m reset is 11:45 PM
london = ZoneInfo("Europe/London")
london_expected = datetime(2023, 5, 15, 23, 45, 0, tzinfo=london)
london_result = get_next_standardized_reset_time("15m", base_time, "Europe/London")
london_result = get_next_standardized_reset_time(
"15m", base_time, "Europe/London"
)
self.assertEqual(london_result, london_expected)
def test_edge_cases(self):
@ -89,25 +95,30 @@ class TestStandardizedResetTime(unittest.TestCase):
hour_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=timezone.utc)
hour_result = get_next_standardized_reset_time("2h", on_hour, "UTC")
self.assertEqual(hour_result, hour_expected)
# Exactly on minute boundary
on_minute = datetime(2023, 5, 15, 14, 30, 0, tzinfo=timezone.utc)
minute_expected = datetime(2023, 5, 15, 15, 0, 0, tzinfo=timezone.utc)
minute_result = get_next_standardized_reset_time("30m", on_minute, "UTC")
self.assertEqual(minute_result, minute_expected)
# Near day boundary
near_midnight = datetime(2023, 5, 15, 23, 50, 0, tzinfo=timezone.utc)
# 30m near midnight - should roll over to next day
midnight_minute_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
midnight_minute_result = get_next_standardized_reset_time("30m", near_midnight, "UTC")
midnight_minute_result = get_next_standardized_reset_time(
"30m", near_midnight, "UTC"
)
self.assertEqual(midnight_minute_result, midnight_minute_expected)
# Invalid timezone - should fall back to UTC
invalid_tz_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
invalid_tz_result = get_next_standardized_reset_time("1d", on_hour, "NonExistentTimeZone")
invalid_tz_result = get_next_standardized_reset_time(
"1d", on_hour, "NonExistentTimeZone"
)
self.assertEqual(invalid_tz_result, invalid_tz_expected)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -15,9 +15,9 @@ from litellm.types.utils import (
Delta,
Function,
ModelResponseStream,
PromptTokensDetails,
StreamingChoices,
Usage,
PromptTokensDetails,
)

View File

@ -612,29 +612,45 @@ def test_streaming_handler_with_stop_chunk(
assert returned_chunk is None
def test_set_response_id_propagation_empty_to_valid(initialized_custom_stream_wrapper: CustomStreamWrapper):
def test_set_response_id_propagation_empty_to_valid(
initialized_custom_stream_wrapper: CustomStreamWrapper,
):
"""Test that response_id is properly set when first chunk has empty ID and second chunk has valid ID"""
model_response1 = ModelResponseStream(id="", created=1742056047, model=None)
model_response1 = initialized_custom_stream_wrapper.set_model_id(model_response1.id, model_response1)
model_response1 = initialized_custom_stream_wrapper.set_model_id(
model_response1.id, model_response1
)
assert model_response1.id == ""
model_response2 = ModelResponseStream(id="valid-id-123", created=1742056048, model=None)
model_response2 = initialized_custom_stream_wrapper.set_model_id("valid-id-123", model_response2)
model_response2 = ModelResponseStream(
id="valid-id-123", created=1742056048, model=None
)
model_response2 = initialized_custom_stream_wrapper.set_model_id(
"valid-id-123", model_response2
)
assert model_response2.id == "valid-id-123"
assert initialized_custom_stream_wrapper.response_id == "valid-id-123"
def test_set_response_id_propagation_valid_to_invalid(initialized_custom_stream_wrapper: CustomStreamWrapper):
def test_set_response_id_propagation_valid_to_invalid(
initialized_custom_stream_wrapper: CustomStreamWrapper,
):
"""Test that response_id is maintained when first chunk has valid ID and second chunk has invalid ID"""
model_response1 = ModelResponseStream(id="first-valid-id", created=1742056049, model=None)
model_response1 = initialized_custom_stream_wrapper.set_model_id("first-valid-id", model_response1)
model_response1 = ModelResponseStream(
id="first-valid-id", created=1742056049, model=None
)
model_response1 = initialized_custom_stream_wrapper.set_model_id(
"first-valid-id", model_response1
)
assert model_response1.id == "first-valid-id"
assert initialized_custom_stream_wrapper.response_id == "first-valid-id"
model_response2 = ModelResponseStream(id="", created=1742056050, model=None)
model_response2 = initialized_custom_stream_wrapper.set_model_id("", model_response2)
model_response2 = initialized_custom_stream_wrapper.set_model_id(
"", model_response2
)
assert model_response2.id == "first-valid-id"
assert initialized_custom_stream_wrapper.response_id == "first-valid-id"

View File

@ -13,17 +13,16 @@ sys.path.insert(
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
from messages_with_counts import (
MESSAGES_TEXT,
MESSAGES_WITH_IMAGES,
MESSAGES_WITH_TOOLS,
)
import litellm
from litellm import create_pretrained_tokenizer, decode, encode, get_modified_max_tokens
from litellm import token_counter as token_counter_old
from litellm.litellm_core_utils.token_counter import token_counter as token_counter_new
from tests.large_text import text
from tests.test_litellm.litellm_core_utils.messages_with_counts import (
MESSAGES_TEXT,
MESSAGES_WITH_IMAGES,
MESSAGES_WITH_TOOLS,
)
def token_counter_both_assert_same(**args):

View File

@ -9,10 +9,9 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
#Use the same token_counter as the main test.
from test_token_counter import token_counter
from test_token_counter_tool_data import *
# Use the same token_counter as the main test.
from tests.test_litellm.litellm_core_utils.test_token_counter import token_counter
from tests.test_litellm.litellm_core_utils.test_token_counter_tool_data import *
@pytest.mark.parametrize(

View File

@ -27,19 +27,19 @@ CONTENT_AND_TOOL_CALL = [
]
_OPENHANDS_SYSTEM_MESSAGE = {
"content": [
{
"type": "text",
"text": "You are OpenHands agent, a helpful AI assistant that can "
"interact with a computer to solve tasks.\n\n<ROLE>\nYour primary "
"role is to assist users by executing commands, modifying code, and "
"solving technical problems effectively. You should be thorough, "
"methodical, and prioritize quality over speed.\n* If the user asks a "
"question, like 'why is X happening', dont try to fix the problem. ",
}
],
"role": "system",
}
"content": [
{
"type": "text",
"text": "You are OpenHands agent, a helpful AI assistant that can "
"interact with a computer to solve tasks.\n\n<ROLE>\nYour primary "
"role is to assist users by executing commands, modifying code, and "
"solving technical problems effectively. You should be thorough, "
"methodical, and prioritize quality over speed.\n* If the user asks a "
"question, like 'why is X happening', dont try to fix the problem. ",
}
],
"role": "system",
}
SYSTEM_LONG = [
_OPENHANDS_SYSTEM_MESSAGE,

View File

@ -19,13 +19,15 @@ async def test_get_openai_compatible_provider_info():
"""
config = AzureAIStudioConfig()
api_base, dynamic_api_key, custom_llm_provider = (
config._get_openai_compatible_provider_info(
model="azure_ai/gpt-4o-mini",
api_base="https://my-base",
api_key="my-key",
custom_llm_provider="azure_ai",
)
(
api_base,
dynamic_api_key,
custom_llm_provider,
) = config._get_openai_compatible_provider_info(
model="azure_ai/gpt-4o-mini",
api_base="https://my-base",
api_key="my-key",
custom_llm_provider="azure_ai",
)
assert custom_llm_provider == "azure"

View File

@ -1,6 +1,6 @@
from litellm.llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import AmazonMistralConfig
from litellm.llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import (
AmazonMistralConfig,
)
from litellm.types.utils import ModelResponse
@ -10,7 +10,9 @@ def test_mistral_get_outputText():
model_response.choices[0].finish_reason = "None"
# Models like pixtral will return a completion with the openai format.
mock_json_with_choices = {"choices": [{"message": {"content": "Hello!"}, "finish_reason": "stop"}]}
mock_json_with_choices = {
"choices": [{"message": {"content": "Hello!"}, "finish_reason": "stop"}]
}
outputText = AmazonMistralConfig.get_outputText(
completion_response=mock_json_with_choices, model_response=model_response

View File

@ -2,7 +2,6 @@ import os
import sys
from unittest.mock import MagicMock
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
@ -18,7 +17,11 @@ class TestCohereTransform:
def test_map_cohere_params(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
test_params = {
"temperature": 0.7,
"max_tokens": 200,
"max_completion_tokens": 256,
}
result = self.config.map_openai_params(
non_default_params=test_params,
@ -32,7 +35,10 @@ class TestCohereTransform:
def test_cohere_max_tokens_backward_compat(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200,}
test_params = {
"temperature": 0.7,
"max_tokens": 200,
}
result = self.config.map_openai_params(
non_default_params=test_params,

View File

@ -1,6 +1,6 @@
from typing import Optional
from unittest.mock import patch
import pytest
import litellm
@ -13,12 +13,26 @@ from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig
("user-provided-key", "secret-key", "user-provided-key", False),
(None, "secret-key", "secret-key", True),
(None, None, "fake-api-key", True),
("", "secret-key", "secret-key", True), # Empty string should fall back to secret
("", None, "fake-api-key", True), # Empty string with no secret should use the fake key
]
(
"",
"secret-key",
"secret-key",
True,
), # Empty string should fall back to secret
(
"",
None,
"fake-api-key",
True,
), # Empty string with no secret should use the fake key
],
)
def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called):
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
def test_resolve_api_key(
input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called
):
with patch(
"litellm.llms.llamafile.chat.transformation.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = api_key_from_secret_manager
result = LlamafileChatConfig._resolve_api_key(input_api_key)
@ -34,14 +48,36 @@ def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_ap
@pytest.mark.parametrize(
"input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called",
[
("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False),
(None, "https://secret-api.example.com", "https://secret-api.example.com", True),
(
"https://user-api.example.com",
"https://secret-api.example.com",
"https://user-api.example.com",
False,
),
(
None,
"https://secret-api.example.com",
"https://secret-api.example.com",
True,
),
(None, None, "http://127.0.0.1:8080/v1", True),
("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back
]
(
"",
"https://secret-api.example.com",
"https://secret-api.example.com",
True,
), # Empty string should fall back
],
)
def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called):
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
def test_resolve_api_base(
input_api_base,
api_base_from_secret_manager,
expected_api_base,
secret_manager_called,
):
with patch(
"litellm.llms.llamafile.chat.transformation.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = api_base_from_secret_manager
result = LlamafileChatConfig._resolve_api_base(input_api_base)
@ -58,31 +94,73 @@ def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected
"api_base, api_key, secret_base, secret_key, expected_base, expected_key",
[
# User-provided values
("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"),
(
"https://user-api.example.com",
"user-key",
"https://secret-api.example.com",
"secret-key",
"https://user-api.example.com",
"user-key",
),
# Fallback to secrets
(None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"),
(
None,
None,
"https://secret-api.example.com",
"secret-key",
"https://secret-api.example.com",
"secret-key",
),
# Nothing provided, use defaults
(None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"),
# Mixed scenarios
("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"),
(None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"),
]
(
"https://user-api.example.com",
None,
None,
"secret-key",
"https://user-api.example.com",
"secret-key",
),
(
None,
"user-key",
"https://secret-api.example.com",
None,
"https://secret-api.example.com",
"user-key",
),
],
)
def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key):
def test_get_openai_compatible_provider_info(
api_base, api_key, secret_base, secret_key, expected_base, expected_key
):
config = LlamafileChatConfig()
def fake_get_secret(key: str) -> Optional[str]:
return {
"LLAMAFILE_API_BASE": secret_base,
"LLAMAFILE_API_KEY": secret_key
}.get(key)
return {"LLAMAFILE_API_BASE": secret_base, "LLAMAFILE_API_KEY": secret_key}.get(
key
)
patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret)
patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base)
patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key)
patch_secret = patch(
"litellm.llms.llamafile.chat.transformation.get_secret_str",
side_effect=fake_get_secret,
)
patch_base = patch.object(
LlamafileChatConfig,
"_resolve_api_base",
wraps=LlamafileChatConfig._resolve_api_base,
)
patch_key = patch.object(
LlamafileChatConfig,
"_resolve_api_key",
wraps=LlamafileChatConfig._resolve_api_key,
)
with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key:
result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key)
result_base, result_key = config._get_openai_compatible_provider_info(
api_base, api_key
)
assert result_base == expected_base
assert result_key == expected_key
@ -100,8 +178,12 @@ def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, sec
def test_completion_with_custom_llamafile_model():
with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func:
mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response
with patch(
"litellm.main.openai_chat_completions.completion"
) as mock_llamafile_completion_func:
mock_llamafile_completion_func.return_value = (
{}
) # Return an empty dictionary for the mocked response
provider = "llamafile"
model_name = "my-custom-test-model"

View File

@ -3,7 +3,9 @@ import sys
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")))
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from litellm.llms.lm_studio.chat.transformation import LMStudioChatConfig
from litellm.utils import get_optional_params
@ -44,7 +46,9 @@ class TestLMStudioChatConfigResponseFormat:
custom_llm_provider="lm_studio",
)
mapped = config.map_openai_params(non_default_params, {}, "lm_studio/test-model", False)
mapped = config.map_openai_params(
non_default_params, {}, "lm_studio/test-model", False
)
mapped_schema = mapped["response_format"]["json_schema"]["schema"]
assert mapped_schema["properties"] == schema["properties"]
opt_schema = optional_params["response_format"]["json_schema"]["schema"]

View File

@ -1,45 +1,42 @@
import json
import os
import sys
import json
import uuid
import pytest
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.ollama.completion.transformation import (
OllamaConfig,
)
from litellm.types.utils import ModelResponse
from litellm.types.utils import Message
from litellm.llms.ollama.completion.transformation import OllamaConfig
from litellm.types.utils import Message, ModelResponse
class TestOllamaConfig:
def test_transform_response_standard(self):
# Initialize config
config = OllamaConfig()
# Create mock response
raw_response = MagicMock()
raw_response.json.return_value = {
"response": "Hello, I am an AI assistant",
"prompt_eval_count": 10,
"eval_count": 5
"eval_count": 5,
}
# Create properly structured model response object
model_response = ModelResponse(
id="test_id",
choices=[{"message": Message(content="")}],
)
# Create mock encoding
mock_encoding = MagicMock()
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
# Transform response
result = config.transform_response(
model="llama2",
@ -52,7 +49,7 @@ class TestOllamaConfig:
litellm_params={},
encoding=mock_encoding,
)
# Verify response
assert result.choices[0]["message"].content == "Hello, I am an AI assistant"
assert result.choices[0]["finish_reason"] == "stop"
@ -67,29 +64,28 @@ class TestOllamaConfig:
def test_transform_response_json_function_call(self, mock_uuid4):
# Setup mock UUID
mock_uuid4.return_value = "test-uuid"
# Initialize config
config = OllamaConfig()
# Create mock response with JSON function call format
raw_response = MagicMock()
raw_response.json.return_value = {
"response": json.dumps({
"name": "get_weather",
"arguments": {"location": "San Francisco"}
})
"response": json.dumps(
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
)
}
# Create properly structured model response object
model_response = ModelResponse(
id="test_id",
choices=[{"message": Message(content="")}],
)
# Create mock encoding
mock_encoding = MagicMock()
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
# Transform response
result = config.transform_response(
model="llama2",
@ -102,39 +98,43 @@ class TestOllamaConfig:
litellm_params={},
encoding=mock_encoding,
)
# Verify result has tool_calls
assert result.choices[0]["message"].content is None
assert result.choices[0]["finish_reason"] == "tool_calls"
assert len(result.choices[0]["message"].tool_calls) == 1
assert result.choices[0]["message"].tool_calls[0]["id"].startswith("call_")
assert result.choices[0]["message"].tool_calls[0]["function"]["name"] == "get_weather"
assert json.loads(result.choices[0]["message"].tool_calls[0]["function"]["arguments"]) == {"location": "San Francisco"}
assert (
result.choices[0]["message"].tool_calls[0]["function"]["name"]
== "get_weather"
)
assert json.loads(
result.choices[0]["message"].tool_calls[0]["function"]["arguments"]
) == {"location": "San Francisco"}
# No usage assertions here as we don't need to test them in every case
def test_transform_response_regular_json(self):
# Initialize config
config = OllamaConfig()
# Create mock response with regular JSON (not function call)
raw_response = MagicMock()
raw_response.json.return_value = {
"response": json.dumps({
"result": "success",
"data": {"temperature": 72, "unit": "F"}
})
"response": json.dumps(
{"result": "success", "data": {"temperature": 72, "unit": "F"}}
)
}
# Create properly structured model response object
model_response = ModelResponse(
id="test_id",
choices=[{"message": Message(content="")}],
)
# Create mock encoding
mock_encoding = MagicMock()
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
# Transform response
result = config.transform_response(
model="llama2",
@ -147,12 +147,11 @@ class TestOllamaConfig:
litellm_params={},
encoding=mock_encoding,
)
# Verify result has JSON content
expected_content = json.dumps({
"result": "success",
"data": {"temperature": 72, "unit": "F"}
})
expected_content = json.dumps(
{"result": "success", "data": {"temperature": 72, "unit": "F"}}
)
assert result.choices[0]["message"].content == expected_content
assert result.choices[0]["finish_reason"] == "stop"
# No usage assertions here as we don't need to test them in every case
# No usage assertions here as we don't need to test them in every case

View File

@ -1,10 +1,10 @@
import json
import os
import sys
import json
import uuid
import pytest
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
@ -14,13 +14,15 @@ sys.path.insert(
Unit tests for OllamaModelInfo.get_models functionality.
"""
# Ensure a dummy httpx module is available for import in tests
import sys, types
import sys
import types
# Provide a dummy httpx module for import in get_models
if 'httpx' not in sys.modules:
if "httpx" not in sys.modules:
# Create a minimal module with HTTPStatusError
httpx_mod = types.ModuleType('httpx')
httpx_mod = types.ModuleType("httpx")
httpx_mod.HTTPStatusError = Exception
sys.modules['httpx'] = httpx_mod
sys.modules["httpx"] = httpx_mod
import httpx
@ -31,6 +33,7 @@ class DummyResponse:
"""
A dummy response object to simulate httpx responses.
"""
def __init__(self, json_data, status_code=200):
self._json = json_data
self.status_code = status_code
@ -38,7 +41,9 @@ class DummyResponse:
def raise_for_status(self):
if self.status_code >= 400:
# Simulate an HTTP status error
raise httpx.HTTPStatusError("Error status code", request=None, response=None)
raise httpx.HTTPStatusError(
"Error status code", request=None, response=None
)
def json(self):
return self._json
@ -51,25 +56,26 @@ class TestOllamaModelInfo:
get_models should extract and return sorted unique model names.
"""
calls = []
sample = {'models': [
{'name': 'zeta'},
{'model': 'alpha'},
{'name': 123}, # non-str should be ignored
'invalid', # non-dict should be ignored
]}
sample = {
"models": [
{"name": "zeta"},
{"model": "alpha"},
{"name": 123}, # non-str should be ignored
"invalid", # non-dict should be ignored
]
}
def mock_get(url):
calls.append(url)
return DummyResponse(sample, status_code=200)
monkeypatch.setattr(httpx, 'get', mock_get)
monkeypatch.setattr(httpx, "get", mock_get)
info = OllamaModelInfo()
models = info.get_models()
# Only 'alpha' and 'zeta' should be returned, sorted alphabetically
assert models == ['alpha', 'zeta']
assert models == ["alpha", "zeta"]
# Ensure correct endpoint was called
assert calls and calls[0].endswith('/api/tags')
assert calls and calls[0].endswith("/api/tags")
def test_get_models_from_list_response(self, monkeypatch):
"""
@ -77,30 +83,30 @@ class TestOllamaModelInfo:
get_models should extract and return sorted unique model names.
"""
sample = [
{'name': 'm1'},
{'model': 'm2'},
{}, # no name/model key should be ignored
{"name": "m1"},
{"model": "m2"},
{}, # no name/model key should be ignored
]
def mock_get(url):
return DummyResponse(sample, status_code=200)
monkeypatch.setattr(httpx, 'get', mock_get)
monkeypatch.setattr(httpx, "get", mock_get)
info = OllamaModelInfo()
models = info.get_models()
assert models == ['m1', 'm2']
assert models == ["m1", "m2"]
def test_get_models_fallback_on_error(self, monkeypatch):
"""
If the httpx.get call raises an exception, get_models should
fall back to the static models_by_provider list prefixed by 'ollama/'.
"""
def mock_get(url):
raise Exception("connection failure")
monkeypatch.setattr(httpx, 'get', mock_get)
monkeypatch.setattr(httpx, "get", mock_get)
info = OllamaModelInfo()
models = info.get_models()
# Default static ollama_models is ['llama2'], so expect ['ollama/llama2']
assert models == ['ollama/llama2']
assert models == ["ollama/llama2"]

View File

@ -1,4 +1,5 @@
import pytest
from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
@ -11,24 +12,20 @@ from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
("o4-mini", True),
("o1-preview", True),
("o3-mini", True),
# Valid O-series models with provider prefix
("openai/o1", True),
("openai/o3", True),
("openai/o4-mini", True),
("openai/o1-preview", True),
("openai/o3-mini", True),
# Non-O-series models
("gpt-4", False),
("gpt-3.5-turbo", False),
("claude-3-opus", False),
# Non-O-series models with provider prefix
("openai/gpt-4", False),
("openai/gpt-3.5-turbo", False),
("anthropic/claude-3-opus", False),
# Edge cases
("o", False), # Too short
("o5", False), # Not a valid O-series model
@ -39,11 +36,12 @@ from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
def test_is_model_o_series_model(model_name: str, expected: bool):
"""
Test that is_model_o_series_model correctly identifies O-series models.
Args:
model_name: The model name to test
expected: The expected result (True if it should be identified as an O-series model)
"""
config = OpenAIOSeriesConfig()
assert config.is_model_o_series_model(model_name) == expected, \
f"Expected {model_name} to be {'an O-series model' if expected else 'not an O-series model'}"
assert (
config.is_model_o_series_model(model_name) == expected
), f"Expected {model_name} to be {'an O-series model' if expected else 'not an O-series model'}"

View File

@ -98,7 +98,6 @@ async def test_openai_client_reuse(function_name, is_async, args):
) as mock_set_cache, patch.object(
BaseOpenAILLM, "get_cached_openai_client"
) as mock_get_cache:
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
mock_client = MagicMock()
mock_get_cache.side_effect = [None] + [

View File

@ -3,15 +3,14 @@ import sys
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.openrouter.chat.transformation import (
OpenRouterChatCompletionStreamingHandler,
OpenRouterException,
OpenrouterConfig,
OpenRouterException,
)
@ -26,11 +25,7 @@ class TestOpenRouterChatCompletionStreamingHandler:
"id": "test_id",
"created": 1234567890,
"model": "test_model",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
},
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
"choices": [
{"delta": {"content": "test content", "reasoning": "test reasoning"}}
],
@ -89,7 +84,6 @@ class TestOpenRouterChatCompletionStreamingHandler:
def test_openrouter_extra_body_transformation():
transformed_request = OpenrouterConfig().transform_request(
model="openrouter/deepseek/deepseek-chat",
messages=[{"role": "user", "content": "Hello, world!"}],

View File

@ -10,6 +10,7 @@ sys.path.insert(0, os.path.abspath("../../../../.."))
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
@pytest.mark.asyncio
async def test_aiter_bytes_unicode_decode_error():
"""
@ -96,6 +97,7 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
assert len(chunks) == 1
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
class TestSagemakerTransform:
def setup_method(self):
self.config = SagemakerConfig()
@ -104,7 +106,11 @@ class TestSagemakerTransform:
def test_map_mistral_params(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
test_params = {
"temperature": 0.7,
"max_tokens": 200,
"max_completion_tokens": 256,
}
result = self.config.map_openai_params(
non_default_params=test_params,
@ -118,7 +124,10 @@ class TestSagemakerTransform:
def test_mistral_max_tokens_backward_compat(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200,}
test_params = {
"temperature": 0.7,
"max_tokens": 200,
}
result = self.config.map_openai_params(
non_default_params=test_params,

View File

@ -1,34 +1,36 @@
import json
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
make_call,
make_sync_call,
VertexAIError,
)
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class TestVertexAIHTTPStatus201(unittest.TestCase):
def setUp(self):
# Setup mock messages
self.messages = [{"role": "user", "content": "Hello, how are you?"}]
# Setup mock data
self.mock_data = json.dumps({"messages": self.messages})
# Setup mock headers
self.mock_headers = {"Content-Type": "application/json"}
# Setup mock model
self.mock_model = "gemini-pro"
# Setup mock logging object
self.mock_logging_obj = MagicMock()
self.mock_logging_obj.post_call = MagicMock()
@patch("litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.get_async_httpx_client")
@patch(
"litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.get_async_httpx_client"
)
async def test_async_http_status_201(self, mock_get_client):
"""Test that async make_call handles HTTP 201 status code correctly"""
# Create a mock response with status code 201
@ -36,13 +38,13 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
mock_response.status_code = 201
mock_response.aiter_lines = MagicMock()
mock_response.aiter_lines.return_value = ["test response"]
# Setup mock client
mock_client = MagicMock()
mock_client.post = MagicMock()
mock_client.post.return_value = mock_response
mock_get_client.return_value = mock_client
# Call the make_call function
result = await make_call(
client=None,
@ -51,15 +53,15 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
data=self.mock_data,
model=self.mock_model,
messages=self.messages,
logging_obj=self.mock_logging_obj
logging_obj=self.mock_logging_obj,
)
# Assert that the post method was called
mock_client.post.assert_called_once()
# Assert that no error was raised for status code 201
self.assertIsNotNone(result)
# Verify logging was called
self.mock_logging_obj.post_call.assert_called_once()
@ -72,7 +74,7 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
mock_response.iter_lines = MagicMock()
mock_response.iter_lines.return_value = ["test response"]
mock_post.return_value = mock_response
# Call the make_sync_call function
result = make_sync_call(
client=None,
@ -82,12 +84,12 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
data=self.mock_data,
model=self.mock_model,
messages=self.messages,
logging_obj=self.mock_logging_obj
logging_obj=self.mock_logging_obj,
)
# Assert that no error was raised for status code 201
self.assertIsNotNone(result)
# Verify logging was called
self.mock_logging_obj.post_call.assert_called_once()
@ -100,7 +102,7 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
mock_response.read = MagicMock(return_value=b"Bad Request")
mock_response.headers = {}
mock_post.return_value = mock_response
# Call the make_sync_call function and expect an error
with self.assertRaises(VertexAIError) as context:
make_sync_call(
@ -111,12 +113,12 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
data=self.mock_data,
model=self.mock_model,
messages=self.messages,
logging_obj=self.mock_logging_obj
logging_obj=self.mock_logging_obj,
)
# Assert that the error has the correct status code
self.assertEqual(context.exception.status_code, 400)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -1,9 +1,7 @@
import base64
import numpy as np
import json
import os
import sys
import traceback
from dotenv import load_dotenv
@ -18,6 +16,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.vertex_ai.gemini.transformation import _process_gemini_image
@ -31,6 +30,7 @@ def encode_image_to_base64(image_path):
def test_completion_pydantic_obj_2():
from pydantic import BaseModel
from litellm.llms.custom_httpx.http_handler import HTTPHandler
litellm.set_verbose = True
@ -110,11 +110,10 @@ def test_completion_pydantic_obj_2():
def test_build_vertex_schema():
from litellm.llms.vertex_ai.common_utils import (
_build_vertex_schema,
)
import json
from litellm.llms.vertex_ai.common_utils import _build_vertex_schema
schema = {
"type": "object",
"my-random-key": "my-random-value",
@ -149,7 +148,6 @@ def test_build_vertex_schema():
],
)
def test_vertex_tool_params(tools, key):
optional_params = get_optional_params(
model="gemini-1.5-pro",
custom_llm_provider="vertex_ai",
@ -1124,7 +1122,6 @@ def test_logprobs():
mock_response.json.return_value = response_body
with patch.object(client, "post", return_value=mock_response):
resp = litellm.completion(
model="gemini/gemini-1.5-flash-002",
messages=[
@ -1140,9 +1137,7 @@ def test_logprobs():
def test_process_gemini_image():
"""Test the _process_gemini_image function for different image sources"""
from litellm.llms.vertex_ai.gemini.transformation import (
_process_gemini_image,
)
from litellm.llms.vertex_ai.gemini.transformation import _process_gemini_image
from litellm.types.llms.vertex_ai import FileDataType
# Test GCS URI
@ -1266,9 +1261,10 @@ def test_vertex_embedding_url(model, expected_url):
assert endpoint == "predict"
import pytest
from unittest.mock import Mock, patch
import pytest
# Add these fixtures below existing fixtures
@pytest.fixture
@ -1435,7 +1431,10 @@ def test_vertex_parallel_tool_calls_false_multiple_tools_error():
tools=tools,
parallel_tool_calls=False,
)
assert "`parallel_tool_calls=False` is not supported when multiple tools are provided" in str(excinfo.value)
assert (
"`parallel_tool_calls=False` is not supported when multiple tools are provided"
in str(excinfo.value)
)
# works when specified as "functions"
with pytest.raises(litellm.utils.UnsupportedParamsError) as excinfo:
@ -1445,7 +1444,10 @@ def test_vertex_parallel_tool_calls_false_multiple_tools_error():
functions=tools,
parallel_tool_calls=False,
)
assert "`parallel_tool_calls=False` is not supported when multiple tools are provided" in str(excinfo.value)
assert (
"`parallel_tool_calls=False` is not supported when multiple tools are provided"
in str(excinfo.value)
)
def test_vertex_parallel_tool_calls_false_single_tool():

View File

@ -13,11 +13,11 @@ sys.path.insert(
import litellm
from litellm.llms.vertex_ai.common_utils import (
_get_vertex_url,
convert_anyof_null_to_nullable,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
set_schema_property_ordering,
_get_vertex_url
)
@ -518,6 +518,7 @@ def test_vertex_ai_complex_response_schema():
assert "additionalProperties" not in type3
assert "additionalProperties" not in type3_prop3_items
@pytest.mark.parametrize(
"stream, expected_endpoint_suffix",
[
@ -537,7 +538,10 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
# Mock litellm.VertexGeminiConfig.get_model_for_vertex_ai_url to return model as is
# as we are not testing that part here, just the URL construction
with patch("litellm.VertexGeminiConfig.get_model_for_vertex_ai_url", side_effect=lambda model: model):
with patch(
"litellm.VertexGeminiConfig.get_model_for_vertex_ai_url",
side_effect=lambda model: model,
):
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
@ -548,7 +552,7 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
)
expected_url_base = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}"
if stream:
expected_endpoint = "streamGenerateContent"
expected_url = f"{expected_url_base}:{expected_endpoint}?alt=sse"
@ -556,6 +560,5 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
expected_endpoint = "generateContent"
expected_url = f"{expected_url_base}:{expected_endpoint}"
assert endpoint == expected_endpoint
assert url == expected_url

View File

@ -24,38 +24,49 @@ class TestAnthropicEndpoints(unittest.TestCase):
{"type": "content_block_delta", "delta": {"text": "more data"}},
"text chunk data again",
]
mock_user_api_key_dict = MagicMock()
mock_request_data = {}
mock_proxy_logging_obj = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(side_effect=lambda **kwargs: kwargs["response"])
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs["response"]
)
# Configure safe_dumps to return a properly formatted JSON string
mock_safe_dumps.side_effect = lambda chunk: json.dumps(chunk)
# Execute
result = [chunk async for chunk in async_data_generator_anthropic(
response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
request_data=mock_request_data,
proxy_logging_obj=mock_proxy_logging_obj,
)]
result = [
chunk
async for chunk in async_data_generator_anthropic(
response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
request_data=mock_request_data,
proxy_logging_obj=mock_proxy_logging_obj,
)
]
# Verify
expected_result = [
'data: {"type": "message_start", "message": {"id": "msg_123"}}\n\n',
'text chunk data',
"text chunk data",
'data: {"type": "content_block_delta", "delta": {"text": "more data"}}\n\n',
'text chunk data again',
"text chunk data again",
]
self.assertEqual(result, expected_result)
# Assert safe_dumps was called for dictionary objects
mock_safe_dumps.assert_any_call({"type": "message_start", "message": {"id": "msg_123"}})
mock_safe_dumps.assert_any_call({"type": "content_block_delta", "delta": {"text": "more data"}})
assert mock_safe_dumps.call_count == 2 # Called twice, once for each dict object
mock_safe_dumps.assert_any_call(
{"type": "message_start", "message": {"id": "msg_123"}}
)
mock_safe_dumps.assert_any_call(
{"type": "content_block_delta", "delta": {"text": "more data"}}
)
assert (
mock_safe_dumps.call_count == 2
) # Called twice, once for each dict object
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -0,0 +1 @@
"""Tests for the LiteLLM Proxy Client CLI package."""

View File

@ -1,5 +1,5 @@
import json
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import pytest
import requests
@ -238,4 +238,4 @@ def test_chat_completions_all_parameters(cli_runner, mock_chat_client):
presence_penalty=0.5,
frequency_penalty=0.5,
user="test-user",
)
)

Some files were not shown because too many files have changed in this diff Show More