forked from open-webui/open-webui
		
	fix: sort ranking hybrid
This commit is contained in:
		
							parent
							
								
									9755cd5baa
								
							
						
					
					
						commit
						69822e4c25
					
				
					 2 changed files with 13 additions and 17 deletions
				
			
		|  | @ -18,8 +18,6 @@ from langchain.retrievers import ( | ||||||
|     EnsembleRetriever, |     EnsembleRetriever, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from sentence_transformers import CrossEncoder |  | ||||||
| 
 |  | ||||||
| from typing import Optional | from typing import Optional | ||||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||||
| 
 | 
 | ||||||
|  | @ -34,14 +32,13 @@ def query_embeddings_doc( | ||||||
|     embeddings_function, |     embeddings_function, | ||||||
|     reranking_function, |     reranking_function, | ||||||
|     k: int, |     k: int, | ||||||
|     r: Optional[float] = None, |     r: int, | ||||||
|     hybrid: Optional[bool] = False, |     hybrid: bool, | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|         if hybrid: |  | ||||||
|             # if you use docker use the model from the environment variable |  | ||||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) |         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||||
| 
 | 
 | ||||||
|  |         if hybrid: | ||||||
|             documents = collection.get()  # get all documents |             documents = collection.get()  # get all documents | ||||||
|             bm25_retriever = BM25Retriever.from_texts( |             bm25_retriever = BM25Retriever.from_texts( | ||||||
|                 texts=documents.get("documents"), |                 texts=documents.get("documents"), | ||||||
|  | @ -77,12 +74,7 @@ def query_embeddings_doc( | ||||||
|                 "metadatas": [[d.metadata for d in result]], |                 "metadatas": [[d.metadata for d in result]], | ||||||
|             } |             } | ||||||
|         else: |         else: | ||||||
|             # if you use docker use the model from the environment variable |  | ||||||
|             query_embeddings = embeddings_function(query) |             query_embeddings = embeddings_function(query) | ||||||
| 
 |  | ||||||
|             log.info(f"query_embeddings_doc {query_embeddings}") |  | ||||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) |  | ||||||
| 
 |  | ||||||
|             result = collection.query( |             result = collection.query( | ||||||
|                 query_embeddings=[query_embeddings], |                 query_embeddings=[query_embeddings], | ||||||
|                 n_results=k, |                 n_results=k, | ||||||
|  | @ -94,7 +86,7 @@ def query_embeddings_doc( | ||||||
|         raise e |         raise e | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def merge_and_sort_query_results(query_results, k): | def merge_and_sort_query_results(query_results, k, reverse=False): | ||||||
|     # Initialize lists to store combined data |     # Initialize lists to store combined data | ||||||
|     combined_distances = [] |     combined_distances = [] | ||||||
|     combined_documents = [] |     combined_documents = [] | ||||||
|  | @ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k): | ||||||
|     combined = list(zip(combined_distances, combined_documents, combined_metadatas)) |     combined = list(zip(combined_distances, combined_documents, combined_metadatas)) | ||||||
| 
 | 
 | ||||||
|     # Sort the list based on distances |     # Sort the list based on distances | ||||||
|     combined.sort(key=lambda x: x[0]) |     combined.sort(key=lambda x: x[0], reverse=reverse) | ||||||
| 
 | 
 | ||||||
|     # We don't have anything :-( |     # We don't have anything :-( | ||||||
|     if not combined: |     if not combined: | ||||||
|  | @ -162,7 +154,8 @@ def query_embeddings_collection( | ||||||
|         except: |         except: | ||||||
|             pass |             pass | ||||||
| 
 | 
 | ||||||
|     return merge_and_sort_query_results(results, k) |     reverse = hybrid and reranking_function is not None | ||||||
|  |     return merge_and_sort_query_results(results, k=k, reverse=reverse) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def rag_template(template: str, context: str, query: str): | def rag_template(template: str, context: str, query: str): | ||||||
|  | @ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor): | ||||||
|                 (d, s) for d, s in docs_with_scores if s >= self.r_score |                 (d, s) for d, s in docs_with_scores if s >= self.r_score | ||||||
|             ] |             ] | ||||||
| 
 | 
 | ||||||
|         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) |         reverse = self.reranking_function is not None | ||||||
|  |         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) | ||||||
|  | 
 | ||||||
|         final_results = [] |         final_results = [] | ||||||
|         for doc, doc_score in result[: self.top_n]: |         for doc, doc_score in result[: self.top_n]: | ||||||
|             metadata = doc.metadata |             metadata = doc.metadata | ||||||
|  |  | ||||||
|  | @ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): | ||||||
|                     rag_app.state.RAG_TEMPLATE, |                     rag_app.state.RAG_TEMPLATE, | ||||||
|                     rag_app.state.TOP_K, |                     rag_app.state.TOP_K, | ||||||
|                     rag_app.state.RELEVANCE_THRESHOLD, |                     rag_app.state.RELEVANCE_THRESHOLD, | ||||||
|  |                     rag_app.state.HYBRID, | ||||||
|                     rag_app.state.RAG_EMBEDDING_ENGINE, |                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||||
|                     rag_app.state.RAG_EMBEDDING_MODEL, |                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||||
|                     rag_app.state.sentence_transformer_ef, |                     rag_app.state.sentence_transformer_ef, | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue