forked from open-webui/open-webui
		
	refac: rag pipeline
This commit is contained in:
		
							parent
							
								
									8f1563a7a5
								
							
						
					
					
						commit
						ce9a5d12e0
					
				
					 3 changed files with 179 additions and 154 deletions
				
			
		|  | @ -47,9 +47,11 @@ from apps.web.models.documents import ( | |||
| 
 | ||||
| from apps.rag.utils import ( | ||||
|     get_model_path, | ||||
|     query_embeddings_doc, | ||||
|     get_embeddings_function, | ||||
|     query_embeddings_collection, | ||||
|     get_embedding_function, | ||||
|     query_doc, | ||||
|     query_doc_with_hybrid_search, | ||||
|     query_collection, | ||||
|     query_collection_with_hybrid_search, | ||||
| ) | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|  | @ -147,6 +149,15 @@ update_reranking_model( | |||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| app.state.EMBEDDING_FUNCTION = get_embedding_function( | ||||
|     app.state.RAG_EMBEDDING_ENGINE, | ||||
|     app.state.RAG_EMBEDDING_MODEL, | ||||
|     app.state.sentence_transformer_ef, | ||||
|     app.state.OPENAI_API_KEY, | ||||
|     app.state.OPENAI_API_BASE_URL, | ||||
| ) | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -227,6 +238,14 @@ async def update_embedding_config( | |||
| 
 | ||||
|         update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) | ||||
| 
 | ||||
|         app.state.EMBEDDING_FUNCTION = get_embedding_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, | ||||
|  | @ -367,26 +386,21 @@ def query_doc_handler( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         embeddings_function = get_embeddings_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
| 
 | ||||
|         return query_embeddings_doc( | ||||
|         if app.state.ENABLE_RAG_HYBRID_SEARCH: | ||||
|             return query_doc_with_hybrid_search( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|             hybrid_search=( | ||||
|                 form_data.hybrid | ||||
|                 if form_data.hybrid | ||||
|                 else app.state.ENABLE_RAG_HYBRID_SEARCH | ||||
|             ), | ||||
|             ) | ||||
|         else: | ||||
|             return query_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|  | @ -410,27 +424,23 @@ def query_collection_handler( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         embeddings_function = get_embeddings_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
| 
 | ||||
|         return query_embeddings_collection( | ||||
|         if app.state.ENABLE_RAG_HYBRID_SEARCH: | ||||
|             return query_collection_with_hybrid_search( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|             hybrid_search=( | ||||
|                 form_data.hybrid | ||||
|                 if form_data.hybrid | ||||
|                 else app.state.ENABLE_RAG_HYBRID_SEARCH | ||||
|             ), | ||||
|             ) | ||||
|         else: | ||||
|             return query_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|  | @ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|         embedding_func = get_embeddings_function( | ||||
|         embedding_func = get_embedding_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|  |  | |||
|  | @ -26,20 +26,38 @@ log = logging.getLogger(__name__) | |||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_doc( | ||||
| def query_doc( | ||||
|     collection_name: str, | ||||
|     query: str, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
|     embedding_function, | ||||
|     k: int, | ||||
|     r: int, | ||||
|     hybrid_search: bool, | ||||
| ): | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
|         query_embeddings = embedding_function(query) | ||||
|         result = collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=k, | ||||
|         ) | ||||
| 
 | ||||
|         if hybrid_search: | ||||
|         log.info(f"query_doc:result {result}") | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         raise e | ||||
| 
 | ||||
| 
 | ||||
| def query_doc_with_hybrid_search( | ||||
|     collection_name: str, | ||||
|     query: str, | ||||
|     embedding_function, | ||||
|     k: int, | ||||
|     reranking_function, | ||||
|     r: int, | ||||
| ): | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
|         documents = collection.get()  # get all documents | ||||
| 
 | ||||
|         bm25_retriever = BM25Retriever.from_texts( | ||||
|             texts=documents.get("documents"), | ||||
|             metadatas=documents.get("metadatas"), | ||||
|  | @ -48,7 +66,7 @@ def query_embeddings_doc( | |||
| 
 | ||||
|         chroma_retriever = ChromaRetriever( | ||||
|             collection=collection, | ||||
|                 embeddings_function=embeddings_function, | ||||
|             embedding_function=embedding_function, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|  | @ -57,7 +75,7 @@ def query_embeddings_doc( | |||
|         ) | ||||
| 
 | ||||
|         compressor = RerankCompressor( | ||||
|                 embeddings_function=embeddings_function, | ||||
|             embedding_function=embedding_function, | ||||
|             reranking_function=reranking_function, | ||||
|             r_score=r, | ||||
|             top_n=k, | ||||
|  | @ -73,14 +91,7 @@ def query_embeddings_doc( | |||
|             "documents": [[d.page_content for d in result]], | ||||
|             "metadatas": [[d.metadata for d in result]], | ||||
|         } | ||||
|         else: | ||||
|             query_embeddings = embeddings_function(query) | ||||
|             result = collection.query( | ||||
|                 query_embeddings=[query_embeddings], | ||||
|                 n_results=k, | ||||
|             ) | ||||
| 
 | ||||
|         log.info(f"query_embeddings_doc:result {result}") | ||||
|         log.info(f"query_doc_with_hybrid_search:result {result}") | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         raise e | ||||
|  | @ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False): | |||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_collection( | ||||
| def query_collection( | ||||
|     collection_names: List[str], | ||||
|     query: str, | ||||
|     embedding_function, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
|     hybrid_search: bool, | ||||
| ): | ||||
| 
 | ||||
|     results = [] | ||||
| 
 | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|             result = query_embeddings_doc( | ||||
|             result = query_doc( | ||||
|                 collection_name=collection_name, | ||||
|                 query=query, | ||||
|                 k=k, | ||||
|                 r=r, | ||||
|                 embeddings_function=embeddings_function, | ||||
|                 embedding_function=embedding_function, | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
|     return merge_and_sort_query_results(results, k=k) | ||||
| 
 | ||||
| 
 | ||||
| def query_collection_with_hybrid_search( | ||||
|     collection_names: List[str], | ||||
|     query: str, | ||||
|     embedding_function, | ||||
|     k: int, | ||||
|     reranking_function, | ||||
|     r: float, | ||||
| ): | ||||
| 
 | ||||
|     results = [] | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|             result = query_doc_with_hybrid_search( | ||||
|                 collection_name=collection_name, | ||||
|                 query=query, | ||||
|                 embedding_function=embedding_function, | ||||
|                 k=k, | ||||
|                 reranking_function=reranking_function, | ||||
|                 hybrid_search=hybrid_search, | ||||
|                 r=r, | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     reverse = hybrid_search and reranking_function is not None | ||||
|     return merge_and_sort_query_results(results, k=k, reverse=reverse) | ||||
|     return merge_and_sort_query_results(results, k=k, reverse=True) | ||||
| 
 | ||||
| 
 | ||||
| def rag_template(template: str, context: str, query: str): | ||||
|  | @ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str): | |||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def get_embeddings_function( | ||||
| def get_embedding_function( | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|  | @ -204,19 +232,13 @@ def rag_messages( | |||
|     docs, | ||||
|     messages, | ||||
|     template, | ||||
|     embedding_function, | ||||
|     k, | ||||
|     reranking_function, | ||||
|     r, | ||||
|     hybrid_search, | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|     reranking_function, | ||||
|     openai_key, | ||||
|     openai_url, | ||||
| ): | ||||
|     log.debug( | ||||
|         f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}" | ||||
|     ) | ||||
|     log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|     for i in range(len(messages) - 1, -1, -1): | ||||
|  | @ -243,14 +265,6 @@ def rag_messages( | |||
|         content_type = None | ||||
|         query = "" | ||||
| 
 | ||||
|     embeddings_function = get_embeddings_function( | ||||
|         embedding_engine, | ||||
|         embedding_model, | ||||
|         embedding_function, | ||||
|         openai_key, | ||||
|         openai_url, | ||||
|     ) | ||||
| 
 | ||||
|     extracted_collections = [] | ||||
|     relevant_contexts = [] | ||||
| 
 | ||||
|  | @ -271,25 +285,30 @@ def rag_messages( | |||
|         try: | ||||
|             if doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             elif doc["type"] == "collection": | ||||
|                 context = query_embeddings_collection( | ||||
|                     collection_names=doc["collection_names"], | ||||
|             else: | ||||
|                 if hybrid_search: | ||||
|                     context = query_collection_with_hybrid_search( | ||||
|                         collection_names=( | ||||
|                             doc["collection_names"] | ||||
|                             if doc["type"] == "collection" | ||||
|                             else [doc["collection_name"]] | ||||
|                         ), | ||||
|                         query=query, | ||||
|                         embedding_function=embedding_function, | ||||
|                         k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                         reranking_function=reranking_function, | ||||
|                     hybrid_search=hybrid_search, | ||||
|                         r=r, | ||||
|                     ) | ||||
|                 else: | ||||
|                 context = query_embeddings_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|                     context = query_collection( | ||||
|                         collection_names=( | ||||
|                             doc["collection_names"] | ||||
|                             if doc["type"] == "collection" | ||||
|                             else [doc["collection_name"]] | ||||
|                         ), | ||||
|                         query=query, | ||||
|                         embedding_function=embedding_function, | ||||
|                         k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                     reranking_function=reranking_function, | ||||
|                     hybrid_search=hybrid_search, | ||||
|                     ) | ||||
|         except Exception as e: | ||||
|             log.exception(e) | ||||
|  | @ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun | |||
| 
 | ||||
| class ChromaRetriever(BaseRetriever): | ||||
|     collection: Any | ||||
|     embeddings_function: Any | ||||
|     embedding_function: Any | ||||
|     top_n: int | ||||
| 
 | ||||
|     def _get_relevant_documents( | ||||
|  | @ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever): | |||
|         *, | ||||
|         run_manager: CallbackManagerForRetrieverRun, | ||||
|     ) -> List[Document]: | ||||
|         query_embeddings = self.embeddings_function(query) | ||||
|         query_embeddings = self.embedding_function(query) | ||||
| 
 | ||||
|         results = self.collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|  | @ -445,7 +464,7 @@ from sentence_transformers import util | |||
| 
 | ||||
| 
 | ||||
| class RerankCompressor(BaseDocumentCompressor): | ||||
|     embeddings_function: Any | ||||
|     embedding_function: Any | ||||
|     reranking_function: Any | ||||
|     r_score: float | ||||
|     top_n: int | ||||
|  | @ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor): | |||
|                 [(query, doc.page_content) for doc in documents] | ||||
|             ) | ||||
|         else: | ||||
|             query_embedding = self.embeddings_function(query) | ||||
|             document_embedding = self.embeddings_function( | ||||
|             query_embedding = self.embedding_function(query) | ||||
|             document_embedding = self.embedding_function( | ||||
|                 [doc.page_content for doc in documents] | ||||
|             ) | ||||
|             scores = util.cos_sim(query_embedding, document_embedding)[0] | ||||
|  |  | |||
|  | @ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|             if "docs" in data: | ||||
|                 data = {**data} | ||||
|                 data["messages"] = rag_messages( | ||||
|                     data["docs"], | ||||
|                     data["messages"], | ||||
|                     rag_app.state.RAG_TEMPLATE, | ||||
|                     rag_app.state.TOP_K, | ||||
|                     rag_app.state.RELEVANCE_THRESHOLD, | ||||
|                     rag_app.state.ENABLE_RAG_HYBRID_SEARCH, | ||||
|                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||
|                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||
|                     rag_app.state.sentence_transformer_ef, | ||||
|                     rag_app.state.sentence_transformer_rf, | ||||
|                     rag_app.state.OPENAI_API_KEY, | ||||
|                     rag_app.state.OPENAI_API_BASE_URL, | ||||
|                     docs=data["docs"], | ||||
|                     messages=data["messages"], | ||||
|                     template=rag_app.state.RAG_TEMPLATE, | ||||
|                     embedding_function=rag_app.state.EMBEDDING_FUNCTION, | ||||
|                     k=rag_app.state.TOP_K, | ||||
|                     reranking_function=rag_app.state.sentence_transformer_rf, | ||||
|                     r=rag_app.state.RELEVANCE_THRESHOLD, | ||||
|                     hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH, | ||||
|                 ) | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek