diff --git a/litellm/rag/__init__.py b/litellm/rag/__init__.py index f87e72f0c1..54f4d3ccaa 100644 --- a/litellm/rag/__init__.py +++ b/litellm/rag/__init__.py @@ -5,9 +5,9 @@ Provides an all-in-one API for document ingestion: Upload -> (OCR) -> Chunk -> Embed -> Vector Store """ -from litellm.rag.main import aingest, ingest +from litellm.rag.main import aingest, aquery, ingest, query -__all__ = ["ingest", "aingest"] +__all__ = ["ingest", "aingest", "query", "aquery"] # Expose at litellm.rag level for convenience diff --git a/litellm/rag/main.py b/litellm/rag/main.py index e7a9d3a241..b8461a8daa 100644 --- a/litellm/rag/main.py +++ b/litellm/rag/main.py @@ -7,12 +7,22 @@ Upload -> (OCR) -> Chunk -> Embed -> Vector Store from __future__ import annotations -__all__ = ["ingest", "aingest"] +__all__ = ["ingest", "aingest", "query", "aquery"] import asyncio import contextvars from functools import partial -from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Coroutine, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) import httpx @@ -21,7 +31,14 @@ from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion from litellm.rag.ingestion.gemini_ingestion import GeminiRAGIngestion from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion -from litellm.types.rag import RAGIngestOptions, RAGIngestResponse +from litellm.rag.rag_query import RAGQuery +from litellm.types.rag import ( + RAGIngestOptions, + RAGIngestResponse, + RAGQueryRequest, + RAGQueryResponse, +) +from litellm.types.utils import ModelResponse from litellm.utils import client if TYPE_CHECKING: @@ -172,6 +189,163 @@ async def aingest( ) +async def _execute_query_pipeline( + model: str, + messages: List[Any], + retrieval_config: Dict[str, Any], + rerank: Optional[Dict[str, Any]] = None, + stream: bool = False, + **kwargs, +) -> ModelResponse: + """ + Execute the RAG query pipeline. + """ + # 1. Extract query from last user message + query_text = RAGQuery.extract_query_from_messages(messages) + if not query_text: + raise ValueError("No query found in messages for RAG query") + + # 2. Search vector store + search_response = await litellm.vector_stores.asearch( + vector_store_id=retrieval_config["vector_store_id"], + query=query_text, + max_num_results=retrieval_config.get("top_k", 10), + custom_llm_provider=retrieval_config.get("custom_llm_provider", "openai"), + **kwargs, + ) + + rerank_response = None + context_chunks = search_response.get("data", []) + + # 3. Optional rerank + if rerank and rerank.get("enabled"): + documents = RAGQuery.extract_documents_from_search(search_response) + if documents: + rerank_response = await litellm.arerank( + model=rerank["model"], + query=query_text, + documents=documents, + top_n=rerank.get("top_n", 5), + ) + context_chunks = RAGQuery.get_top_chunks_from_rerank( + search_response, rerank_response + ) + + # 4. Build context message and call completion + context_message = RAGQuery.build_context_message(context_chunks) + modified_messages = messages[:-1] + [context_message] + [messages[-1]] + + response = await litellm.acompletion( + model=model, + messages=modified_messages, + stream=stream, + **kwargs, + ) + + # 5. Attach search results to response + if not stream and isinstance(response, ModelResponse): + response = RAGQuery.add_search_results_to_response( + response=response, + search_results=search_response, + rerank_results=rerank_response, + ) + + return response # type: ignore[return-value] + + +@client +async def aquery( + model: str, + messages: List[Any], + retrieval_config: Dict[str, Any], + rerank: Optional[Dict[str, Any]] = None, + stream: bool = False, + **kwargs, +) -> ModelResponse: + """ + Async: Query a RAG pipeline. + """ + local_vars = locals() + try: + loop = asyncio.get_event_loop() + kwargs["aquery"] = True + + func = partial( + query, + model=model, + messages=messages, + retrieval_config=retrieval_config, + rerank=rerank, + stream=stream, + **kwargs, + ) + + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + + return response + except Exception as e: + raise litellm.exception_type( + model=model, + custom_llm_provider=retrieval_config.get("custom_llm_provider"), + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) + + +@client +def query( + model: str, + messages: List[Any], + retrieval_config: Dict[str, Any], + rerank: Optional[Dict[str, Any]] = None, + stream: bool = False, + **kwargs, +) -> Union[ModelResponse, Coroutine[Any, Any, ModelResponse]]: + """ + Query a RAG pipeline. + """ + local_vars = locals() + try: + _is_async = kwargs.pop("aquery", False) is True + + if _is_async: + return _execute_query_pipeline( + model=model, + messages=messages, + retrieval_config=retrieval_config, + rerank=rerank, + stream=stream, + **kwargs, + ) + else: + return asyncio.get_event_loop().run_until_complete( + _execute_query_pipeline( + model=model, + messages=messages, + retrieval_config=retrieval_config, + rerank=rerank, + stream=stream, + **kwargs, + ) + ) + except Exception as e: + raise litellm.exception_type( + model=model, + custom_llm_provider=retrieval_config.get("custom_llm_provider"), + original_exception=e, + completion_kwargs=local_vars, + extra_kwargs=kwargs, + ) + + @client def ingest( ingest_options: Dict[str, Any], diff --git a/litellm/rag/rag_query.py b/litellm/rag/rag_query.py new file mode 100644 index 0000000000..53cc6d0089 --- /dev/null +++ b/litellm/rag/rag_query.py @@ -0,0 +1,120 @@ +from typing import Any, Dict, List, Optional, Union, cast + +import litellm +from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage +from litellm.types.utils import ModelResponse +from litellm.types.vector_stores import ( + VectorStoreResultContent, + VectorStoreSearchResponse, + VectorStoreSearchResult, +) + + +class RAGQuery: + CONTENT_PREFIX_STRING = "Context:\n\n" + + @staticmethod + def extract_query_from_messages(messages: List[AllMessageValues]) -> Optional[str]: + """ + Extract the query from the last user message. + """ + if not messages or len(messages) == 0: + return None + + last_message = messages[-1] + if not isinstance(last_message, dict) or "content" not in last_message: + return None + + content = last_message["content"] + + if isinstance(content, str): + return content + elif isinstance(content, list) and len(content) > 0: + # Handle list of content items, extract text from first text item + for item in content: + if ( + isinstance(item, dict) + and item.get("type") == "text" + and "text" in item + ): + return item["text"] + + return None + + @staticmethod + def build_context_message(context_chunks: List[Any]) -> ChatCompletionUserMessage: + """ + Process search results and build a context message. + """ + context_content = RAGQuery.CONTENT_PREFIX_STRING + + for chunk in context_chunks: + if isinstance(chunk, dict): + result_content: Optional[List[VectorStoreResultContent]] = chunk.get( + "content" + ) + if result_content: + for content_item in result_content: + content_text: Optional[str] = content_item.get("text") + if content_text: + context_content += content_text + "\n\n" + elif "text" in chunk: # Fallback for simple dict with text + context_content += chunk["text"] + "\n\n" + elif isinstance(chunk, str): + context_content += chunk + "\n\n" + + return { + "role": "user", + "content": context_content, + } + + @staticmethod + def add_search_results_to_response( + response: ModelResponse, + search_results: VectorStoreSearchResponse, + rerank_results: Optional[Any] = None, + ) -> ModelResponse: + """ + Add search results to the response choices. + """ + if hasattr(response, "choices") and response.choices: + for choice in response.choices: + message = getattr(choice, "message", None) + if message is not None: + # Get existing provider_specific_fields or create new dict + provider_fields = ( + getattr(message, "provider_specific_fields", None) or {} + ) + + # Add search results + provider_fields["search_results"] = search_results + if rerank_results: + provider_fields["rerank_results"] = rerank_results + + # Set the provider_specific_fields + setattr(message, "provider_specific_fields", provider_fields) + return response + + @staticmethod + def extract_documents_from_search( + search_response: Any, + ) -> List[Union[str, Dict[str, Any]]]: + """Extract text documents from vector store search response.""" + documents: List[Union[str, Dict[str, Any]]] = [] + for result in search_response.get("data", []): + content_list = result.get("content", []) + for content in content_list: + if content.get("type") == "text" and content.get("text"): + documents.append(content["text"]) + return documents + + @staticmethod + def get_top_chunks_from_rerank(search_response: Any, rerank_response: Any) -> List[Any]: + """Get the original search results corresponding to the top reranked results.""" + top_chunks = [] + original_results = search_response.get("data", []) + for result in rerank_response.get("results", []): + index = result.get("index") + if index is not None and index < len(original_results): + top_chunks.append(original_results[index]) + return top_chunks diff --git a/litellm/types/rag.py b/litellm/types/rag.py index dd724ca217..fe237a1343 100644 --- a/litellm/types/rag.py +++ b/litellm/types/rag.py @@ -7,6 +7,8 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from typing_extensions import TypedDict +from litellm.types.utils import ModelResponse + class RAGChunkingStrategy(TypedDict, total=False): """ @@ -187,3 +189,39 @@ class RAGIngestRequest(BaseModel): model_config = ConfigDict(extra="allow") # Allow additional fields + +class RAGRetrievalConfig(TypedDict, total=False): + """Configuration for vector store retrieval.""" + + vector_store_id: str + custom_llm_provider: str + top_k: int # max results from vector store + filters: Optional[Dict[str, Any]] # optional - vector store filters + + +class RAGRerankConfig(TypedDict, total=False): + """Configuration for reranking results.""" + + enabled: bool + model: str + top_n: int # final number of chunks after reranking + return_documents: Optional[bool] + + +class RAGQueryRequest(BaseModel): + """Request body for RAG query API.""" + + model: str + messages: List[Any] + retrieval_config: RAGRetrievalConfig + rerank: Optional[RAGRerankConfig] = None + stream: Optional[bool] = False + + model_config = ConfigDict(extra="allow") + + +class RAGQueryResponse(ModelResponse): + """Response from RAG query API.""" + + pass + diff --git a/tests/vector_store_tests/rag/test_rag_openai.py b/tests/vector_store_tests/rag/test_rag_openai.py index d077ebe0cb..a9cffa3776 100644 --- a/tests/vector_store_tests/rag/test_rag_openai.py +++ b/tests/vector_store_tests/rag/test_rag_openai.py @@ -42,4 +42,110 @@ class TestRAGOpenAI(BaseRAGTest): return search_response return None + @pytest.mark.asyncio + async def test_rag_query_basic(self): + """Test basic RAG query flow.""" + import asyncio + + litellm._turn_on_debug() + + # First ingest a document + filename, unique_id = self.get_unique_filename("rag_query") + text_content = ( + f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode() + ) + + ingest_response = await litellm.rag.aingest( + ingest_options=self.get_base_ingest_options(), + file_data=(filename, text_content, "text/plain"), + ) + + # Check if ingestion succeeded + if ingest_response["status"] != "completed": + pytest.fail( + f"Ingestion failed with status: {ingest_response['status']}, " + f"error: {ingest_response.get('error', 'Unknown')}" + ) + + vector_store_id = ingest_response["vector_store_id"] + assert vector_store_id, "vector_store_id should not be empty" + + # Wait for indexing + await asyncio.sleep(10) + + # Query with RAG + response = await litellm.rag.aquery( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "What is LiteLLM?"}], + retrieval_config={ + "vector_store_id": vector_store_id, + "custom_llm_provider": "openai", + "top_k": 5, + }, + ) + + print(f"RAG Query Response: {response}") + + assert response.choices[0].message.content + assert ( + "search_results" in response.choices[0].message.provider_specific_fields + ) + + @pytest.mark.asyncio + async def test_rag_query_with_rerank(self): + """Test RAG query with reranking.""" + import asyncio + + litellm._turn_on_debug() + + # First ingest a document + filename, unique_id = self.get_unique_filename("rag_query_rerank") + text_content = ( + f"LiteLLM is a unified interface for 100+ LLMs. ID: {unique_id}".encode() + ) + + ingest_response = await litellm.rag.aingest( + ingest_options=self.get_base_ingest_options(), + file_data=(filename, text_content, "text/plain"), + ) + + # Check if ingestion succeeded + if ingest_response["status"] != "completed": + pytest.fail( + f"Ingestion failed with status: {ingest_response['status']}, " + f"error: {ingest_response.get('error', 'Unknown')}" + ) + + vector_store_id = ingest_response["vector_store_id"] + assert vector_store_id, "vector_store_id should not be empty" + + # Wait for indexing + await asyncio.sleep(10) + + # Query with RAG and rerank + response = await litellm.rag.aquery( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "What is LiteLLM?"}], + retrieval_config={ + "vector_store_id": vector_store_id, + "custom_llm_provider": "openai", + "top_k": 5, + }, + rerank={ + "enabled": True, + "model": "cohere/rerank-english-v3.0", + "top_n": 3, + }, + ) + + print(f"RAG Query Response with Rerank: {response.model_dump_json(indent=4)}") + + assert response.choices[0].message.content + assert ( + "search_results" in response.choices[0].message.provider_specific_fields + ) + assert ( + "rerank_results" in response.choices[0].message.provider_specific_fields + ) + \ No newline at end of file