forked from open-webui/open-webui
		
	fix: various api rag results
This commit is contained in:
		
							parent
							
								
									877ed69004
								
							
						
					
					
						commit
						5b8fd14470
					
				
					 6 changed files with 55 additions and 37 deletions
				
			
		|  | @ -391,16 +391,16 @@ def query_doc_handler( | |||
|             return query_doc_with_hybrid_search( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 embedding_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             ) | ||||
|         else: | ||||
|             return query_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 embedding_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|     except Exception as e: | ||||
|  | @ -429,16 +429,16 @@ def query_collection_handler( | |||
|             return query_collection_with_hybrid_search( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 embedding_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 reranking_function=app.state.sentence_transformer_rf, | ||||
|                 r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             ) | ||||
|         else: | ||||
|             return query_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 embeddings_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 embedding_function=app.state.EMBEDDING_FUNCTION, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -35,6 +35,7 @@ def query_doc( | |||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
|         query_embeddings = embedding_function(query) | ||||
| 
 | ||||
|         result = collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=k, | ||||
|  | @ -76,9 +77,9 @@ def query_doc_with_hybrid_search( | |||
| 
 | ||||
|         compressor = RerankCompressor( | ||||
|             embedding_function=embedding_function, | ||||
|             top_n=k, | ||||
|             reranking_function=reranking_function, | ||||
|             r_score=r, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|         compression_retriever = ContextualCompressionRetriever( | ||||
|  | @ -91,6 +92,7 @@ def query_doc_with_hybrid_search( | |||
|             "documents": [[d.page_content for d in result]], | ||||
|             "metadatas": [[d.metadata for d in result]], | ||||
|         } | ||||
| 
 | ||||
|         log.info(f"query_doc_with_hybrid_search:result {result}") | ||||
|         return result | ||||
|     except Exception as e: | ||||
|  | @ -167,7 +169,6 @@ def query_collection_with_hybrid_search( | |||
|     reranking_function, | ||||
|     r: float, | ||||
| ): | ||||
| 
 | ||||
|     results = [] | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|  | @ -182,7 +183,6 @@ def query_collection_with_hybrid_search( | |||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     return merge_and_sort_query_results(results, k=k, reverse=True) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -443,13 +443,15 @@ class ChromaRetriever(BaseRetriever): | |||
|         metadatas = results["metadatas"][0] | ||||
|         documents = results["documents"][0] | ||||
| 
 | ||||
|         return [ | ||||
|             Document( | ||||
|                 metadata=metadatas[idx], | ||||
|                 page_content=documents[idx], | ||||
|         results = [] | ||||
|         for idx in range(len(ids)): | ||||
|             results.append( | ||||
|                 Document( | ||||
|                     metadata=metadatas[idx], | ||||
|                     page_content=documents[idx], | ||||
|                 ) | ||||
|             ) | ||||
|             for idx in range(len(ids)) | ||||
|         ] | ||||
|         return results | ||||
| 
 | ||||
| 
 | ||||
| import operator | ||||
|  | @ -465,9 +467,9 @@ from sentence_transformers import util | |||
| 
 | ||||
| class RerankCompressor(BaseDocumentCompressor): | ||||
|     embedding_function: Any | ||||
|     top_n: int | ||||
|     reranking_function: Any | ||||
|     r_score: float | ||||
|     top_n: int | ||||
| 
 | ||||
|     class Config: | ||||
|         extra = Extra.forbid | ||||
|  | @ -479,7 +481,9 @@ class RerankCompressor(BaseDocumentCompressor): | |||
|         query: str, | ||||
|         callbacks: Optional[Callbacks] = None, | ||||
|     ) -> Sequence[Document]: | ||||
|         if self.reranking_function: | ||||
|         reranking = self.reranking_function is not None | ||||
| 
 | ||||
|         if reranking: | ||||
|             scores = self.reranking_function.predict( | ||||
|                 [(query, doc.page_content) for doc in documents] | ||||
|             ) | ||||
|  | @ -496,9 +500,7 @@ class RerankCompressor(BaseDocumentCompressor): | |||
|                 (d, s) for d, s in docs_with_scores if s >= self.r_score | ||||
|             ] | ||||
| 
 | ||||
|         reverse = self.reranking_function is not None | ||||
|         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) | ||||
| 
 | ||||
|         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | ||||
|         final_results = [] | ||||
|         for doc, doc_score in result[: self.top_n]: | ||||
|             metadata = doc.metadata | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Steven Kreitzer
						Steven Kreitzer