forked from open-webui/open-webui
		
	revert: original rag pipeline
This commit is contained in:
		
							parent
							
								
									7d88689f51
								
							
						
					
					
						commit
						984dbf13ab
					
				
					 1 changed files with 51 additions and 34 deletions
				
			
		|  | @ -18,6 +18,9 @@ from langchain.retrievers import ( | |||
|     EnsembleRetriever, | ||||
| ) | ||||
| 
 | ||||
| from sentence_transformers import CrossEncoder | ||||
| 
 | ||||
| from typing import Optional | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| 
 | ||||
|  | @ -28,50 +31,64 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | |||
| def query_embeddings_doc( | ||||
|     collection_name: str, | ||||
|     query: str, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
|     k: int, | ||||
|     reranking_function: Optional[CrossEncoder] = None, | ||||
|     r: Optional[float] = None, | ||||
| ): | ||||
|     try: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|         documents = collection.get()  # get all documents | ||||
|         bm25_retriever = BM25Retriever.from_texts( | ||||
|             texts=documents.get("documents"), | ||||
|             metadatas=documents.get("metadatas"), | ||||
|         ) | ||||
|         bm25_retriever.k = k | ||||
|         if reranking_function: | ||||
|             # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|         chroma_retriever = ChromaRetriever( | ||||
|             collection=collection, | ||||
|             embeddings_function=embeddings_function, | ||||
|             top_n=k, | ||||
|         ) | ||||
|             documents = collection.get()  # get all documents | ||||
|             bm25_retriever = BM25Retriever.from_texts( | ||||
|                 texts=documents.get("documents"), | ||||
|                 metadatas=documents.get("metadatas"), | ||||
|             ) | ||||
|             bm25_retriever.k = k | ||||
| 
 | ||||
|         ensemble_retriever = EnsembleRetriever( | ||||
|             retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] | ||||
|         ) | ||||
|             chroma_retriever = ChromaRetriever( | ||||
|                 collection=collection, | ||||
|                 embeddings_function=embeddings_function, | ||||
|                 top_n=k, | ||||
|             ) | ||||
| 
 | ||||
|         compressor = RerankCompressor( | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=reranking_function, | ||||
|             r_score=r, | ||||
|             top_n=k, | ||||
|         ) | ||||
|             ensemble_retriever = EnsembleRetriever( | ||||
|                 retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] | ||||
|             ) | ||||
| 
 | ||||
|         compression_retriever = ContextualCompressionRetriever( | ||||
|             base_compressor=compressor, base_retriever=ensemble_retriever | ||||
|         ) | ||||
|             compressor = RerankCompressor( | ||||
|                 embeddings_function=embeddings_function, | ||||
|                 reranking_function=reranking_function, | ||||
|                 r_score=r, | ||||
|                 top_n=k, | ||||
|             ) | ||||
| 
 | ||||
|         result = compression_retriever.invoke(query) | ||||
|         result = { | ||||
|             "distances": [[d.metadata.get("score") for d in result]], | ||||
|             "documents": [[d.page_content for d in result]], | ||||
|             "metadatas": [[d.metadata for d in result]], | ||||
|         } | ||||
|             compression_retriever = ContextualCompressionRetriever( | ||||
|                 base_compressor=compressor, base_retriever=ensemble_retriever | ||||
|             ) | ||||
| 
 | ||||
|             result = compression_retriever.invoke(query) | ||||
|             result = { | ||||
|                 "distances": [[d.metadata.get("score") for d in result]], | ||||
|                 "documents": [[d.page_content for d in result]], | ||||
|                 "metadatas": [[d.metadata for d in result]], | ||||
|             } | ||||
|         else: | ||||
|             # if you use docker use the model from the environment variable | ||||
|             query_embeddings = embeddings_function(query) | ||||
| 
 | ||||
|             log.info(f"query_embeddings_doc {query_embeddings}") | ||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|             result = collection.query( | ||||
|                 query_embeddings=[query_embeddings], | ||||
|                 n_results=k, | ||||
|             ) | ||||
| 
 | ||||
|             log.info(f"query_embeddings_doc:result {result}") | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         raise e | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek