[Feat] Add new Rag Search API / Query API with rerankers (#18217)

* init RAGQueryRequest

* init RAGQuery

* fix query

* fix _execute_query_pipeline

* TestRAGOpenAI
This commit is contained in:
Ishaan Jaff 2025-12-19 19:05:07 +05:30 committed by GitHub
parent b0db9d6bb7
commit deb8d16967
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 443 additions and 5 deletions

View File

@ -5,9 +5,9 @@ Provides an all-in-one API for document ingestion:
Upload -> (OCR) -> Chunk -> Embed -> Vector Store
"""
from litellm.rag.main import aingest, ingest
from litellm.rag.main import aingest, aquery, ingest, query
__all__ = ["ingest", "aingest"]
__all__ = ["ingest", "aingest", "query", "aquery"]
# Expose at litellm.rag level for convenience

View File

@ -7,12 +7,22 @@ Upload -> (OCR) -> Chunk -> Embed -> Vector Store
from __future__ import annotations
__all__ = ["ingest", "aingest"]
__all__ = ["ingest", "aingest", "query", "aquery"]
import asyncio
import contextvars
from functools import partial
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
import httpx
@ -21,7 +31,14 @@ from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
from litellm.types.rag import RAGIngestOptions, RAGIngestResponse
from litellm.rag.rag_query import RAGQuery
from litellm.types.rag import (
RAGIngestOptions,
RAGIngestResponse,
RAGQueryRequest,
RAGQueryResponse,
)
from litellm.types.utils import ModelResponse
from litellm.utils import client
if TYPE_CHECKING:
@ -172,6 +189,163 @@ async def aingest(
)
async def _execute_query_pipeline(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> ModelResponse:
"""
Execute the RAG query pipeline.
"""
# 1. Extract query from last user message
query_text = RAGQuery.extract_query_from_messages(messages)
if not query_text:
raise ValueError("No query found in messages for RAG query")
# 2. Search vector store
search_response = await litellm.vector_stores.asearch(
vector_store_id=retrieval_config["vector_store_id"],
query=query_text,
max_num_results=retrieval_config.get("top_k", 10),
custom_llm_provider=retrieval_config.get("custom_llm_provider", "openai"),
**kwargs,
)
rerank_response = None
context_chunks = search_response.get("data", [])
# 3. Optional rerank
if rerank and rerank.get("enabled"):
documents = RAGQuery.extract_documents_from_search(search_response)
if documents:
rerank_response = await litellm.arerank(
model=rerank["model"],
query=query_text,
documents=documents,
top_n=rerank.get("top_n", 5),
)
context_chunks = RAGQuery.get_top_chunks_from_rerank(
search_response, rerank_response
)
# 4. Build context message and call completion
context_message = RAGQuery.build_context_message(context_chunks)
modified_messages = messages[:-1] + [context_message] + [messages[-1]]
response = await litellm.acompletion(
model=model,
messages=modified_messages,
stream=stream,
**kwargs,
)
# 5. Attach search results to response
if not stream and isinstance(response, ModelResponse):
response = RAGQuery.add_search_results_to_response(
response=response,
search_results=search_response,
rerank_results=rerank_response,
)
return response # type: ignore[return-value]
@client
async def aquery(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> ModelResponse:
"""
Async: Query a RAG pipeline.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["aquery"] = True
func = partial(
query,
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def query(
model: str,
messages: List[Any],
retrieval_config: Dict[str, Any],
rerank: Optional[Dict[str, Any]] = None,
stream: bool = False,
**kwargs,
) -> Union[ModelResponse, Coroutine[Any, Any, ModelResponse]]:
"""
Query a RAG pipeline.
"""
local_vars = locals()
try:
_is_async = kwargs.pop("aquery", False) is True
if _is_async:
return _execute_query_pipeline(
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
else:
return asyncio.get_event_loop().run_until_complete(
_execute_query_pipeline(
model=model,
messages=messages,
retrieval_config=retrieval_config,
rerank=rerank,
stream=stream,
**kwargs,
)
)
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=retrieval_config.get("custom_llm_provider"),
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def ingest(
ingest_options: Dict[str, Any],

120
litellm/rag/rag_query.py Normal file
View File

@ -0,0 +1,120 @@
from typing import Any, Dict, List, Optional, Union, cast
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import ModelResponse
from litellm.types.vector_stores import (
VectorStoreResultContent,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
class RAGQuery:
CONTENT_PREFIX_STRING = "Context:\n\n"
@staticmethod
def extract_query_from_messages(messages: List[AllMessageValues]) -> Optional[str]:
"""
Extract the query from the last user message.
"""
if not messages or len(messages) == 0:
return None
last_message = messages[-1]
if not isinstance(last_message, dict) or "content" not in last_message:
return None
content = last_message["content"]
if isinstance(content, str):
return content
elif isinstance(content, list) and len(content) > 0:
# Handle list of content items, extract text from first text item
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "text"
and "text" in item
):
return item["text"]
return None
@staticmethod
def build_context_message(context_chunks: List[Any]) -> ChatCompletionUserMessage:
"""
Process search results and build a context message.
"""
context_content = RAGQuery.CONTENT_PREFIX_STRING
for chunk in context_chunks:
if isinstance(chunk, dict):
result_content: Optional[List[VectorStoreResultContent]] = chunk.get(
"content"
)
if result_content:
for content_item in result_content:
content_text: Optional[str] = content_item.get("text")
if content_text:
context_content += content_text + "\n\n"
elif "text" in chunk: # Fallback for simple dict with text
context_content += chunk["text"] + "\n\n"
elif isinstance(chunk, str):
context_content += chunk + "\n\n"
return {
"role": "user",
"content": context_content,
}
@staticmethod
def add_search_results_to_response(
response: ModelResponse,
search_results: VectorStoreSearchResponse,
rerank_results: Optional[Any] = None,
) -> ModelResponse:
"""
Add search results to the response choices.
"""
if hasattr(response, "choices") and response.choices:
for choice in response.choices:
message = getattr(choice, "message", None)
if message is not None:
# Get existing provider_specific_fields or create new dict
provider_fields = (
getattr(message, "provider_specific_fields", None) or {}
)
# Add search results
provider_fields["search_results"] = search_results
if rerank_results:
provider_fields["rerank_results"] = rerank_results
# Set the provider_specific_fields
setattr(message, "provider_specific_fields", provider_fields)
return response
@staticmethod
def extract_documents_from_search(
search_response: Any,
) -> List[Union[str, Dict[str, Any]]]:
"""Extract text documents from vector store search response."""
documents: List[Union[str, Dict[str, Any]]] = []
for result in search_response.get("data", []):
content_list = result.get("content", [])
for content in content_list:
if content.get("type") == "text" and content.get("text"):
documents.append(content["text"])
return documents
@staticmethod
def get_top_chunks_from_rerank(search_response: Any, rerank_response: Any) -> List[Any]:
"""Get the original search results corresponding to the top reranked results."""
top_chunks = []
original_results = search_response.get("data", [])
for result in rerank_response.get("results", []):
index = result.get("index")
if index is not None and index < len(original_results):
top_chunks.append(original_results[index])
return top_chunks

View File

@ -7,6 +7,8 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from typing_extensions import TypedDict
from litellm.types.utils import ModelResponse
class RAGChunkingStrategy(TypedDict, total=False):
"""
@ -187,3 +189,39 @@ class RAGIngestRequest(BaseModel):
model_config = ConfigDict(extra="allow") # Allow additional fields
class RAGRetrievalConfig(TypedDict, total=False):
"""Configuration for vector store retrieval."""
vector_store_id: str
custom_llm_provider: str
top_k: int # max results from vector store
filters: Optional[Dict[str, Any]] # optional - vector store filters
class RAGRerankConfig(TypedDict, total=False):
"""Configuration for reranking results."""
enabled: bool
model: str
top_n: int # final number of chunks after reranking
return_documents: Optional[bool]
class RAGQueryRequest(BaseModel):
"""Request body for RAG query API."""
model: str
messages: List[Any]
retrieval_config: RAGRetrievalConfig
rerank: Optional[RAGRerankConfig] = None
stream: Optional[bool] = False
model_config = ConfigDict(extra="allow")
class RAGQueryResponse(ModelResponse):
"""Response from RAG query API."""
pass

View File

@ -42,4 +42,110 @@ class TestRAGOpenAI(BaseRAGTest):
return search_response
return None
@pytest.mark.asyncio
async def test_rag_query_basic(self):
"""Test basic RAG query flow."""
import asyncio
litellm._turn_on_debug()
# First ingest a document
filename, unique_id = self.get_unique_filename("rag_query")
text_content = (
f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode()
)
ingest_response = await litellm.rag.aingest(
ingest_options=self.get_base_ingest_options(),
file_data=(filename, text_content, "text/plain"),
)
# Check if ingestion succeeded
if ingest_response["status"] != "completed":
pytest.fail(
f"Ingestion failed with status: {ingest_response['status']}, "
f"error: {ingest_response.get('error', 'Unknown')}"
)
vector_store_id = ingest_response["vector_store_id"]
assert vector_store_id, "vector_store_id should not be empty"
# Wait for indexing
await asyncio.sleep(10)
# Query with RAG
response = await litellm.rag.aquery(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "What is LiteLLM?"}],
retrieval_config={
"vector_store_id": vector_store_id,
"custom_llm_provider": "openai",
"top_k": 5,
},
)
print(f"RAG Query Response: {response}")
assert response.choices[0].message.content
assert (
"search_results" in response.choices[0].message.provider_specific_fields
)
@pytest.mark.asyncio
async def test_rag_query_with_rerank(self):
"""Test RAG query with reranking."""
import asyncio
litellm._turn_on_debug()
# First ingest a document
filename, unique_id = self.get_unique_filename("rag_query_rerank")
text_content = (
f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode()
)
ingest_response = await litellm.rag.aingest(
ingest_options=self.get_base_ingest_options(),
file_data=(filename, text_content, "text/plain"),
)
# Check if ingestion succeeded
if ingest_response["status"] != "completed":
pytest.fail(
f"Ingestion failed with status: {ingest_response['status']}, "
f"error: {ingest_response.get('error', 'Unknown')}"
)
vector_store_id = ingest_response["vector_store_id"]
assert vector_store_id, "vector_store_id should not be empty"
# Wait for indexing
await asyncio.sleep(10)
# Query with RAG and rerank
response = await litellm.rag.aquery(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "What is LiteLLM?"}],
retrieval_config={
"vector_store_id": vector_store_id,
"custom_llm_provider": "openai",
"top_k": 5,
},
rerank={
"enabled": True,
"model": "cohere/rerank-english-v3.0",
"top_n": 3,
},
)
print(f"RAG Query Response with Rerank: {response.model_dump_json(indent=4)}")
assert response.choices[0].message.content
assert (
"search_results" in response.choices[0].message.provider_specific_fields
)
assert (
"rerank_results" in response.choices[0].message.provider_specific_fields
)