[Feat] Add new Rag Search API / Query API with rerankers (#18217)
* init RAGQueryRequest * init RAGQuery * fix query * fix _execute_query_pipeline * TestRAGOpenAI
This commit is contained in:
parent
b0db9d6bb7
commit
deb8d16967
@ -5,9 +5,9 @@ Provides an all-in-one API for document ingestion:
|
||||
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
|
||||
"""
|
||||
|
||||
from litellm.rag.main import aingest, ingest
|
||||
from litellm.rag.main import aingest, aquery, ingest, query
|
||||
|
||||
__all__ = ["ingest", "aingest"]
|
||||
__all__ = ["ingest", "aingest", "query", "aquery"]
|
||||
|
||||
|
||||
# Expose at litellm.rag level for convenience
|
||||
|
||||
@ -7,12 +7,22 @@ Upload -> (OCR) -> Chunk -> Embed -> Vector Store
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["ingest", "aingest"]
|
||||
__all__ = ["ingest", "aingest", "query", "aquery"]
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
@ -21,7 +31,14 @@ from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
|
||||
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
|
||||
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
|
||||
from litellm.types.rag import RAGIngestOptions, RAGIngestResponse
|
||||
from litellm.rag.rag_query import RAGQuery
|
||||
from litellm.types.rag import (
|
||||
RAGIngestOptions,
|
||||
RAGIngestResponse,
|
||||
RAGQueryRequest,
|
||||
RAGQueryResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -172,6 +189,163 @@ async def aingest(
|
||||
)
|
||||
|
||||
|
||||
async def _execute_query_pipeline(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Execute the RAG query pipeline.
|
||||
"""
|
||||
# 1. Extract query from last user message
|
||||
query_text = RAGQuery.extract_query_from_messages(messages)
|
||||
if not query_text:
|
||||
raise ValueError("No query found in messages for RAG query")
|
||||
|
||||
# 2. Search vector store
|
||||
search_response = await litellm.vector_stores.asearch(
|
||||
vector_store_id=retrieval_config["vector_store_id"],
|
||||
query=query_text,
|
||||
max_num_results=retrieval_config.get("top_k", 10),
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider", "openai"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rerank_response = None
|
||||
context_chunks = search_response.get("data", [])
|
||||
|
||||
# 3. Optional rerank
|
||||
if rerank and rerank.get("enabled"):
|
||||
documents = RAGQuery.extract_documents_from_search(search_response)
|
||||
if documents:
|
||||
rerank_response = await litellm.arerank(
|
||||
model=rerank["model"],
|
||||
query=query_text,
|
||||
documents=documents,
|
||||
top_n=rerank.get("top_n", 5),
|
||||
)
|
||||
context_chunks = RAGQuery.get_top_chunks_from_rerank(
|
||||
search_response, rerank_response
|
||||
)
|
||||
|
||||
# 4. Build context message and call completion
|
||||
context_message = RAGQuery.build_context_message(context_chunks)
|
||||
modified_messages = messages[:-1] + [context_message] + [messages[-1]]
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=modified_messages,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 5. Attach search results to response
|
||||
if not stream and isinstance(response, ModelResponse):
|
||||
response = RAGQuery.add_search_results_to_response(
|
||||
response=response,
|
||||
search_results=search_response,
|
||||
rerank_results=rerank_response,
|
||||
)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
|
||||
@client
|
||||
async def aquery(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Async: Query a RAG pipeline.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aquery"] = True
|
||||
|
||||
func = partial(
|
||||
query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def query(
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
retrieval_config: Dict[str, Any],
|
||||
rerank: Optional[Dict[str, Any]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[ModelResponse, Coroutine[Any, Any, ModelResponse]]:
|
||||
"""
|
||||
Query a RAG pipeline.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
_is_async = kwargs.pop("aquery", False) is True
|
||||
|
||||
if _is_async:
|
||||
return _execute_query_pipeline(
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
_execute_query_pipeline(
|
||||
model=model,
|
||||
messages=messages,
|
||||
retrieval_config=retrieval_config,
|
||||
rerank=rerank,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def ingest(
|
||||
ingest_options: Dict[str, Any],
|
||||
|
||||
120
litellm/rag/rag_query.py
Normal file
120
litellm/rag/rag_query.py
Normal file
@ -0,0 +1,120 @@
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
|
||||
class RAGQuery:
|
||||
CONTENT_PREFIX_STRING = "Context:\n\n"
|
||||
|
||||
@staticmethod
|
||||
def extract_query_from_messages(messages: List[AllMessageValues]) -> Optional[str]:
|
||||
"""
|
||||
Extract the query from the last user message.
|
||||
"""
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
if not isinstance(last_message, dict) or "content" not in last_message:
|
||||
return None
|
||||
|
||||
content = last_message["content"]
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle list of content items, extract text from first text item
|
||||
for item in content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
and "text" in item
|
||||
):
|
||||
return item["text"]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def build_context_message(context_chunks: List[Any]) -> ChatCompletionUserMessage:
|
||||
"""
|
||||
Process search results and build a context message.
|
||||
"""
|
||||
context_content = RAGQuery.CONTENT_PREFIX_STRING
|
||||
|
||||
for chunk in context_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
result_content: Optional[List[VectorStoreResultContent]] = chunk.get(
|
||||
"content"
|
||||
)
|
||||
if result_content:
|
||||
for content_item in result_content:
|
||||
content_text: Optional[str] = content_item.get("text")
|
||||
if content_text:
|
||||
context_content += content_text + "\n\n"
|
||||
elif "text" in chunk: # Fallback for simple dict with text
|
||||
context_content += chunk["text"] + "\n\n"
|
||||
elif isinstance(chunk, str):
|
||||
context_content += chunk + "\n\n"
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": context_content,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def add_search_results_to_response(
|
||||
response: ModelResponse,
|
||||
search_results: VectorStoreSearchResponse,
|
||||
rerank_results: Optional[Any] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Add search results to the response choices.
|
||||
"""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
for choice in response.choices:
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None:
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
provider_fields = (
|
||||
getattr(message, "provider_specific_fields", None) or {}
|
||||
)
|
||||
|
||||
# Add search results
|
||||
provider_fields["search_results"] = search_results
|
||||
if rerank_results:
|
||||
provider_fields["rerank_results"] = rerank_results
|
||||
|
||||
# Set the provider_specific_fields
|
||||
setattr(message, "provider_specific_fields", provider_fields)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def extract_documents_from_search(
|
||||
search_response: Any,
|
||||
) -> List[Union[str, Dict[str, Any]]]:
|
||||
"""Extract text documents from vector store search response."""
|
||||
documents: List[Union[str, Dict[str, Any]]] = []
|
||||
for result in search_response.get("data", []):
|
||||
content_list = result.get("content", [])
|
||||
for content in content_list:
|
||||
if content.get("type") == "text" and content.get("text"):
|
||||
documents.append(content["text"])
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_top_chunks_from_rerank(search_response: Any, rerank_response: Any) -> List[Any]:
|
||||
"""Get the original search results corresponding to the top reranked results."""
|
||||
top_chunks = []
|
||||
original_results = search_response.get("data", [])
|
||||
for result in rerank_response.get("results", []):
|
||||
index = result.get("index")
|
||||
if index is not None and index < len(original_results):
|
||||
top_chunks.append(original_results[index])
|
||||
return top_chunks
|
||||
@ -7,6 +7,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class RAGChunkingStrategy(TypedDict, total=False):
|
||||
"""
|
||||
@ -187,3 +189,39 @@ class RAGIngestRequest(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="allow") # Allow additional fields
|
||||
|
||||
|
||||
class RAGRetrievalConfig(TypedDict, total=False):
|
||||
"""Configuration for vector store retrieval."""
|
||||
|
||||
vector_store_id: str
|
||||
custom_llm_provider: str
|
||||
top_k: int # max results from vector store
|
||||
filters: Optional[Dict[str, Any]] # optional - vector store filters
|
||||
|
||||
|
||||
class RAGRerankConfig(TypedDict, total=False):
|
||||
"""Configuration for reranking results."""
|
||||
|
||||
enabled: bool
|
||||
model: str
|
||||
top_n: int # final number of chunks after reranking
|
||||
return_documents: Optional[bool]
|
||||
|
||||
|
||||
class RAGQueryRequest(BaseModel):
|
||||
"""Request body for RAG query API."""
|
||||
|
||||
model: str
|
||||
messages: List[Any]
|
||||
retrieval_config: RAGRetrievalConfig
|
||||
rerank: Optional[RAGRerankConfig] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class RAGQueryResponse(ModelResponse):
|
||||
"""Response from RAG query API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@ -42,4 +42,110 @@ class TestRAGOpenAI(BaseRAGTest):
|
||||
return search_response
|
||||
return None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_query_basic(self):
|
||||
"""Test basic RAG query flow."""
|
||||
import asyncio
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# First ingest a document
|
||||
filename, unique_id = self.get_unique_filename("rag_query")
|
||||
text_content = (
|
||||
f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode()
|
||||
)
|
||||
|
||||
ingest_response = await litellm.rag.aingest(
|
||||
ingest_options=self.get_base_ingest_options(),
|
||||
file_data=(filename, text_content, "text/plain"),
|
||||
)
|
||||
|
||||
# Check if ingestion succeeded
|
||||
if ingest_response["status"] != "completed":
|
||||
pytest.fail(
|
||||
f"Ingestion failed with status: {ingest_response['status']}, "
|
||||
f"error: {ingest_response.get('error', 'Unknown')}"
|
||||
)
|
||||
|
||||
vector_store_id = ingest_response["vector_store_id"]
|
||||
assert vector_store_id, "vector_store_id should not be empty"
|
||||
|
||||
# Wait for indexing
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Query with RAG
|
||||
response = await litellm.rag.aquery(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "What is LiteLLM?"}],
|
||||
retrieval_config={
|
||||
"vector_store_id": vector_store_id,
|
||||
"custom_llm_provider": "openai",
|
||||
"top_k": 5,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"RAG Query Response: {response}")
|
||||
|
||||
assert response.choices[0].message.content
|
||||
assert (
|
||||
"search_results" in response.choices[0].message.provider_specific_fields
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_query_with_rerank(self):
|
||||
"""Test RAG query with reranking."""
|
||||
import asyncio
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# First ingest a document
|
||||
filename, unique_id = self.get_unique_filename("rag_query_rerank")
|
||||
text_content = (
|
||||
f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode()
|
||||
)
|
||||
|
||||
ingest_response = await litellm.rag.aingest(
|
||||
ingest_options=self.get_base_ingest_options(),
|
||||
file_data=(filename, text_content, "text/plain"),
|
||||
)
|
||||
|
||||
# Check if ingestion succeeded
|
||||
if ingest_response["status"] != "completed":
|
||||
pytest.fail(
|
||||
f"Ingestion failed with status: {ingest_response['status']}, "
|
||||
f"error: {ingest_response.get('error', 'Unknown')}"
|
||||
)
|
||||
|
||||
vector_store_id = ingest_response["vector_store_id"]
|
||||
assert vector_store_id, "vector_store_id should not be empty"
|
||||
|
||||
# Wait for indexing
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Query with RAG and rerank
|
||||
response = await litellm.rag.aquery(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "What is LiteLLM?"}],
|
||||
retrieval_config={
|
||||
"vector_store_id": vector_store_id,
|
||||
"custom_llm_provider": "openai",
|
||||
"top_k": 5,
|
||||
},
|
||||
rerank={
|
||||
"enabled": True,
|
||||
"model": "cohere/rerank-english-v3.0",
|
||||
"top_n": 3,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"RAG Query Response with Rerank: {response.model_dump_json(indent=4)}")
|
||||
|
||||
assert response.choices[0].message.content
|
||||
assert (
|
||||
"search_results" in response.choices[0].message.provider_specific_fields
|
||||
)
|
||||
assert (
|
||||
"rerank_results" in response.choices[0].message.provider_specific_fields
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user