Litellm fix GitHub action testing (#11163)
* test: add __init__.py files * refactor: rename test folder to avoid naming conflict * test: update workflows * test: update tests * test: update imports * test: update tests * test: remove unused import * ci(test-litellm.yml): add pytest retry to github workflow * test: fix test
This commit is contained in:
parent
9a35c41462
commit
ef42461c1e
3
.github/workflows/test-litellm.yml
vendored
3
.github/workflows/test-litellm.yml
vendored
@ -28,6 +28,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
poetry install --with dev,proxy-dev --extras proxy
|
||||
poetry run pip install "pytest-retry==1.6.3"
|
||||
poetry run pip install pytest-xdist
|
||||
- name: Setup litellm-enterprise as local package
|
||||
run: |
|
||||
@ -36,4 +37,4 @@ jobs:
|
||||
cd ..
|
||||
- name: Run tests
|
||||
run: |
|
||||
poetry run pytest tests/litellm -x -vv -n 4
|
||||
poetry run pytest tests/test_litellm -x -vv -n 4
|
||||
@ -24,7 +24,7 @@ repos:
|
||||
rev: 7.0.0 # The version of flake8 to use
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^litellm/tests/|^litellm/proxy/tests/|^litellm/tests/litellm/|^tests/litellm/
|
||||
exclude: ^litellm/tests/|^litellm/proxy/tests/|^litellm/tests/test_litellm/|^tests/test_litellm/
|
||||
additional_dependencies: [flake8-print]
|
||||
files: (litellm/|litellm_proxy_extras/|enterprise/).*\.py
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
|
||||
@ -4392,7 +4392,7 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_tool_choice": true,
|
||||
"deprecation_date": "2025-1-6"
|
||||
"deprecation_date": "2025-01-06"
|
||||
},
|
||||
"groq/llama3-groq-8b-8192-tool-use-preview": {
|
||||
"max_tokens": 8192,
|
||||
@ -4405,7 +4405,7 @@
|
||||
"supports_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_tool_choice": true,
|
||||
"deprecation_date": "2025-1-6"
|
||||
"deprecation_date": "2025-01-06"
|
||||
},
|
||||
"groq/qwen-qwq-32b": {
|
||||
"max_tokens": 128000,
|
||||
|
||||
@ -1,207 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, ANY
|
||||
import json
|
||||
import datetime
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
from litellm.integrations.athina import AthinaLogger
|
||||
|
||||
class TestAthinaLogger(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Set up environment variables for testing
|
||||
self.env_patcher = patch.dict('os.environ', {
|
||||
'ATHINA_API_KEY': 'test-api-key',
|
||||
'ATHINA_BASE_URL': 'https://test.athina.ai'
|
||||
})
|
||||
self.env_patcher.start()
|
||||
self.logger = AthinaLogger()
|
||||
|
||||
# Setup common test variables
|
||||
self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1)
|
||||
self.print_verbose = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.env_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test the initialization of AthinaLogger"""
|
||||
self.assertEqual(self.logger.athina_api_key, 'test-api-key')
|
||||
self.assertEqual(self.logger.athina_logging_url, 'https://test.athina.ai/api/v1/log/inference')
|
||||
self.assertEqual(self.logger.headers, {
|
||||
'athina-api-key': 'test-api-key',
|
||||
'Content-Type': 'application/json'
|
||||
})
|
||||
|
||||
@patch('litellm.module_level_client.post')
|
||||
def test_log_event_success(self, mock_post):
|
||||
"""Test successful logging of an event"""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "Success"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
'model': 'gpt-4',
|
||||
'messages': [{'role': 'user', 'content': 'Hello'}],
|
||||
'stream': False,
|
||||
'litellm_params': {
|
||||
'metadata': {
|
||||
'environment': 'test-environment',
|
||||
'prompt_slug': 'test-prompt',
|
||||
'customer_id': 'test-customer',
|
||||
'customer_user_id': 'test-user',
|
||||
'session_id': 'test-session',
|
||||
'external_reference_id': 'test-ext-ref',
|
||||
'context': 'test-context',
|
||||
'expected_response': 'test-expected',
|
||||
'user_query': 'test-query',
|
||||
'tags': ['test-tag'],
|
||||
'user_feedback': 'test-feedback',
|
||||
'model_options': {'test-opt': 'test-val'},
|
||||
'custom_attributes': {'test-attr': 'test-val'}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
'id': 'resp-123',
|
||||
'choices': [{'message': {'content': 'Hi there'}}],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 15
|
||||
}
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
|
||||
|
||||
# Verify the results
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
self.assertEqual(call_args[0][0], 'https://test.athina.ai/api/v1/log/inference')
|
||||
self.assertEqual(call_args[1]['headers'], self.logger.headers)
|
||||
|
||||
# Parse and verify the sent data
|
||||
sent_data = json.loads(call_args[1]['data'])
|
||||
self.assertEqual(sent_data['language_model_id'], 'gpt-4')
|
||||
self.assertEqual(sent_data['prompt'], kwargs['messages'])
|
||||
self.assertEqual(sent_data['prompt_tokens'], 10)
|
||||
self.assertEqual(sent_data['completion_tokens'], 5)
|
||||
self.assertEqual(sent_data['total_tokens'], 15)
|
||||
self.assertEqual(sent_data['response_time'], 1000) # 1 second = 1000ms
|
||||
self.assertEqual(sent_data['customer_id'], 'test-customer')
|
||||
self.assertEqual(sent_data['session_id'], 'test-session')
|
||||
self.assertEqual(sent_data['environment'], 'test-environment')
|
||||
self.assertEqual(sent_data['prompt_slug'], 'test-prompt')
|
||||
self.assertEqual(sent_data['external_reference_id'], 'test-ext-ref')
|
||||
self.assertEqual(sent_data['context'], 'test-context')
|
||||
self.assertEqual(sent_data['expected_response'], 'test-expected')
|
||||
self.assertEqual(sent_data['user_query'], 'test-query')
|
||||
self.assertEqual(sent_data['tags'], ['test-tag'])
|
||||
self.assertEqual(sent_data['user_feedback'], 'test-feedback')
|
||||
self.assertEqual(sent_data['model_options'], {'test-opt': 'test-val'})
|
||||
self.assertEqual(sent_data['custom_attributes'], {'test-attr': 'test-val'})
|
||||
# Verify the print_verbose was called
|
||||
self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success")
|
||||
|
||||
@patch('litellm.module_level_client.post')
|
||||
def test_log_event_error_response(self, mock_post):
|
||||
"""Test handling of error response from the API"""
|
||||
# Setup mock error response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
'model': 'gpt-4',
|
||||
'messages': [{'role': 'user', 'content': 'Hello'}],
|
||||
'stream': False
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
'id': 'resp-123',
|
||||
'choices': [{'message': {'content': 'Hi there'}}],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 15
|
||||
}
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
|
||||
|
||||
# Verify print_verbose was called with error message
|
||||
self.print_verbose.assert_called_once_with("Athina Logger Error - Bad Request, 400")
|
||||
|
||||
@patch('litellm.module_level_client.post')
|
||||
def test_log_event_exception(self, mock_post):
|
||||
"""Test handling of exceptions during logging"""
|
||||
# Setup mock to raise exception
|
||||
mock_post.side_effect = Exception("Test exception")
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
'model': 'gpt-4',
|
||||
'messages': [{'role': 'user', 'content': 'Hello'}],
|
||||
'stream': False
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
|
||||
|
||||
# Verify print_verbose was called with exception info
|
||||
self.print_verbose.assert_called_once()
|
||||
self.assertIn("Athina Logger Error - Test exception", self.print_verbose.call_args[0][0])
|
||||
|
||||
@patch('litellm.module_level_client.post')
|
||||
def test_log_event_with_tools(self, mock_post):
|
||||
"""Test logging with tools/functions data"""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data with tools
|
||||
kwargs = {
|
||||
'model': 'gpt-4',
|
||||
'messages': [{'role': 'user', 'content': "What's the weather?"}],
|
||||
'stream': False,
|
||||
'optional_params': {
|
||||
'tools': [{'type': 'function', 'function': {'name': 'get_weather'}}]
|
||||
}
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
'id': 'resp-123',
|
||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5, 'total_tokens': 15}
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(kwargs, response_obj, self.start_time, self.end_time, self.print_verbose)
|
||||
|
||||
# Verify the results
|
||||
sent_data = json.loads(mock_post.call_args[1]['data'])
|
||||
self.assertEqual(sent_data['tools'], [{'type': 'function', 'function': {'name': 'get_weather'}}])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1 +0,0 @@
|
||||
"""Tests for the LiteLLM Proxy Client CLI package."""
|
||||
0
tests/test_litellm/__init__.py
Normal file
0
tests/test_litellm/__init__.py
Normal file
@ -16,7 +16,7 @@ from litellm.caching.redis_cache import RedisCache
|
||||
@pytest.fixture
|
||||
def redis_no_ping():
|
||||
"""Patch RedisCache initialization to prevent async ping tasks from being created"""
|
||||
with patch('asyncio.get_running_loop') as mock_get_loop:
|
||||
with patch("asyncio.get_running_loop") as mock_get_loop:
|
||||
# Either raise an exception or return a mock that will handle the task creation
|
||||
mock_get_loop.side_effect = RuntimeError("No running event loop")
|
||||
yield
|
||||
@ -64,32 +64,32 @@ async def test_redis_client_init_with_socket_timeout(monkeypatch, redis_no_ping)
|
||||
async def test_redis_cache_async_batch_get_cache(monkeypatch, redis_no_ping):
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
redis_cache = RedisCache()
|
||||
|
||||
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
||||
|
||||
# Make sure the mock can be used as an async context manager
|
||||
mock_redis_instance.__aenter__.return_value = mock_redis_instance
|
||||
mock_redis_instance.__aexit__.return_value = None
|
||||
|
||||
|
||||
# Setup the return value for mget
|
||||
mock_redis_instance.mget.return_value = [
|
||||
b'{"key1": "value1"}',
|
||||
None,
|
||||
b'{"key3": "value3"}'
|
||||
b'{"key3": "value3"}',
|
||||
]
|
||||
|
||||
|
||||
test_keys = ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
with patch.object(
|
||||
redis_cache, "init_async_client", return_value=mock_redis_instance
|
||||
):
|
||||
# Call async_batch_get_cache
|
||||
result = await redis_cache.async_batch_get_cache(key_list=test_keys)
|
||||
|
||||
|
||||
# Verify mget was called with the correct keys
|
||||
mock_redis_instance.mget.assert_called_once()
|
||||
|
||||
|
||||
# Check that results were properly decoded
|
||||
assert result["key1"] == {"key1": "value1"}
|
||||
assert result["key2"] is None
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -13,27 +13,30 @@ sys.path.insert(
|
||||
def test_redis_semantic_cache_initialization(monkeypatch):
|
||||
# Mock the redisvl import
|
||||
semantic_cache_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock())
|
||||
}):
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=MagicMock()),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
|
||||
# Initialize the cache with a similarity threshold
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
|
||||
# Verify the semantic cache was initialized with correct parameters
|
||||
assert redis_semantic_cache.similarity_threshold == 0.8
|
||||
|
||||
|
||||
# Use pytest.approx for floating point comparison to handle precision issues
|
||||
assert redis_semantic_cache.distance_threshold == pytest.approx(0.2, abs=1e-10)
|
||||
assert redis_semantic_cache.embedding_model == "text-embedding-ada-002"
|
||||
|
||||
|
||||
# Test initialization with missing similarity_threshold
|
||||
with pytest.raises(ValueError, match="similarity_threshold must be provided"):
|
||||
RedisSemanticCache()
|
||||
@ -43,42 +46,48 @@ def test_redis_semantic_cache_get_cache(monkeypatch):
|
||||
# Mock the redisvl import and embedding function
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=custom_vectorizer_mock)
|
||||
}):
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
|
||||
# Initialize cache
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
|
||||
# Mock the llmcache.check method to return a result
|
||||
mock_result = [
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1 # Distance of 0.1 means similarity of 0.9
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
}
|
||||
]
|
||||
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
|
||||
|
||||
|
||||
# Mock the embedding function
|
||||
with patch("litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}):
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
# Test get_cache with a message
|
||||
result = redis_semantic_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}]
|
||||
key="test_key", messages=[{"content": "What is the capital of France?"}]
|
||||
)
|
||||
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
|
||||
|
||||
# Verify llmcache.check was called
|
||||
redis_semantic_cache.llmcache.check.assert_called_once()
|
||||
|
||||
@ -88,43 +97,50 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
|
||||
# Mock the redisvl import
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(CustomTextVectorizer=custom_vectorizer_mock)
|
||||
}):
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
|
||||
# Initialize cache
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
|
||||
|
||||
# Mock the async methods
|
||||
mock_result = [
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1 # Distance of 0.1 means similarity of 0.9
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=mock_result)
|
||||
redis_semantic_cache._get_async_embedding = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
|
||||
redis_semantic_cache._get_async_embedding = AsyncMock(
|
||||
return_value=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
# Test async_get_cache with a message
|
||||
result = await redis_semantic_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata={}
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
|
||||
|
||||
# Verify methods were called
|
||||
redis_semantic_cache._get_async_embedding.assert_called_once()
|
||||
redis_semantic_cache.llmcache.acheck.assert_called_once()
|
||||
redis_semantic_cache.llmcache.acheck.assert_called_once()
|
||||
@ -162,13 +162,13 @@ class TestSlackAlerting(unittest.TestCase):
|
||||
self.assertEqual(event, "soft_budget_crossed")
|
||||
self.assertTrue("Total Soft Budget" in event_message)
|
||||
|
||||
# Calling update_values with alerting args should try to start the periodic task
|
||||
# Calling update_values with alerting args should try to start the periodic task
|
||||
@patch("asyncio.create_task")
|
||||
def test_update_values_starts_periodic_task(self, mock_create_task):
|
||||
# Make it do nothing (or return a dummy future)
|
||||
mock_create_task.return_value = AsyncMock() # prevents awaiting errors
|
||||
|
||||
assert(self.slack_alerting.periodic_started == False)
|
||||
assert self.slack_alerting.periodic_started == False
|
||||
|
||||
self.slack_alerting.update_values(alerting_args={"slack_alerting": "True"})
|
||||
assert(self.slack_alerting.periodic_started == True)
|
||||
assert self.slack_alerting.periodic_started == True
|
||||
@ -5,7 +5,6 @@ from litellm.integrations.arize.arize_phoenix import ArizePhoenixLogger
|
||||
|
||||
|
||||
class TestArizePhoenixConfig(unittest.TestCase):
|
||||
|
||||
@patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
@ -7,15 +7,17 @@ from typing import Optional
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import asyncio
|
||||
import litellm
|
||||
|
||||
import pytest
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
import litellm
|
||||
from litellm.integrations._types.open_inference import (
|
||||
SpanAttributes,
|
||||
MessageAttributes,
|
||||
SpanAttributes,
|
||||
ToolCallAttributes,
|
||||
)
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Choices, StandardCallbackDynamicParams
|
||||
|
||||
|
||||
@ -25,6 +27,7 @@ def test_arize_set_attributes():
|
||||
Ensures that the correct span attributes are being added during a request.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
span = MagicMock() # Mocked tracing span to test attribute setting
|
||||
@ -1,18 +1,20 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../..")) # Adds the parent directory to the system-path
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
from litellm.integrations.agentops.agentops import AgentOps, AgentOpsConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_response():
|
||||
return {
|
||||
"token": "test_jwt_token",
|
||||
"project_id": "test_project_id"
|
||||
}
|
||||
return {"token": "test_jwt_token", "project_id": "test_project_id"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agentops_config():
|
||||
@ -21,16 +23,20 @@ def agentops_config():
|
||||
api_key="test_api_key",
|
||||
service_name="test_service",
|
||||
deployment_environment="test_env",
|
||||
auth_endpoint="https://api.agentops.ai/v3/auth/token"
|
||||
auth_endpoint="https://api.agentops.ai/v3/auth/token",
|
||||
)
|
||||
|
||||
|
||||
def test_agentops_config_from_env():
|
||||
"""Test that AgentOpsConfig correctly reads from environment variables"""
|
||||
with patch.dict(os.environ, {
|
||||
"AGENTOPS_API_KEY": "test_key",
|
||||
"AGENTOPS_SERVICE_NAME": "test_service",
|
||||
"AGENTOPS_ENVIRONMENT": "test_env"
|
||||
}):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AGENTOPS_API_KEY": "test_key",
|
||||
"AGENTOPS_SERVICE_NAME": "test_service",
|
||||
"AGENTOPS_ENVIRONMENT": "test_env",
|
||||
},
|
||||
):
|
||||
config = AgentOpsConfig.from_env()
|
||||
assert config.api_key == "test_key"
|
||||
assert config.service_name == "test_service"
|
||||
@ -38,6 +44,7 @@ def test_agentops_config_from_env():
|
||||
assert config.endpoint == "https://otlp.agentops.cloud/v1/traces"
|
||||
assert config.auth_endpoint == "https://api.agentops.ai/v3/auth/token"
|
||||
|
||||
|
||||
def test_agentops_config_defaults():
|
||||
"""Test that AgentOpsConfig uses correct default values"""
|
||||
config = AgentOpsConfig()
|
||||
@ -47,52 +54,64 @@ def test_agentops_config_defaults():
|
||||
assert config.endpoint == "https://otlp.agentops.cloud/v1/traces"
|
||||
assert config.auth_endpoint == "https://api.agentops.ai/v3/auth/token"
|
||||
|
||||
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
|
||||
|
||||
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
|
||||
def test_fetch_auth_token_success(mock_fetch_auth_token, mock_auth_response):
|
||||
"""Test successful JWT token fetch"""
|
||||
mock_fetch_auth_token.return_value = mock_auth_response
|
||||
|
||||
|
||||
config = AgentOpsConfig(api_key="test_key")
|
||||
agentops = AgentOps(config=config)
|
||||
|
||||
mock_fetch_auth_token.assert_called_once_with("test_key", "https://api.agentops.ai/v3/auth/token")
|
||||
assert agentops.resource_attributes.get("project.id") == mock_auth_response.get("project_id")
|
||||
|
||||
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
|
||||
mock_fetch_auth_token.assert_called_once_with(
|
||||
"test_key", "https://api.agentops.ai/v3/auth/token"
|
||||
)
|
||||
assert agentops.resource_attributes.get("project.id") == mock_auth_response.get(
|
||||
"project_id"
|
||||
)
|
||||
|
||||
|
||||
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
|
||||
def test_fetch_auth_token_failure(mock_fetch_auth_token):
|
||||
"""Test failed JWT token fetch"""
|
||||
mock_fetch_auth_token.side_effect = Exception("Failed to fetch auth token: Unauthorized")
|
||||
|
||||
mock_fetch_auth_token.side_effect = Exception(
|
||||
"Failed to fetch auth token: Unauthorized"
|
||||
)
|
||||
|
||||
config = AgentOpsConfig(api_key="test_key")
|
||||
agentops = AgentOps(config=config)
|
||||
|
||||
|
||||
mock_fetch_auth_token.assert_called_once()
|
||||
assert "project.id" not in agentops.resource_attributes
|
||||
|
||||
@patch('litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token')
|
||||
def test_agentops_initialization(mock_fetch_auth_token, agentops_config, mock_auth_response):
|
||||
|
||||
@patch("litellm.integrations.agentops.agentops.AgentOps._fetch_auth_token")
|
||||
def test_agentops_initialization(
|
||||
mock_fetch_auth_token, agentops_config, mock_auth_response
|
||||
):
|
||||
"""Test AgentOps initialization with config"""
|
||||
mock_fetch_auth_token.return_value = mock_auth_response
|
||||
|
||||
|
||||
agentops = AgentOps(config=agentops_config)
|
||||
|
||||
|
||||
assert agentops.resource_attributes["service.name"] == "test_service"
|
||||
assert agentops.resource_attributes["deployment.environment"] == "test_env"
|
||||
assert agentops.resource_attributes["telemetry.sdk.name"] == "agentops"
|
||||
assert agentops.resource_attributes["project.id"] == "test_project_id"
|
||||
|
||||
|
||||
def test_agentops_initialization_no_auth():
|
||||
"""Test AgentOps initialization without authentication"""
|
||||
test_config = AgentOpsConfig(
|
||||
endpoint="https://otlp.agentops.cloud/v1/traces",
|
||||
api_key=None, # No API key
|
||||
service_name="test_service",
|
||||
deployment_environment="test_env"
|
||||
deployment_environment="test_env",
|
||||
)
|
||||
|
||||
|
||||
agentops = AgentOps(config=test_config)
|
||||
|
||||
|
||||
assert agentops.resource_attributes["service.name"] == "test_service"
|
||||
assert agentops.resource_attributes["deployment.environment"] == "test_env"
|
||||
assert agentops.resource_attributes["telemetry.sdk.name"] == "agentops"
|
||||
assert "project.id" not in agentops.resource_attributes
|
||||
assert "project.id" not in agentops.resource_attributes
|
||||
220
tests/test_litellm/integrations/test_athina.py
Normal file
220
tests/test_litellm/integrations/test_athina.py
Normal file
@ -0,0 +1,220 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
from litellm.integrations.athina import AthinaLogger
|
||||
|
||||
|
||||
class TestAthinaLogger(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Set up environment variables for testing
|
||||
self.env_patcher = patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"ATHINA_API_KEY": "test-api-key",
|
||||
"ATHINA_BASE_URL": "https://test.athina.ai",
|
||||
},
|
||||
)
|
||||
self.env_patcher.start()
|
||||
self.logger = AthinaLogger()
|
||||
|
||||
# Setup common test variables
|
||||
self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1)
|
||||
self.print_verbose = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.env_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test the initialization of AthinaLogger"""
|
||||
self.assertEqual(self.logger.athina_api_key, "test-api-key")
|
||||
self.assertEqual(
|
||||
self.logger.athina_logging_url,
|
||||
"https://test.athina.ai/api/v1/log/inference",
|
||||
)
|
||||
self.assertEqual(
|
||||
self.logger.headers,
|
||||
{"athina-api-key": "test-api-key", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@patch("litellm.module_level_client.post")
|
||||
def test_log_event_success(self, mock_post):
|
||||
"""Test successful logging of an event"""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = "Success"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"environment": "test-environment",
|
||||
"prompt_slug": "test-prompt",
|
||||
"customer_id": "test-customer",
|
||||
"customer_user_id": "test-user",
|
||||
"session_id": "test-session",
|
||||
"external_reference_id": "test-ext-ref",
|
||||
"context": "test-context",
|
||||
"expected_response": "test-expected",
|
||||
"user_query": "test-query",
|
||||
"tags": ["test-tag"],
|
||||
"user_feedback": "test-feedback",
|
||||
"model_options": {"test-opt": "test-val"},
|
||||
"custom_attributes": {"test-attr": "test-val"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
"id": "resp-123",
|
||||
"choices": [{"message": {"content": "Hi there"}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(
|
||||
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
|
||||
)
|
||||
|
||||
# Verify the results
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
self.assertEqual(call_args[0][0], "https://test.athina.ai/api/v1/log/inference")
|
||||
self.assertEqual(call_args[1]["headers"], self.logger.headers)
|
||||
|
||||
# Parse and verify the sent data
|
||||
sent_data = json.loads(call_args[1]["data"])
|
||||
self.assertEqual(sent_data["language_model_id"], "gpt-4")
|
||||
self.assertEqual(sent_data["prompt"], kwargs["messages"])
|
||||
self.assertEqual(sent_data["prompt_tokens"], 10)
|
||||
self.assertEqual(sent_data["completion_tokens"], 5)
|
||||
self.assertEqual(sent_data["total_tokens"], 15)
|
||||
self.assertEqual(sent_data["response_time"], 1000) # 1 second = 1000ms
|
||||
self.assertEqual(sent_data["customer_id"], "test-customer")
|
||||
self.assertEqual(sent_data["session_id"], "test-session")
|
||||
self.assertEqual(sent_data["environment"], "test-environment")
|
||||
self.assertEqual(sent_data["prompt_slug"], "test-prompt")
|
||||
self.assertEqual(sent_data["external_reference_id"], "test-ext-ref")
|
||||
self.assertEqual(sent_data["context"], "test-context")
|
||||
self.assertEqual(sent_data["expected_response"], "test-expected")
|
||||
self.assertEqual(sent_data["user_query"], "test-query")
|
||||
self.assertEqual(sent_data["tags"], ["test-tag"])
|
||||
self.assertEqual(sent_data["user_feedback"], "test-feedback")
|
||||
self.assertEqual(sent_data["model_options"], {"test-opt": "test-val"})
|
||||
self.assertEqual(sent_data["custom_attributes"], {"test-attr": "test-val"})
|
||||
# Verify the print_verbose was called
|
||||
self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success")
|
||||
|
||||
@patch("litellm.module_level_client.post")
|
||||
def test_log_event_error_response(self, mock_post):
|
||||
"""Test handling of error response from the API"""
|
||||
# Setup mock error response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
"id": "resp-123",
|
||||
"choices": [{"message": {"content": "Hi there"}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(
|
||||
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
|
||||
)
|
||||
|
||||
# Verify print_verbose was called with error message
|
||||
self.print_verbose.assert_called_once_with(
|
||||
"Athina Logger Error - Bad Request, 400"
|
||||
)
|
||||
|
||||
@patch("litellm.module_level_client.post")
|
||||
def test_log_event_exception(self, mock_post):
|
||||
"""Test handling of exceptions during logging"""
|
||||
# Setup mock to raise exception
|
||||
mock_post.side_effect = Exception("Test exception")
|
||||
|
||||
# Create test data
|
||||
kwargs = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(
|
||||
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
|
||||
)
|
||||
|
||||
# Verify print_verbose was called with exception info
|
||||
self.print_verbose.assert_called_once()
|
||||
self.assertIn(
|
||||
"Athina Logger Error - Test exception", self.print_verbose.call_args[0][0]
|
||||
)
|
||||
|
||||
@patch("litellm.module_level_client.post")
|
||||
def test_log_event_with_tools(self, mock_post):
|
||||
"""Test logging with tools/functions data"""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Create test data with tools
|
||||
kwargs = {
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "What's the weather?"}],
|
||||
"stream": False,
|
||||
"optional_params": {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
},
|
||||
}
|
||||
|
||||
response_obj = MagicMock()
|
||||
response_obj.model_dump.return_value = {
|
||||
"id": "resp-123",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
# Call the method
|
||||
self.logger.log_event(
|
||||
kwargs, response_obj, self.start_time, self.end_time, self.print_verbose
|
||||
)
|
||||
|
||||
# Verify the results
|
||||
sent_data = json.loads(mock_post.call_args[1]["data"])
|
||||
self.assertEqual(
|
||||
sent_data["tools"],
|
||||
[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,12 +1,12 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
import os
|
||||
import unittest
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.integrations.deepeval.api import Endpoints, HttpMethods
|
||||
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
|
||||
from litellm.integrations.deepeval.api import HttpMethods, Endpoints
|
||||
from litellm.integrations.deepeval.types import TraceSpanApiStatus, SpanApiType
|
||||
from litellm.integrations.deepeval.types import SpanApiType, TraceSpanApiStatus
|
||||
|
||||
|
||||
class TestDeepEvalLogger(unittest.TestCase):
|
||||
0
tests/test_litellm/litellm_core_utils/__init__.py
Normal file
0
tests/test_litellm/litellm_core_utils/__init__.py
Normal file
@ -410,8 +410,8 @@ inner_object = {
|
||||
}
|
||||
],
|
||||
"tool_choice": "none",
|
||||
"count": 65,
|
||||
"count-tolerate" : 67 #over by 2
|
||||
"count": 65,
|
||||
"count-tolerate": 67, # over by 2
|
||||
}
|
||||
"""
|
||||
namespace functions {
|
||||
@ -459,7 +459,7 @@ inner_object_with_enum_only = {
|
||||
],
|
||||
"tool_choice": "none",
|
||||
"count": 73,
|
||||
"count-tolerate" : 74 #over by 1
|
||||
"count-tolerate": 74, # over by 1
|
||||
}
|
||||
"""
|
||||
namespace functions {
|
||||
@ -511,7 +511,7 @@ inner_object_with_enum = {
|
||||
],
|
||||
"tool_choice": "none",
|
||||
"count": 89,
|
||||
"count-tolerate" : 92, #over by 3
|
||||
"count-tolerate": 92, # over by 3
|
||||
}
|
||||
"""
|
||||
namespace functions {
|
||||
@ -568,8 +568,8 @@ inner_object_and_string = {
|
||||
}
|
||||
],
|
||||
"tool_choice": "none",
|
||||
"count": 103,
|
||||
"count-tolerate" : 106, #over by 3
|
||||
"count": 103,
|
||||
"count-tolerate": 106, # over by 3
|
||||
}
|
||||
"""
|
||||
namespace functions {
|
||||
@ -6,10 +6,11 @@ import pytest
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
BAD_MESSAGE_ERROR_STR,
|
||||
ollama_pt,
|
||||
BedrockConverseMessagesProcessor,
|
||||
ollama_pt,
|
||||
)
|
||||
|
||||
|
||||
def test_ollama_pt_simple_messages():
|
||||
"""Test basic functionality with simple text messages"""
|
||||
messages = [
|
||||
@ -43,6 +44,7 @@ def test_ollama_pt_consecutive_user_messages():
|
||||
assert isinstance(result, dict)
|
||||
assert result["prompt"] == expected_prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_bedrock_thinking_blocks_with_none_content():
|
||||
"""
|
||||
@ -55,28 +57,29 @@ async def test_anthropic_bedrock_thinking_blocks_with_none_content():
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "This is a test thinking block",
|
||||
"signature": "test-signature"
|
||||
"signature": "test-signature",
|
||||
}
|
||||
],
|
||||
"reasoning_content": "This is the reasoning content"
|
||||
"reasoning_content": "This is the reasoning content",
|
||||
}
|
||||
messages = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
mock_assistant_message
|
||||
mock_assistant_message,
|
||||
]
|
||||
|
||||
# test _bedrock_converse_messages_pt_async
|
||||
result = await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
messages=messages,
|
||||
model="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
llm_provider="bedrock"
|
||||
llm_provider="bedrock",
|
||||
)
|
||||
|
||||
|
||||
# verify the result
|
||||
assert len(result) == 2
|
||||
assert result[1]["content"][0]["reasoningContent"]["reasoningText"]["text"] == "This is a test thinking block"
|
||||
|
||||
assert (
|
||||
result[1]["content"][0]["reasoningContent"]["reasoningText"]["text"]
|
||||
== "This is a test thinking block"
|
||||
)
|
||||
|
||||
|
||||
# def test_ollama_pt_consecutive_system_messages():
|
||||
@ -1,30 +1,34 @@
|
||||
import unittest
|
||||
from datetime import datetime, timezone
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from litellm.litellm_core_utils.duration_parser import get_next_standardized_reset_time
|
||||
|
||||
|
||||
class TestStandardizedResetTime(unittest.TestCase):
|
||||
def test_day_based_resets(self):
|
||||
"""Test day-based reset durations (1d, 7d, 30d)"""
|
||||
# Base time: 2023-05-15 10:30:00 UTC
|
||||
base_time = datetime(2023, 5, 15, 10, 30, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# Daily reset (1d) - should reset at next midnight
|
||||
daily_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
daily_result = get_next_standardized_reset_time("1d", base_time, "UTC")
|
||||
self.assertEqual(daily_result, daily_expected)
|
||||
|
||||
|
||||
# Weekly reset (7d) - should reset on next Monday
|
||||
wednesday = datetime(2023, 5, 17, 15, 45, 0, tzinfo=timezone.utc) # A Wednesday
|
||||
weekly_expected = datetime(2023, 5, 22, 0, 0, 0, tzinfo=timezone.utc) # Next Monday
|
||||
weekly_expected = datetime(
|
||||
2023, 5, 22, 0, 0, 0, tzinfo=timezone.utc
|
||||
) # Next Monday
|
||||
weekly_result = get_next_standardized_reset_time("7d", wednesday, "UTC")
|
||||
self.assertEqual(weekly_result, weekly_expected)
|
||||
|
||||
|
||||
# Monthly reset (30d) - should reset on 1st of next month
|
||||
monthly_expected = datetime(2023, 6, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
monthly_result = get_next_standardized_reset_time("30d", base_time, "UTC")
|
||||
self.assertEqual(monthly_result, monthly_expected)
|
||||
|
||||
|
||||
# Custom day reset (3d) - should reset after 3 days
|
||||
custom_day_expected = datetime(2023, 5, 18, 0, 0, 0, tzinfo=timezone.utc)
|
||||
custom_day_result = get_next_standardized_reset_time("3d", base_time, "UTC")
|
||||
@ -34,17 +38,17 @@ class TestStandardizedResetTime(unittest.TestCase):
|
||||
"""Test hour, minute, and second based reset durations"""
|
||||
# Base time: 2023-05-15 15:20:30 UTC (3:20:30 PM)
|
||||
base_time = datetime(2023, 5, 15, 15, 20, 30, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# 2-hour reset - should reset at next even hour (16:00)
|
||||
hour_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=timezone.utc)
|
||||
hour_result = get_next_standardized_reset_time("2h", base_time, "UTC")
|
||||
self.assertEqual(hour_result, hour_expected)
|
||||
|
||||
|
||||
# 30-minute reset - should reset at next 30-minute mark (15:30)
|
||||
minute_expected = datetime(2023, 5, 15, 15, 30, 0, tzinfo=timezone.utc)
|
||||
minute_result = get_next_standardized_reset_time("30m", base_time, "UTC")
|
||||
self.assertEqual(minute_result, minute_expected)
|
||||
|
||||
|
||||
# 15-second reset - should reset at next 15-second mark (15:20:45)
|
||||
second_expected = datetime(2023, 5, 15, 15, 20, 45, tzinfo=timezone.utc)
|
||||
second_result = get_next_standardized_reset_time("15s", base_time, "UTC")
|
||||
@ -54,32 +58,34 @@ class TestStandardizedResetTime(unittest.TestCase):
|
||||
"""Test timezone handling with different regions"""
|
||||
# Base time: 2023-05-15 22:30:00 UTC (late in UTC day)
|
||||
base_time = datetime(2023, 5, 15, 22, 30, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# Test daily reset in different timezones
|
||||
# US/Eastern (UTC-4): 6:30 PM, so next reset is midnight same day
|
||||
eastern = ZoneInfo("US/Eastern")
|
||||
eastern_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=eastern)
|
||||
eastern_result = get_next_standardized_reset_time("1d", base_time, "US/Eastern")
|
||||
self.assertEqual(eastern_result, eastern_expected)
|
||||
|
||||
|
||||
# Asia/Kolkata (UTC+5:30): 4:00 AM next day, so next reset is midnight the day after
|
||||
ist = ZoneInfo("Asia/Kolkata")
|
||||
ist_expected = datetime(2023, 5, 17, 0, 0, 0, tzinfo=ist)
|
||||
ist_result = get_next_standardized_reset_time("1d", base_time, "Asia/Kolkata")
|
||||
self.assertEqual(ist_result, ist_expected)
|
||||
|
||||
|
||||
# Test hourly reset in different timezones
|
||||
# US/Pacific (UTC-7): 3:30 PM, so next 2h reset is 4:00 PM
|
||||
pacific = ZoneInfo("US/Pacific")
|
||||
pacific_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=pacific)
|
||||
pacific_result = get_next_standardized_reset_time("2h", base_time, "US/Pacific")
|
||||
self.assertEqual(pacific_result, pacific_expected)
|
||||
|
||||
|
||||
# Test minute reset in different timezones
|
||||
# Europe/London (UTC+1): 11:30 PM, so next 15m reset is 11:45 PM
|
||||
london = ZoneInfo("Europe/London")
|
||||
london_expected = datetime(2023, 5, 15, 23, 45, 0, tzinfo=london)
|
||||
london_result = get_next_standardized_reset_time("15m", base_time, "Europe/London")
|
||||
london_result = get_next_standardized_reset_time(
|
||||
"15m", base_time, "Europe/London"
|
||||
)
|
||||
self.assertEqual(london_result, london_expected)
|
||||
|
||||
def test_edge_cases(self):
|
||||
@ -89,25 +95,30 @@ class TestStandardizedResetTime(unittest.TestCase):
|
||||
hour_expected = datetime(2023, 5, 15, 16, 0, 0, tzinfo=timezone.utc)
|
||||
hour_result = get_next_standardized_reset_time("2h", on_hour, "UTC")
|
||||
self.assertEqual(hour_result, hour_expected)
|
||||
|
||||
|
||||
# Exactly on minute boundary
|
||||
on_minute = datetime(2023, 5, 15, 14, 30, 0, tzinfo=timezone.utc)
|
||||
minute_expected = datetime(2023, 5, 15, 15, 0, 0, tzinfo=timezone.utc)
|
||||
minute_result = get_next_standardized_reset_time("30m", on_minute, "UTC")
|
||||
self.assertEqual(minute_result, minute_expected)
|
||||
|
||||
|
||||
# Near day boundary
|
||||
near_midnight = datetime(2023, 5, 15, 23, 50, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# 30m near midnight - should roll over to next day
|
||||
midnight_minute_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
midnight_minute_result = get_next_standardized_reset_time("30m", near_midnight, "UTC")
|
||||
midnight_minute_result = get_next_standardized_reset_time(
|
||||
"30m", near_midnight, "UTC"
|
||||
)
|
||||
self.assertEqual(midnight_minute_result, midnight_minute_expected)
|
||||
|
||||
|
||||
# Invalid timezone - should fall back to UTC
|
||||
invalid_tz_expected = datetime(2023, 5, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
invalid_tz_result = get_next_standardized_reset_time("1d", on_hour, "NonExistentTimeZone")
|
||||
invalid_tz_result = get_next_standardized_reset_time(
|
||||
"1d", on_hour, "NonExistentTimeZone"
|
||||
)
|
||||
self.assertEqual(invalid_tz_result, invalid_tz_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
@ -15,9 +15,9 @@ from litellm.types.utils import (
|
||||
Delta,
|
||||
Function,
|
||||
ModelResponseStream,
|
||||
PromptTokensDetails,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
PromptTokensDetails,
|
||||
)
|
||||
|
||||
|
||||
@ -612,29 +612,45 @@ def test_streaming_handler_with_stop_chunk(
|
||||
assert returned_chunk is None
|
||||
|
||||
|
||||
def test_set_response_id_propagation_empty_to_valid(initialized_custom_stream_wrapper: CustomStreamWrapper):
|
||||
def test_set_response_id_propagation_empty_to_valid(
|
||||
initialized_custom_stream_wrapper: CustomStreamWrapper,
|
||||
):
|
||||
"""Test that response_id is properly set when first chunk has empty ID and second chunk has valid ID"""
|
||||
|
||||
model_response1 = ModelResponseStream(id="", created=1742056047, model=None)
|
||||
model_response1 = initialized_custom_stream_wrapper.set_model_id(model_response1.id, model_response1)
|
||||
model_response1 = initialized_custom_stream_wrapper.set_model_id(
|
||||
model_response1.id, model_response1
|
||||
)
|
||||
assert model_response1.id == ""
|
||||
|
||||
model_response2 = ModelResponseStream(id="valid-id-123", created=1742056048, model=None)
|
||||
model_response2 = initialized_custom_stream_wrapper.set_model_id("valid-id-123", model_response2)
|
||||
model_response2 = ModelResponseStream(
|
||||
id="valid-id-123", created=1742056048, model=None
|
||||
)
|
||||
model_response2 = initialized_custom_stream_wrapper.set_model_id(
|
||||
"valid-id-123", model_response2
|
||||
)
|
||||
assert model_response2.id == "valid-id-123"
|
||||
assert initialized_custom_stream_wrapper.response_id == "valid-id-123"
|
||||
|
||||
|
||||
def test_set_response_id_propagation_valid_to_invalid(initialized_custom_stream_wrapper: CustomStreamWrapper):
|
||||
def test_set_response_id_propagation_valid_to_invalid(
|
||||
initialized_custom_stream_wrapper: CustomStreamWrapper,
|
||||
):
|
||||
"""Test that response_id is maintained when first chunk has valid ID and second chunk has invalid ID"""
|
||||
|
||||
model_response1 = ModelResponseStream(id="first-valid-id", created=1742056049, model=None)
|
||||
model_response1 = initialized_custom_stream_wrapper.set_model_id("first-valid-id", model_response1)
|
||||
model_response1 = ModelResponseStream(
|
||||
id="first-valid-id", created=1742056049, model=None
|
||||
)
|
||||
model_response1 = initialized_custom_stream_wrapper.set_model_id(
|
||||
"first-valid-id", model_response1
|
||||
)
|
||||
assert model_response1.id == "first-valid-id"
|
||||
assert initialized_custom_stream_wrapper.response_id == "first-valid-id"
|
||||
|
||||
model_response2 = ModelResponseStream(id="", created=1742056050, model=None)
|
||||
model_response2 = initialized_custom_stream_wrapper.set_model_id("", model_response2)
|
||||
model_response2 = initialized_custom_stream_wrapper.set_model_id(
|
||||
"", model_response2
|
||||
)
|
||||
assert model_response2.id == "first-valid-id"
|
||||
assert initialized_custom_stream_wrapper.response_id == "first-valid-id"
|
||||
|
||||
@ -13,17 +13,16 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from messages_with_counts import (
|
||||
MESSAGES_TEXT,
|
||||
MESSAGES_WITH_IMAGES,
|
||||
MESSAGES_WITH_TOOLS,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm import create_pretrained_tokenizer, decode, encode, get_modified_max_tokens
|
||||
from litellm import token_counter as token_counter_old
|
||||
from litellm.litellm_core_utils.token_counter import token_counter as token_counter_new
|
||||
from tests.large_text import text
|
||||
from tests.test_litellm.litellm_core_utils.messages_with_counts import (
|
||||
MESSAGES_TEXT,
|
||||
MESSAGES_WITH_IMAGES,
|
||||
MESSAGES_WITH_TOOLS,
|
||||
)
|
||||
|
||||
|
||||
def token_counter_both_assert_same(**args):
|
||||
@ -9,10 +9,9 @@ sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
#Use the same token_counter as the main test.
|
||||
from test_token_counter import token_counter
|
||||
|
||||
from test_token_counter_tool_data import *
|
||||
# Use the same token_counter as the main test.
|
||||
from tests.test_litellm.litellm_core_utils.test_token_counter import token_counter
|
||||
from tests.test_litellm.litellm_core_utils.test_token_counter_tool_data import *
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -27,19 +27,19 @@ CONTENT_AND_TOOL_CALL = [
|
||||
]
|
||||
|
||||
_OPENHANDS_SYSTEM_MESSAGE = {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are OpenHands agent, a helpful AI assistant that can "
|
||||
"interact with a computer to solve tasks.\n\n<ROLE>\nYour primary "
|
||||
"role is to assist users by executing commands, modifying code, and "
|
||||
"solving technical problems effectively. You should be thorough, "
|
||||
"methodical, and prioritize quality over speed.\n* If the user asks a "
|
||||
"question, like 'why is X happening', don’t try to fix the problem. ",
|
||||
}
|
||||
],
|
||||
"role": "system",
|
||||
}
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are OpenHands agent, a helpful AI assistant that can "
|
||||
"interact with a computer to solve tasks.\n\n<ROLE>\nYour primary "
|
||||
"role is to assist users by executing commands, modifying code, and "
|
||||
"solving technical problems effectively. You should be thorough, "
|
||||
"methodical, and prioritize quality over speed.\n* If the user asks a "
|
||||
"question, like 'why is X happening', don’t try to fix the problem. ",
|
||||
}
|
||||
],
|
||||
"role": "system",
|
||||
}
|
||||
|
||||
SYSTEM_LONG = [
|
||||
_OPENHANDS_SYSTEM_MESSAGE,
|
||||
@ -19,13 +19,15 @@ async def test_get_openai_compatible_provider_info():
|
||||
"""
|
||||
config = AzureAIStudioConfig()
|
||||
|
||||
api_base, dynamic_api_key, custom_llm_provider = (
|
||||
config._get_openai_compatible_provider_info(
|
||||
model="azure_ai/gpt-4o-mini",
|
||||
api_base="https://my-base",
|
||||
api_key="my-key",
|
||||
custom_llm_provider="azure_ai",
|
||||
)
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
custom_llm_provider,
|
||||
) = config._get_openai_compatible_provider_info(
|
||||
model="azure_ai/gpt-4o-mini",
|
||||
api_base="https://my-base",
|
||||
api_key="my-key",
|
||||
custom_llm_provider="azure_ai",
|
||||
)
|
||||
|
||||
assert custom_llm_provider == "azure"
|
||||
@ -1,6 +1,6 @@
|
||||
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import AmazonMistralConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import (
|
||||
AmazonMistralConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
@ -10,7 +10,9 @@ def test_mistral_get_outputText():
|
||||
model_response.choices[0].finish_reason = "None"
|
||||
|
||||
# Models like pixtral will return a completion with the openai format.
|
||||
mock_json_with_choices = {"choices": [{"message": {"content": "Hello!"}, "finish_reason": "stop"}]}
|
||||
mock_json_with_choices = {
|
||||
"choices": [{"message": {"content": "Hello!"}, "finish_reason": "stop"}]
|
||||
}
|
||||
|
||||
outputText = AmazonMistralConfig.get_outputText(
|
||||
completion_response=mock_json_with_choices, model_response=model_response
|
||||
@ -2,7 +2,6 @@ import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
@ -18,7 +17,11 @@ class TestCohereTransform:
|
||||
|
||||
def test_map_cohere_params(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
|
||||
test_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 200,
|
||||
"max_completion_tokens": 256,
|
||||
}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
@ -32,7 +35,10 @@ class TestCohereTransform:
|
||||
|
||||
def test_cohere_max_tokens_backward_compat(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200,}
|
||||
test_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 200,
|
||||
}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
@ -13,12 +13,26 @@ from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig
|
||||
("user-provided-key", "secret-key", "user-provided-key", False),
|
||||
(None, "secret-key", "secret-key", True),
|
||||
(None, None, "fake-api-key", True),
|
||||
("", "secret-key", "secret-key", True), # Empty string should fall back to secret
|
||||
("", None, "fake-api-key", True), # Empty string with no secret should use the fake key
|
||||
]
|
||||
(
|
||||
"",
|
||||
"secret-key",
|
||||
"secret-key",
|
||||
True,
|
||||
), # Empty string should fall back to secret
|
||||
(
|
||||
"",
|
||||
None,
|
||||
"fake-api-key",
|
||||
True,
|
||||
), # Empty string with no secret should use the fake key
|
||||
],
|
||||
)
|
||||
def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called):
|
||||
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||
def test_resolve_api_key(
|
||||
input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called
|
||||
):
|
||||
with patch(
|
||||
"litellm.llms.llamafile.chat.transformation.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = api_key_from_secret_manager
|
||||
|
||||
result = LlamafileChatConfig._resolve_api_key(input_api_key)
|
||||
@ -34,14 +48,36 @@ def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_ap
|
||||
@pytest.mark.parametrize(
|
||||
"input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called",
|
||||
[
|
||||
("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False),
|
||||
(None, "https://secret-api.example.com", "https://secret-api.example.com", True),
|
||||
(
|
||||
"https://user-api.example.com",
|
||||
"https://secret-api.example.com",
|
||||
"https://user-api.example.com",
|
||||
False,
|
||||
),
|
||||
(
|
||||
None,
|
||||
"https://secret-api.example.com",
|
||||
"https://secret-api.example.com",
|
||||
True,
|
||||
),
|
||||
(None, None, "http://127.0.0.1:8080/v1", True),
|
||||
("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back
|
||||
]
|
||||
(
|
||||
"",
|
||||
"https://secret-api.example.com",
|
||||
"https://secret-api.example.com",
|
||||
True,
|
||||
), # Empty string should fall back
|
||||
],
|
||||
)
|
||||
def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called):
|
||||
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||
def test_resolve_api_base(
|
||||
input_api_base,
|
||||
api_base_from_secret_manager,
|
||||
expected_api_base,
|
||||
secret_manager_called,
|
||||
):
|
||||
with patch(
|
||||
"litellm.llms.llamafile.chat.transformation.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = api_base_from_secret_manager
|
||||
|
||||
result = LlamafileChatConfig._resolve_api_base(input_api_base)
|
||||
@ -58,31 +94,73 @@ def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected
|
||||
"api_base, api_key, secret_base, secret_key, expected_base, expected_key",
|
||||
[
|
||||
# User-provided values
|
||||
("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"),
|
||||
(
|
||||
"https://user-api.example.com",
|
||||
"user-key",
|
||||
"https://secret-api.example.com",
|
||||
"secret-key",
|
||||
"https://user-api.example.com",
|
||||
"user-key",
|
||||
),
|
||||
# Fallback to secrets
|
||||
(None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"),
|
||||
(
|
||||
None,
|
||||
None,
|
||||
"https://secret-api.example.com",
|
||||
"secret-key",
|
||||
"https://secret-api.example.com",
|
||||
"secret-key",
|
||||
),
|
||||
# Nothing provided, use defaults
|
||||
(None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"),
|
||||
# Mixed scenarios
|
||||
("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"),
|
||||
(None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"),
|
||||
]
|
||||
(
|
||||
"https://user-api.example.com",
|
||||
None,
|
||||
None,
|
||||
"secret-key",
|
||||
"https://user-api.example.com",
|
||||
"secret-key",
|
||||
),
|
||||
(
|
||||
None,
|
||||
"user-key",
|
||||
"https://secret-api.example.com",
|
||||
None,
|
||||
"https://secret-api.example.com",
|
||||
"user-key",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key):
|
||||
def test_get_openai_compatible_provider_info(
|
||||
api_base, api_key, secret_base, secret_key, expected_base, expected_key
|
||||
):
|
||||
config = LlamafileChatConfig()
|
||||
|
||||
def fake_get_secret(key: str) -> Optional[str]:
|
||||
return {
|
||||
"LLAMAFILE_API_BASE": secret_base,
|
||||
"LLAMAFILE_API_KEY": secret_key
|
||||
}.get(key)
|
||||
return {"LLAMAFILE_API_BASE": secret_base, "LLAMAFILE_API_KEY": secret_key}.get(
|
||||
key
|
||||
)
|
||||
|
||||
patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret)
|
||||
patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base)
|
||||
patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key)
|
||||
patch_secret = patch(
|
||||
"litellm.llms.llamafile.chat.transformation.get_secret_str",
|
||||
side_effect=fake_get_secret,
|
||||
)
|
||||
patch_base = patch.object(
|
||||
LlamafileChatConfig,
|
||||
"_resolve_api_base",
|
||||
wraps=LlamafileChatConfig._resolve_api_base,
|
||||
)
|
||||
patch_key = patch.object(
|
||||
LlamafileChatConfig,
|
||||
"_resolve_api_key",
|
||||
wraps=LlamafileChatConfig._resolve_api_key,
|
||||
)
|
||||
|
||||
with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key:
|
||||
result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key)
|
||||
result_base, result_key = config._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
|
||||
assert result_base == expected_base
|
||||
assert result_key == expected_key
|
||||
@ -100,8 +178,12 @@ def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, sec
|
||||
|
||||
|
||||
def test_completion_with_custom_llamafile_model():
|
||||
with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func:
|
||||
mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response
|
||||
with patch(
|
||||
"litellm.main.openai_chat_completions.completion"
|
||||
) as mock_llamafile_completion_func:
|
||||
mock_llamafile_completion_func.return_value = (
|
||||
{}
|
||||
) # Return an empty dictionary for the mocked response
|
||||
|
||||
provider = "llamafile"
|
||||
model_name = "my-custom-test-model"
|
||||
@ -3,7 +3,9 @@ import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")))
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
|
||||
)
|
||||
|
||||
from litellm.llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
from litellm.utils import get_optional_params
|
||||
@ -44,7 +46,9 @@ class TestLMStudioChatConfigResponseFormat:
|
||||
custom_llm_provider="lm_studio",
|
||||
)
|
||||
|
||||
mapped = config.map_openai_params(non_default_params, {}, "lm_studio/test-model", False)
|
||||
mapped = config.map_openai_params(
|
||||
non_default_params, {}, "lm_studio/test-model", False
|
||||
)
|
||||
mapped_schema = mapped["response_format"]["json_schema"]["schema"]
|
||||
assert mapped_schema["properties"] == schema["properties"]
|
||||
opt_schema = optional_params["response_format"]["json_schema"]["schema"]
|
||||
@ -1,45 +1,42 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.ollama.completion.transformation import (
|
||||
OllamaConfig,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import Message
|
||||
from litellm.llms.ollama.completion.transformation import OllamaConfig
|
||||
from litellm.types.utils import Message, ModelResponse
|
||||
|
||||
|
||||
class TestOllamaConfig:
|
||||
def test_transform_response_standard(self):
|
||||
# Initialize config
|
||||
config = OllamaConfig()
|
||||
|
||||
|
||||
# Create mock response
|
||||
raw_response = MagicMock()
|
||||
raw_response.json.return_value = {
|
||||
"response": "Hello, I am an AI assistant",
|
||||
"prompt_eval_count": 10,
|
||||
"eval_count": 5
|
||||
"eval_count": 5,
|
||||
}
|
||||
|
||||
|
||||
# Create properly structured model response object
|
||||
model_response = ModelResponse(
|
||||
id="test_id",
|
||||
choices=[{"message": Message(content="")}],
|
||||
)
|
||||
|
||||
|
||||
# Create mock encoding
|
||||
mock_encoding = MagicMock()
|
||||
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
|
||||
|
||||
|
||||
# Transform response
|
||||
result = config.transform_response(
|
||||
model="llama2",
|
||||
@ -52,7 +49,7 @@ class TestOllamaConfig:
|
||||
litellm_params={},
|
||||
encoding=mock_encoding,
|
||||
)
|
||||
|
||||
|
||||
# Verify response
|
||||
assert result.choices[0]["message"].content == "Hello, I am an AI assistant"
|
||||
assert result.choices[0]["finish_reason"] == "stop"
|
||||
@ -67,29 +64,28 @@ class TestOllamaConfig:
|
||||
def test_transform_response_json_function_call(self, mock_uuid4):
|
||||
# Setup mock UUID
|
||||
mock_uuid4.return_value = "test-uuid"
|
||||
|
||||
|
||||
# Initialize config
|
||||
config = OllamaConfig()
|
||||
|
||||
|
||||
# Create mock response with JSON function call format
|
||||
raw_response = MagicMock()
|
||||
raw_response.json.return_value = {
|
||||
"response": json.dumps({
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"}
|
||||
})
|
||||
"response": json.dumps(
|
||||
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# Create properly structured model response object
|
||||
model_response = ModelResponse(
|
||||
id="test_id",
|
||||
choices=[{"message": Message(content="")}],
|
||||
)
|
||||
|
||||
|
||||
# Create mock encoding
|
||||
mock_encoding = MagicMock()
|
||||
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
|
||||
|
||||
|
||||
# Transform response
|
||||
result = config.transform_response(
|
||||
model="llama2",
|
||||
@ -102,39 +98,43 @@ class TestOllamaConfig:
|
||||
litellm_params={},
|
||||
encoding=mock_encoding,
|
||||
)
|
||||
|
||||
|
||||
# Verify result has tool_calls
|
||||
assert result.choices[0]["message"].content is None
|
||||
assert result.choices[0]["finish_reason"] == "tool_calls"
|
||||
assert len(result.choices[0]["message"].tool_calls) == 1
|
||||
assert result.choices[0]["message"].tool_calls[0]["id"].startswith("call_")
|
||||
assert result.choices[0]["message"].tool_calls[0]["function"]["name"] == "get_weather"
|
||||
assert json.loads(result.choices[0]["message"].tool_calls[0]["function"]["arguments"]) == {"location": "San Francisco"}
|
||||
assert (
|
||||
result.choices[0]["message"].tool_calls[0]["function"]["name"]
|
||||
== "get_weather"
|
||||
)
|
||||
assert json.loads(
|
||||
result.choices[0]["message"].tool_calls[0]["function"]["arguments"]
|
||||
) == {"location": "San Francisco"}
|
||||
# No usage assertions here as we don't need to test them in every case
|
||||
|
||||
def test_transform_response_regular_json(self):
|
||||
# Initialize config
|
||||
config = OllamaConfig()
|
||||
|
||||
|
||||
# Create mock response with regular JSON (not function call)
|
||||
raw_response = MagicMock()
|
||||
raw_response.json.return_value = {
|
||||
"response": json.dumps({
|
||||
"result": "success",
|
||||
"data": {"temperature": 72, "unit": "F"}
|
||||
})
|
||||
"response": json.dumps(
|
||||
{"result": "success", "data": {"temperature": 72, "unit": "F"}}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# Create properly structured model response object
|
||||
model_response = ModelResponse(
|
||||
id="test_id",
|
||||
choices=[{"message": Message(content="")}],
|
||||
)
|
||||
|
||||
|
||||
# Create mock encoding
|
||||
mock_encoding = MagicMock()
|
||||
mock_encoding.encode.return_value = [1, 2, 3] # Return dummy token IDs
|
||||
|
||||
|
||||
# Transform response
|
||||
result = config.transform_response(
|
||||
model="llama2",
|
||||
@ -147,12 +147,11 @@ class TestOllamaConfig:
|
||||
litellm_params={},
|
||||
encoding=mock_encoding,
|
||||
)
|
||||
|
||||
|
||||
# Verify result has JSON content
|
||||
expected_content = json.dumps({
|
||||
"result": "success",
|
||||
"data": {"temperature": 72, "unit": "F"}
|
||||
})
|
||||
expected_content = json.dumps(
|
||||
{"result": "success", "data": {"temperature": 72, "unit": "F"}}
|
||||
)
|
||||
assert result.choices[0]["message"].content == expected_content
|
||||
assert result.choices[0]["finish_reason"] == "stop"
|
||||
# No usage assertions here as we don't need to test them in every case
|
||||
# No usage assertions here as we don't need to test them in every case
|
||||
@ -1,10 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
@ -14,13 +14,15 @@ sys.path.insert(
|
||||
Unit tests for OllamaModelInfo.get_models functionality.
|
||||
"""
|
||||
# Ensure a dummy httpx module is available for import in tests
|
||||
import sys, types
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Provide a dummy httpx module for import in get_models
|
||||
if 'httpx' not in sys.modules:
|
||||
if "httpx" not in sys.modules:
|
||||
# Create a minimal module with HTTPStatusError
|
||||
httpx_mod = types.ModuleType('httpx')
|
||||
httpx_mod = types.ModuleType("httpx")
|
||||
httpx_mod.HTTPStatusError = Exception
|
||||
sys.modules['httpx'] = httpx_mod
|
||||
sys.modules["httpx"] = httpx_mod
|
||||
|
||||
import httpx
|
||||
|
||||
@ -31,6 +33,7 @@ class DummyResponse:
|
||||
"""
|
||||
A dummy response object to simulate httpx responses.
|
||||
"""
|
||||
|
||||
def __init__(self, json_data, status_code=200):
|
||||
self._json = json_data
|
||||
self.status_code = status_code
|
||||
@ -38,7 +41,9 @@ class DummyResponse:
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
# Simulate an HTTP status error
|
||||
raise httpx.HTTPStatusError("Error status code", request=None, response=None)
|
||||
raise httpx.HTTPStatusError(
|
||||
"Error status code", request=None, response=None
|
||||
)
|
||||
|
||||
def json(self):
|
||||
return self._json
|
||||
@ -51,25 +56,26 @@ class TestOllamaModelInfo:
|
||||
get_models should extract and return sorted unique model names.
|
||||
"""
|
||||
calls = []
|
||||
sample = {'models': [
|
||||
{'name': 'zeta'},
|
||||
{'model': 'alpha'},
|
||||
{'name': 123}, # non-str should be ignored
|
||||
'invalid', # non-dict should be ignored
|
||||
]}
|
||||
sample = {
|
||||
"models": [
|
||||
{"name": "zeta"},
|
||||
{"model": "alpha"},
|
||||
{"name": 123}, # non-str should be ignored
|
||||
"invalid", # non-dict should be ignored
|
||||
]
|
||||
}
|
||||
|
||||
def mock_get(url):
|
||||
calls.append(url)
|
||||
return DummyResponse(sample, status_code=200)
|
||||
|
||||
monkeypatch.setattr(httpx, 'get', mock_get)
|
||||
monkeypatch.setattr(httpx, "get", mock_get)
|
||||
info = OllamaModelInfo()
|
||||
models = info.get_models()
|
||||
# Only 'alpha' and 'zeta' should be returned, sorted alphabetically
|
||||
assert models == ['alpha', 'zeta']
|
||||
assert models == ["alpha", "zeta"]
|
||||
# Ensure correct endpoint was called
|
||||
assert calls and calls[0].endswith('/api/tags')
|
||||
|
||||
assert calls and calls[0].endswith("/api/tags")
|
||||
|
||||
def test_get_models_from_list_response(self, monkeypatch):
|
||||
"""
|
||||
@ -77,30 +83,30 @@ class TestOllamaModelInfo:
|
||||
get_models should extract and return sorted unique model names.
|
||||
"""
|
||||
sample = [
|
||||
{'name': 'm1'},
|
||||
{'model': 'm2'},
|
||||
{}, # no name/model key should be ignored
|
||||
{"name": "m1"},
|
||||
{"model": "m2"},
|
||||
{}, # no name/model key should be ignored
|
||||
]
|
||||
|
||||
def mock_get(url):
|
||||
return DummyResponse(sample, status_code=200)
|
||||
|
||||
monkeypatch.setattr(httpx, 'get', mock_get)
|
||||
monkeypatch.setattr(httpx, "get", mock_get)
|
||||
info = OllamaModelInfo()
|
||||
models = info.get_models()
|
||||
assert models == ['m1', 'm2']
|
||||
|
||||
assert models == ["m1", "m2"]
|
||||
|
||||
def test_get_models_fallback_on_error(self, monkeypatch):
|
||||
"""
|
||||
If the httpx.get call raises an exception, get_models should
|
||||
fall back to the static models_by_provider list prefixed by 'ollama/'.
|
||||
"""
|
||||
|
||||
def mock_get(url):
|
||||
raise Exception("connection failure")
|
||||
|
||||
monkeypatch.setattr(httpx, 'get', mock_get)
|
||||
monkeypatch.setattr(httpx, "get", mock_get)
|
||||
info = OllamaModelInfo()
|
||||
models = info.get_models()
|
||||
# Default static ollama_models is ['llama2'], so expect ['ollama/llama2']
|
||||
assert models == ['ollama/llama2']
|
||||
assert models == ["ollama/llama2"]
|
||||
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
|
||||
|
||||
@ -11,24 +12,20 @@ from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
("o4-mini", True),
|
||||
("o1-preview", True),
|
||||
("o3-mini", True),
|
||||
|
||||
# Valid O-series models with provider prefix
|
||||
("openai/o1", True),
|
||||
("openai/o3", True),
|
||||
("openai/o4-mini", True),
|
||||
("openai/o1-preview", True),
|
||||
("openai/o3-mini", True),
|
||||
|
||||
# Non-O-series models
|
||||
("gpt-4", False),
|
||||
("gpt-3.5-turbo", False),
|
||||
("claude-3-opus", False),
|
||||
|
||||
# Non-O-series models with provider prefix
|
||||
("openai/gpt-4", False),
|
||||
("openai/gpt-3.5-turbo", False),
|
||||
("anthropic/claude-3-opus", False),
|
||||
|
||||
# Edge cases
|
||||
("o", False), # Too short
|
||||
("o5", False), # Not a valid O-series model
|
||||
@ -39,11 +36,12 @@ from litellm.llms.openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
def test_is_model_o_series_model(model_name: str, expected: bool):
|
||||
"""
|
||||
Test that is_model_o_series_model correctly identifies O-series models.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: The model name to test
|
||||
expected: The expected result (True if it should be identified as an O-series model)
|
||||
"""
|
||||
config = OpenAIOSeriesConfig()
|
||||
assert config.is_model_o_series_model(model_name) == expected, \
|
||||
f"Expected {model_name} to be {'an O-series model' if expected else 'not an O-series model'}"
|
||||
assert (
|
||||
config.is_model_o_series_model(model_name) == expected
|
||||
), f"Expected {model_name} to be {'an O-series model' if expected else 'not an O-series model'}"
|
||||
@ -98,7 +98,6 @@ async def test_openai_client_reuse(function_name, is_async, args):
|
||||
) as mock_set_cache, patch.object(
|
||||
BaseOpenAILLM, "get_cached_openai_client"
|
||||
) as mock_get_cache:
|
||||
|
||||
# Setup the mock to return None first time (cache miss) then a client for subsequent calls
|
||||
mock_client = MagicMock()
|
||||
mock_get_cache.side_effect = [None] + [
|
||||
@ -3,15 +3,14 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.openrouter.chat.transformation import (
|
||||
OpenRouterChatCompletionStreamingHandler,
|
||||
OpenRouterException,
|
||||
OpenrouterConfig,
|
||||
OpenRouterException,
|
||||
)
|
||||
|
||||
|
||||
@ -26,11 +25,7 @@ class TestOpenRouterChatCompletionStreamingHandler:
|
||||
"id": "test_id",
|
||||
"created": 1234567890,
|
||||
"model": "test_model",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
},
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
"choices": [
|
||||
{"delta": {"content": "test content", "reasoning": "test reasoning"}}
|
||||
],
|
||||
@ -89,7 +84,6 @@ class TestOpenRouterChatCompletionStreamingHandler:
|
||||
|
||||
|
||||
def test_openrouter_extra_body_transformation():
|
||||
|
||||
transformed_request = OpenrouterConfig().transform_request(
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
@ -10,6 +10,7 @@ sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
|
||||
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiter_bytes_unicode_decode_error():
|
||||
"""
|
||||
@ -96,6 +97,7 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
|
||||
|
||||
|
||||
class TestSagemakerTransform:
|
||||
def setup_method(self):
|
||||
self.config = SagemakerConfig()
|
||||
@ -104,7 +106,11 @@ class TestSagemakerTransform:
|
||||
|
||||
def test_map_mistral_params(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
|
||||
test_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 200,
|
||||
"max_completion_tokens": 256,
|
||||
}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
@ -118,7 +124,10 @@ class TestSagemakerTransform:
|
||||
|
||||
def test_mistral_max_tokens_backward_compat(self):
|
||||
"""Test that parameters are correctly mapped"""
|
||||
test_params = {"temperature": 0.7, "max_tokens": 200,}
|
||||
test_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 200,
|
||||
}
|
||||
|
||||
result = self.config.map_openai_params(
|
||||
non_default_params=test_params,
|
||||
@ -1,34 +1,36 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexAIError,
|
||||
make_call,
|
||||
make_sync_call,
|
||||
VertexAIError,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Setup mock messages
|
||||
self.messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
|
||||
|
||||
# Setup mock data
|
||||
self.mock_data = json.dumps({"messages": self.messages})
|
||||
|
||||
|
||||
# Setup mock headers
|
||||
self.mock_headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
# Setup mock model
|
||||
self.mock_model = "gemini-pro"
|
||||
|
||||
|
||||
# Setup mock logging object
|
||||
self.mock_logging_obj = MagicMock()
|
||||
self.mock_logging_obj.post_call = MagicMock()
|
||||
|
||||
@patch("litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.get_async_httpx_client")
|
||||
@patch(
|
||||
"litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.get_async_httpx_client"
|
||||
)
|
||||
async def test_async_http_status_201(self, mock_get_client):
|
||||
"""Test that async make_call handles HTTP 201 status code correctly"""
|
||||
# Create a mock response with status code 201
|
||||
@ -36,13 +38,13 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
mock_response.status_code = 201
|
||||
mock_response.aiter_lines = MagicMock()
|
||||
mock_response.aiter_lines.return_value = ["test response"]
|
||||
|
||||
|
||||
# Setup mock client
|
||||
mock_client = MagicMock()
|
||||
mock_client.post = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
|
||||
# Call the make_call function
|
||||
result = await make_call(
|
||||
client=None,
|
||||
@ -51,15 +53,15 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
data=self.mock_data,
|
||||
model=self.mock_model,
|
||||
messages=self.messages,
|
||||
logging_obj=self.mock_logging_obj
|
||||
logging_obj=self.mock_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
# Assert that the post method was called
|
||||
mock_client.post.assert_called_once()
|
||||
|
||||
|
||||
# Assert that no error was raised for status code 201
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
|
||||
# Verify logging was called
|
||||
self.mock_logging_obj.post_call.assert_called_once()
|
||||
|
||||
@ -72,7 +74,7 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
mock_response.iter_lines = MagicMock()
|
||||
mock_response.iter_lines.return_value = ["test response"]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
|
||||
# Call the make_sync_call function
|
||||
result = make_sync_call(
|
||||
client=None,
|
||||
@ -82,12 +84,12 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
data=self.mock_data,
|
||||
model=self.mock_model,
|
||||
messages=self.messages,
|
||||
logging_obj=self.mock_logging_obj
|
||||
logging_obj=self.mock_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
# Assert that no error was raised for status code 201
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
|
||||
# Verify logging was called
|
||||
self.mock_logging_obj.post_call.assert_called_once()
|
||||
|
||||
@ -100,7 +102,7 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
mock_response.read = MagicMock(return_value=b"Bad Request")
|
||||
mock_response.headers = {}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
|
||||
# Call the make_sync_call function and expect an error
|
||||
with self.assertRaises(VertexAIError) as context:
|
||||
make_sync_call(
|
||||
@ -111,12 +113,12 @@ class TestVertexAIHTTPStatus201(unittest.TestCase):
|
||||
data=self.mock_data,
|
||||
model=self.mock_model,
|
||||
messages=self.messages,
|
||||
logging_obj=self.mock_logging_obj
|
||||
logging_obj=self.mock_logging_obj,
|
||||
)
|
||||
|
||||
|
||||
# Assert that the error has the correct status code
|
||||
self.assertEqual(context.exception.status_code, 400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
@ -1,9 +1,7 @@
|
||||
import base64
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@ -18,6 +16,7 @@ sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import get_optional_params
|
||||
from litellm.llms.vertex_ai.gemini.transformation import _process_gemini_image
|
||||
@ -31,6 +30,7 @@ def encode_image_to_base64(image_path):
|
||||
|
||||
def test_completion_pydantic_obj_2():
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
litellm.set_verbose = True
|
||||
@ -110,11 +110,10 @@ def test_completion_pydantic_obj_2():
|
||||
|
||||
|
||||
def test_build_vertex_schema():
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_build_vertex_schema,
|
||||
)
|
||||
import json
|
||||
|
||||
from litellm.llms.vertex_ai.common_utils import _build_vertex_schema
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"my-random-key": "my-random-value",
|
||||
@ -149,7 +148,6 @@ def test_build_vertex_schema():
|
||||
],
|
||||
)
|
||||
def test_vertex_tool_params(tools, key):
|
||||
|
||||
optional_params = get_optional_params(
|
||||
model="gemini-1.5-pro",
|
||||
custom_llm_provider="vertex_ai",
|
||||
@ -1124,7 +1122,6 @@ def test_logprobs():
|
||||
mock_response.json.return_value = response_body
|
||||
|
||||
with patch.object(client, "post", return_value=mock_response):
|
||||
|
||||
resp = litellm.completion(
|
||||
model="gemini/gemini-1.5-flash-002",
|
||||
messages=[
|
||||
@ -1140,9 +1137,7 @@ def test_logprobs():
|
||||
|
||||
def test_process_gemini_image():
|
||||
"""Test the _process_gemini_image function for different image sources"""
|
||||
from litellm.llms.vertex_ai.gemini.transformation import (
|
||||
_process_gemini_image,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.transformation import _process_gemini_image
|
||||
from litellm.types.llms.vertex_ai import FileDataType
|
||||
|
||||
# Test GCS URI
|
||||
@ -1266,9 +1261,10 @@ def test_vertex_embedding_url(model, expected_url):
|
||||
assert endpoint == "predict"
|
||||
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Add these fixtures below existing fixtures
|
||||
@pytest.fixture
|
||||
@ -1435,7 +1431,10 @@ def test_vertex_parallel_tool_calls_false_multiple_tools_error():
|
||||
tools=tools,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
assert "`parallel_tool_calls=False` is not supported when multiple tools are provided" in str(excinfo.value)
|
||||
assert (
|
||||
"`parallel_tool_calls=False` is not supported when multiple tools are provided"
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
# works when specified as "functions"
|
||||
with pytest.raises(litellm.utils.UnsupportedParamsError) as excinfo:
|
||||
@ -1445,7 +1444,10 @@ def test_vertex_parallel_tool_calls_false_multiple_tools_error():
|
||||
functions=tools,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
assert "`parallel_tool_calls=False` is not supported when multiple tools are provided" in str(excinfo.value)
|
||||
assert (
|
||||
"`parallel_tool_calls=False` is not supported when multiple tools are provided"
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
def test_vertex_parallel_tool_calls_false_single_tool():
|
||||
@ -13,11 +13,11 @@ sys.path.insert(
|
||||
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_get_vertex_url,
|
||||
convert_anyof_null_to_nullable,
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
set_schema_property_ordering,
|
||||
_get_vertex_url
|
||||
)
|
||||
|
||||
|
||||
@ -518,6 +518,7 @@ def test_vertex_ai_complex_response_schema():
|
||||
assert "additionalProperties" not in type3
|
||||
assert "additionalProperties" not in type3_prop3_items
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stream, expected_endpoint_suffix",
|
||||
[
|
||||
@ -537,7 +538,10 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
|
||||
|
||||
# Mock litellm.VertexGeminiConfig.get_model_for_vertex_ai_url to return model as is
|
||||
# as we are not testing that part here, just the URL construction
|
||||
with patch("litellm.VertexGeminiConfig.get_model_for_vertex_ai_url", side_effect=lambda model: model):
|
||||
with patch(
|
||||
"litellm.VertexGeminiConfig.get_model_for_vertex_ai_url",
|
||||
side_effect=lambda model: model,
|
||||
):
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
@ -548,7 +552,7 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
|
||||
)
|
||||
|
||||
expected_url_base = f"https://aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/global/publishers/google/models/{model}"
|
||||
|
||||
|
||||
if stream:
|
||||
expected_endpoint = "streamGenerateContent"
|
||||
expected_url = f"{expected_url_base}:{expected_endpoint}?alt=sse"
|
||||
@ -556,6 +560,5 @@ def test_get_vertex_url_global_region(stream, expected_endpoint_suffix):
|
||||
expected_endpoint = "generateContent"
|
||||
expected_url = f"{expected_url_base}:{expected_endpoint}"
|
||||
|
||||
|
||||
assert endpoint == expected_endpoint
|
||||
assert url == expected_url
|
||||
@ -24,38 +24,49 @@ class TestAnthropicEndpoints(unittest.TestCase):
|
||||
{"type": "content_block_delta", "delta": {"text": "more data"}},
|
||||
"text chunk data again",
|
||||
]
|
||||
|
||||
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
mock_request_data = {}
|
||||
mock_proxy_logging_obj = MagicMock()
|
||||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(side_effect=lambda **kwargs: kwargs["response"])
|
||||
|
||||
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
|
||||
side_effect=lambda **kwargs: kwargs["response"]
|
||||
)
|
||||
|
||||
# Configure safe_dumps to return a properly formatted JSON string
|
||||
mock_safe_dumps.side_effect = lambda chunk: json.dumps(chunk)
|
||||
|
||||
|
||||
# Execute
|
||||
result = [chunk async for chunk in async_data_generator_anthropic(
|
||||
response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
request_data=mock_request_data,
|
||||
proxy_logging_obj=mock_proxy_logging_obj,
|
||||
)]
|
||||
|
||||
result = [
|
||||
chunk
|
||||
async for chunk in async_data_generator_anthropic(
|
||||
response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
request_data=mock_request_data,
|
||||
proxy_logging_obj=mock_proxy_logging_obj,
|
||||
)
|
||||
]
|
||||
|
||||
# Verify
|
||||
expected_result = [
|
||||
'data: {"type": "message_start", "message": {"id": "msg_123"}}\n\n',
|
||||
'text chunk data',
|
||||
"text chunk data",
|
||||
'data: {"type": "content_block_delta", "delta": {"text": "more data"}}\n\n',
|
||||
'text chunk data again',
|
||||
"text chunk data again",
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
# Assert safe_dumps was called for dictionary objects
|
||||
mock_safe_dumps.assert_any_call({"type": "message_start", "message": {"id": "msg_123"}})
|
||||
mock_safe_dumps.assert_any_call({"type": "content_block_delta", "delta": {"text": "more data"}})
|
||||
assert mock_safe_dumps.call_count == 2 # Called twice, once for each dict object
|
||||
mock_safe_dumps.assert_any_call(
|
||||
{"type": "message_start", "message": {"id": "msg_123"}}
|
||||
)
|
||||
mock_safe_dumps.assert_any_call(
|
||||
{"type": "content_block_delta", "delta": {"text": "more data"}}
|
||||
)
|
||||
assert (
|
||||
mock_safe_dumps.call_count == 2
|
||||
) # Called twice, once for each dict object
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
1
tests/test_litellm/proxy/client/cli/__init__.py
Normal file
1
tests/test_litellm/proxy/client/cli/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for the LiteLLM Proxy Client CLI package."""
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@ -238,4 +238,4 @@ def test_chat_completions_all_parameters(cli_runner, mock_chat_client):
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.5,
|
||||
user="test-user",
|
||||
)
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user