forked from open-webui/open-webui
		
	revert: original rag pipeline
This commit is contained in:
		
							parent
							
								
									7d88689f51
								
							
						
					
					
						commit
						984dbf13ab
					
				
					 1 changed files with 51 additions and 34 deletions
				
			
		|  | @ -18,6 +18,9 @@ from langchain.retrievers import ( | ||||||
|     EnsembleRetriever, |     EnsembleRetriever, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | from sentence_transformers import CrossEncoder | ||||||
|  | 
 | ||||||
|  | from typing import Optional | ||||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -28,50 +31,64 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||||
| def query_embeddings_doc( | def query_embeddings_doc( | ||||||
|     collection_name: str, |     collection_name: str, | ||||||
|     query: str, |     query: str, | ||||||
|     k: int, |  | ||||||
|     r: float, |  | ||||||
|     embeddings_function, |     embeddings_function, | ||||||
|     reranking_function, |     k: int, | ||||||
|  |     reranking_function: Optional[CrossEncoder] = None, | ||||||
|  |     r: Optional[float] = None, | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|         # if you use docker use the model from the environment variable |  | ||||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) |  | ||||||
| 
 | 
 | ||||||
|         documents = collection.get()  # get all documents |         if reranking_function: | ||||||
|         bm25_retriever = BM25Retriever.from_texts( |             # if you use docker use the model from the environment variable | ||||||
|             texts=documents.get("documents"), |             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||||
|             metadatas=documents.get("metadatas"), |  | ||||||
|         ) |  | ||||||
|         bm25_retriever.k = k |  | ||||||
| 
 | 
 | ||||||
|         chroma_retriever = ChromaRetriever( |             documents = collection.get()  # get all documents | ||||||
|             collection=collection, |             bm25_retriever = BM25Retriever.from_texts( | ||||||
|             embeddings_function=embeddings_function, |                 texts=documents.get("documents"), | ||||||
|             top_n=k, |                 metadatas=documents.get("metadatas"), | ||||||
|         ) |             ) | ||||||
|  |             bm25_retriever.k = k | ||||||
| 
 | 
 | ||||||
|         ensemble_retriever = EnsembleRetriever( |             chroma_retriever = ChromaRetriever( | ||||||
|             retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] |                 collection=collection, | ||||||
|         ) |                 embeddings_function=embeddings_function, | ||||||
|  |                 top_n=k, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         compressor = RerankCompressor( |             ensemble_retriever = EnsembleRetriever( | ||||||
|             embeddings_function=embeddings_function, |                 retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] | ||||||
|             reranking_function=reranking_function, |             ) | ||||||
|             r_score=r, |  | ||||||
|             top_n=k, |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         compression_retriever = ContextualCompressionRetriever( |             compressor = RerankCompressor( | ||||||
|             base_compressor=compressor, base_retriever=ensemble_retriever |                 embeddings_function=embeddings_function, | ||||||
|         ) |                 reranking_function=reranking_function, | ||||||
|  |                 r_score=r, | ||||||
|  |                 top_n=k, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         result = compression_retriever.invoke(query) |             compression_retriever = ContextualCompressionRetriever( | ||||||
|         result = { |                 base_compressor=compressor, base_retriever=ensemble_retriever | ||||||
|             "distances": [[d.metadata.get("score") for d in result]], |             ) | ||||||
|             "documents": [[d.page_content for d in result]], |  | ||||||
|             "metadatas": [[d.metadata for d in result]], |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|  |             result = compression_retriever.invoke(query) | ||||||
|  |             result = { | ||||||
|  |                 "distances": [[d.metadata.get("score") for d in result]], | ||||||
|  |                 "documents": [[d.page_content for d in result]], | ||||||
|  |                 "metadatas": [[d.metadata for d in result]], | ||||||
|  |             } | ||||||
|  |         else: | ||||||
|  |             # if you use docker use the model from the environment variable | ||||||
|  |             query_embeddings = embeddings_function(query) | ||||||
|  | 
 | ||||||
|  |             log.info(f"query_embeddings_doc {query_embeddings}") | ||||||
|  |             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||||
|  | 
 | ||||||
|  |             result = collection.query( | ||||||
|  |                 query_embeddings=[query_embeddings], | ||||||
|  |                 n_results=k, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             log.info(f"query_embeddings_doc:result {result}") | ||||||
|         return result |         return result | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         raise e |         raise e | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek