forked from open-webui/open-webui
		
	feat: hybrid search and reranking support
This commit is contained in:
		
							parent
							
								
									db801aee79
								
							
						
					
					
						commit
						c0259aad67
					
				
					 10 changed files with 262 additions and 131 deletions
				
			
		|  | @ -64,6 +64,8 @@ from config import ( | |||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_TOP_K, | ||||
|     RAG_RELEVANCE_THRESHOLD, | ||||
|     RAG_EMBEDDING_ENGINE, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|  | @ -86,7 +88,8 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | |||
| app = FastAPI() | ||||
| 
 | ||||
| 
 | ||||
| app.state.TOP_K = 4 | ||||
| app.state.TOP_K = RAG_TOP_K | ||||
| app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| 
 | ||||
|  | @ -107,12 +110,17 @@ if app.state.RAG_EMBEDDING_ENGINE == "": | |||
|         device=DEVICE_TYPE, | ||||
|         trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|     ) | ||||
| else: | ||||
|     app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
| app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|     app.state.RAG_RERANKING_MODEL, | ||||
|     device=DEVICE_TYPE, | ||||
|     trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
| ) | ||||
| if not app.state.RAG_RERANKING_MODEL == "": | ||||
|     app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|         app.state.RAG_RERANKING_MODEL, | ||||
|         device=DEVICE_TYPE, | ||||
|         trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
|     ) | ||||
| else: | ||||
|     app.state.sentence_transformer_rf = None | ||||
| 
 | ||||
| 
 | ||||
| origins = ["*"] | ||||
|  | @ -185,22 +193,22 @@ async def update_embedding_config( | |||
|     ) | ||||
|     try: | ||||
|         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine | ||||
|         app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: | ||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|             app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
|             if form_data.openai_config != None: | ||||
|                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url | ||||
|                 app.state.OPENAI_API_KEY = form_data.openai_config.key | ||||
| 
 | ||||
|             app.state.sentence_transformer_ef = None | ||||
|         else: | ||||
|             sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||
|                 app.state.RAG_EMBEDDING_MODEL, | ||||
|                 device=DEVICE_TYPE, | ||||
|                 trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|             app.state.sentence_transformer_ef = ( | ||||
|                 sentence_transformers.SentenceTransformer( | ||||
|                     app.state.RAG_EMBEDDING_MODEL, | ||||
|                     device=DEVICE_TYPE, | ||||
|                     trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|                 ) | ||||
|             ) | ||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|             app.state.sentence_transformer_ef = sentence_transformer_ef | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|  | @ -222,7 +230,7 @@ async def update_embedding_config( | |||
| 
 | ||||
| class RerankingModelUpdateForm(BaseModel): | ||||
|     reranking_model: str | ||||
|      | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/reranking/update") | ||||
| async def update_reranking_config( | ||||
|  | @ -233,10 +241,14 @@ async def update_reranking_config( | |||
|     ) | ||||
|     try: | ||||
|         app.state.RAG_RERANKING_MODEL = form_data.reranking_model | ||||
|         app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|             app.state.RAG_RERANKING_MODEL, | ||||
|             device=DEVICE_TYPE, | ||||
|         ) | ||||
| 
 | ||||
|         if app.state.RAG_RERANKING_MODEL == "": | ||||
|             app.state.sentence_transformer_rf = None | ||||
|         else: | ||||
|             app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|                 app.state.RAG_RERANKING_MODEL, | ||||
|                 device=DEVICE_TYPE, | ||||
|             ) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|  | @ -302,11 +314,13 @@ async def get_query_settings(user=Depends(get_admin_user)): | |||
|         "status": True, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "k": app.state.TOP_K, | ||||
|         "r": app.state.RELEVANCE_THRESHOLD, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class QuerySettingsForm(BaseModel): | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
|     template: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
|  | @ -316,6 +330,7 @@ async def update_query_settings( | |||
| ): | ||||
|     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE | ||||
|     app.state.TOP_K = form_data.k if form_data.k else 4 | ||||
|     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 | ||||
|     return {"status": True, "template": app.state.RAG_TEMPLATE} | ||||
| 
 | ||||
| 
 | ||||
|  | @ -323,6 +338,7 @@ class QueryDocForm(BaseModel): | |||
|     collection_name: str | ||||
|     query: str | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/doc") | ||||
|  | @ -343,6 +359,7 @@ def query_doc_handler( | |||
|             collection_name=form_data.collection_name, | ||||
|             query=form_data.query, | ||||
|             k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|         ) | ||||
|  | @ -358,6 +375,7 @@ class QueryCollectionsForm(BaseModel): | |||
|     collection_names: List[str] | ||||
|     query: str | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/collection") | ||||
|  | @ -378,6 +396,7 @@ def query_collection_handler( | |||
|             collection_names=form_data.collection_names, | ||||
|             query=form_data.query, | ||||
|             k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|         ) | ||||
|  | @ -467,12 +486,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|         ) | ||||
| 
 | ||||
|         embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             embeddings = embedding_func(embedding_texts) | ||||
|         else: | ||||
|             embeddings = [ | ||||
|                 embedding_func(embedding_texts) for text in texts | ||||
|             ] | ||||
|         embeddings = embedding_func(embedding_texts) | ||||
| 
 | ||||
|         for batch in create_batches( | ||||
|             api=CHROMA_CLIENT, | ||||
|  |  | |||
|  | @ -1,8 +1,5 @@ | |||
| import logging | ||||
| import requests | ||||
| import operator | ||||
| 
 | ||||
| import sentence_transformers | ||||
| 
 | ||||
| from typing import List | ||||
| 
 | ||||
|  | @ -11,8 +8,10 @@ from apps.ollama.main import ( | |||
|     GenerateEmbeddingsForm, | ||||
| ) | ||||
| 
 | ||||
| from langchain_core.documents import Document | ||||
| from langchain_community.retrievers import BM25Retriever | ||||
| from langchain.retrievers import ( | ||||
|     BM25Retriever, | ||||
|     ContextualCompressionRetriever, | ||||
|     EnsembleRetriever, | ||||
| ) | ||||
| 
 | ||||
|  | @ -27,6 +26,7 @@ def query_embeddings_doc( | |||
|     collection_name: str, | ||||
|     query: str, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
| ): | ||||
|  | @ -34,38 +34,39 @@ def query_embeddings_doc( | |||
|         # if you use docker use the model from the environment variable | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|         # keyword search | ||||
|         documents = collection.get() # get all documents | ||||
|         documents = collection.get()  # get all documents | ||||
|         bm25_retriever = BM25Retriever.from_texts( | ||||
|             texts=documents.get("documents"), | ||||
|             metadatas=documents.get("metadatas"), | ||||
|         ) | ||||
|         bm25_retriever.k = k | ||||
| 
 | ||||
|         # semantic search (vector) | ||||
|         chroma_retriever = ChromaRetriever( | ||||
|             collection=collection, | ||||
|             k=k, | ||||
|             embeddings_function=embeddings_function, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|         # hybrid search (ensemble) | ||||
|         ensemble_retriever = EnsembleRetriever( | ||||
|             retrievers=[bm25_retriever, chroma_retriever], | ||||
|             weights=[0.6, 0.4] | ||||
|             retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] | ||||
|         ) | ||||
| 
 | ||||
|         documents = ensemble_retriever.invoke(query) | ||||
|         result = query_results_rank( | ||||
|             query=query, | ||||
|             documents=documents, | ||||
|             k=k, | ||||
|         compressor = RerankCompressor( | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=reranking_function, | ||||
|             r_score=r, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|         compression_retriever = ContextualCompressionRetriever( | ||||
|             base_compressor=compressor, base_retriever=ensemble_retriever | ||||
|         ) | ||||
| 
 | ||||
|         result = compression_retriever.invoke(query) | ||||
|         result = { | ||||
|             "distances": [[d[1].item() for d in result]], | ||||
|             "documents": [[d[0].page_content for d in result]], | ||||
|             "metadatas": [[d[0].metadata for d in result]], | ||||
|             "distances": [[d.metadata.get("score") for d in result]], | ||||
|             "documents": [[d.page_content for d in result]], | ||||
|             "metadatas": [[d.metadata for d in result]], | ||||
|         } | ||||
| 
 | ||||
|         return result | ||||
|  | @ -73,58 +74,52 @@ def query_embeddings_doc( | |||
|         raise e | ||||
| 
 | ||||
| 
 | ||||
| def query_results_rank(query: str, documents, k: int, reranking_function): | ||||
|     scores = reranking_function.predict([(query, doc.page_content) for doc in documents]) | ||||
|     docs_with_scores = list(zip(documents, scores)) | ||||
|     result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | ||||
|     return result[: k] | ||||
| 
 | ||||
| 
 | ||||
| def merge_and_sort_query_results(query_results, k): | ||||
|     # Initialize lists to store combined data | ||||
|     combined_distances = [] | ||||
|     combined_documents = [] | ||||
|     combined_metadatas = [] | ||||
| 
 | ||||
|     # Combine data from each dictionary | ||||
|     for data in query_results: | ||||
|         combined_distances.extend(data["distances"][0]) | ||||
|         combined_documents.extend(data["documents"][0]) | ||||
|         combined_metadatas.extend(data["metadatas"][0]) | ||||
| 
 | ||||
|     # Create a list of tuples (distance, document, metadata) | ||||
|     combined = list( | ||||
|         zip(combined_distances, combined_documents, combined_metadatas) | ||||
|     ) | ||||
|     combined = list(zip(combined_distances, combined_documents, combined_metadatas)) | ||||
| 
 | ||||
|     # Sort the list based on distances | ||||
|     combined.sort(key=lambda x: x[0]) | ||||
| 
 | ||||
|     # Unzip the sorted list | ||||
|     sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) | ||||
|     # We don't have anything :-( | ||||
|     if not combined: | ||||
|         sorted_distances = [] | ||||
|         sorted_documents = [] | ||||
|         sorted_metadatas = [] | ||||
|     else: | ||||
|         # Unzip the sorted list | ||||
|         sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) | ||||
| 
 | ||||
|     # Slicing the lists to include only k elements | ||||
|     sorted_distances = list(sorted_distances)[:k] | ||||
|     sorted_documents = list(sorted_documents)[:k] | ||||
|     sorted_metadatas = list(sorted_metadatas)[:k] | ||||
|         # Slicing the lists to include only k elements | ||||
|         sorted_distances = list(sorted_distances)[:k] | ||||
|         sorted_documents = list(sorted_documents)[:k] | ||||
|         sorted_metadatas = list(sorted_metadatas)[:k] | ||||
| 
 | ||||
|     # Create the output dictionary | ||||
|     merged_query_results = { | ||||
|     result = { | ||||
|         "distances": [sorted_distances], | ||||
|         "documents": [sorted_documents], | ||||
|         "metadatas": [sorted_metadatas], | ||||
|         "embeddings": None, | ||||
|         "uris": None, | ||||
|         "data": None, | ||||
|     } | ||||
| 
 | ||||
|     return merged_query_results | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_collection( | ||||
|     collection_names: List[str], | ||||
|     query: str, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
| ): | ||||
|  | @ -137,6 +132,7 @@ def query_embeddings_collection( | |||
|                 collection_name=collection_name, | ||||
|                 query=query, | ||||
|                 k=k, | ||||
|                 r=r, | ||||
|                 embeddings_function=embeddings_function, | ||||
|                 reranking_function=reranking_function, | ||||
|             ) | ||||
|  | @ -162,22 +158,31 @@ def query_embeddings_function( | |||
| ): | ||||
|     if embedding_engine == "": | ||||
|         return lambda query: embedding_function.encode(query).tolist() | ||||
|     elif embedding_engine == "ollama": | ||||
|         return lambda query: generate_ollama_embeddings( | ||||
|             GenerateEmbeddingsForm( | ||||
|                 **{ | ||||
|                     "model": embedding_model, | ||||
|                     "prompt": query, | ||||
|                 } | ||||
|     elif embedding_engine in ["ollama", "openai"]: | ||||
|         if embedding_engine == "ollama": | ||||
|             func = lambda query: generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": embedding_model, | ||||
|                         "prompt": query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
|         ) | ||||
|     elif embedding_engine == "openai": | ||||
|         return lambda query: generate_openai_embeddings( | ||||
|             model=embedding_model, | ||||
|             text=query, | ||||
|             key=openai_key, | ||||
|             url=openai_url, | ||||
|         ) | ||||
|         elif embedding_engine == "openai": | ||||
|             func = lambda query: generate_openai_embeddings( | ||||
|                 model=embedding_model, | ||||
|                 text=query, | ||||
|                 key=openai_key, | ||||
|                 url=openai_url, | ||||
|             ) | ||||
| 
 | ||||
|         def generate_multiple(query, f): | ||||
|             if isinstance(query, list): | ||||
|                 return [f(q) for q in query] | ||||
|             else: | ||||
|                 return f(query) | ||||
| 
 | ||||
|         return lambda query: generate_multiple(query, func) | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages( | ||||
|  | @ -185,6 +190,7 @@ def rag_messages( | |||
|     messages, | ||||
|     template, | ||||
|     k, | ||||
|     r, | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|  | @ -221,53 +227,68 @@ def rag_messages( | |||
|         content_type = None | ||||
|         query = "" | ||||
| 
 | ||||
|     embeddings_function = query_embeddings_function( | ||||
|         embedding_engine, | ||||
|         embedding_model, | ||||
|         embedding_function, | ||||
|         openai_key, | ||||
|         openai_url, | ||||
|     ) | ||||
| 
 | ||||
|     extracted_collections = [] | ||||
|     relevant_contexts = [] | ||||
| 
 | ||||
|     for doc in docs: | ||||
|         context = None | ||||
| 
 | ||||
|         try: | ||||
|         collection = doc.get("collection_name") | ||||
|         if collection: | ||||
|             collection = [collection] | ||||
|         else: | ||||
|             collection = doc.get("collection_names", []) | ||||
| 
 | ||||
|         collection = set(collection).difference(extracted_collections) | ||||
|         if not collection: | ||||
|             log.debug(f"skipping {doc} as it has already been extracted") | ||||
|             continue | ||||
| 
 | ||||
|         try: | ||||
|             if doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 embeddings_function = query_embeddings_function( | ||||
|                     embedding_engine, | ||||
|                     embedding_model, | ||||
|                     embedding_function, | ||||
|                     openai_key, | ||||
|                     openai_url, | ||||
|             elif doc["type"] == "collection": | ||||
|                 context = query_embeddings_collection( | ||||
|                     collection_names=doc["collection_names"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                     reranking_function=reranking_function, | ||||
|                 ) | ||||
|             else: | ||||
|                 context = query_embeddings_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                     reranking_function=reranking_function, | ||||
|                 ) | ||||
| 
 | ||||
|                 if doc["type"] == "collection": | ||||
|                     context = query_embeddings_collection( | ||||
|                         collection_names=doc["collection_names"], | ||||
|                         query=query, | ||||
|                         k=k, | ||||
|                         embeddings_function=embeddings_function, | ||||
|                         reranking_function=reranking_function, | ||||
|                     ) | ||||
|                 else: | ||||
|                     context = query_embeddings_doc( | ||||
|                         collection_name=doc["collection_name"], | ||||
|                         query=query, | ||||
|                         k=k, | ||||
|                         embeddings_function=embeddings_function, | ||||
|                         reranking_function=reranking_function, | ||||
|                     ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             log.exception(e) | ||||
|             context = None | ||||
| 
 | ||||
|         relevant_contexts.append(context) | ||||
|         if context: | ||||
|             relevant_contexts.append(context) | ||||
| 
 | ||||
|         extracted_collections.extend(collection) | ||||
| 
 | ||||
|     log.debug(f"relevant_contexts: {relevant_contexts}") | ||||
| 
 | ||||
|     context_string = "" | ||||
|     for context in relevant_contexts: | ||||
|         if context: | ||||
|             context_string += " ".join(context["documents"][0]) + "\n" | ||||
|         items = context["documents"][0] | ||||
|         context_string += "\n\n".join(items) | ||||
|     context_string = context_string.strip() | ||||
| 
 | ||||
|     ra_content = rag_template( | ||||
|         template=template, | ||||
|  | @ -275,6 +296,8 @@ def rag_messages( | |||
|         query=query, | ||||
|     ) | ||||
| 
 | ||||
|     log.debug(f"ra_content: {ra_content}") | ||||
| 
 | ||||
|     if content_type == "list": | ||||
|         new_content = [] | ||||
|         for content_item in user_message["content"]: | ||||
|  | @ -321,15 +344,14 @@ def generate_openai_embeddings( | |||
| 
 | ||||
| from typing import Any | ||||
| 
 | ||||
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||||
| from langchain_core.documents import Document | ||||
| from langchain_core.retrievers import BaseRetriever | ||||
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||||
| 
 | ||||
| 
 | ||||
| class ChromaRetriever(BaseRetriever): | ||||
|     collection: Any | ||||
|     k: int | ||||
|     embeddings_function: Any | ||||
|     top_n: int | ||||
| 
 | ||||
|     def _get_relevant_documents( | ||||
|         self, | ||||
|  | @ -341,7 +363,7 @@ class ChromaRetriever(BaseRetriever): | |||
| 
 | ||||
|         results = self.collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=self.k, | ||||
|             n_results=self.top_n, | ||||
|         ) | ||||
| 
 | ||||
|         ids = results["ids"][0] | ||||
|  | @ -355,3 +377,60 @@ class ChromaRetriever(BaseRetriever): | |||
|             ) | ||||
|             for idx in range(len(ids)) | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| import operator | ||||
| 
 | ||||
| from typing import Optional, Sequence | ||||
| 
 | ||||
| from langchain_core.documents import BaseDocumentCompressor, Document | ||||
| from langchain_core.callbacks import Callbacks | ||||
| from langchain_core.pydantic_v1 import Extra | ||||
| 
 | ||||
| from sentence_transformers import util | ||||
| 
 | ||||
| 
 | ||||
| class RerankCompressor(BaseDocumentCompressor): | ||||
|     embeddings_function: Any | ||||
|     reranking_function: Any | ||||
|     r_score: float | ||||
|     top_n: int | ||||
| 
 | ||||
|     class Config: | ||||
|         extra = Extra.forbid | ||||
|         arbitrary_types_allowed = True | ||||
| 
 | ||||
|     def compress_documents( | ||||
|         self, | ||||
|         documents: Sequence[Document], | ||||
|         query: str, | ||||
|         callbacks: Optional[Callbacks] = None, | ||||
|     ) -> Sequence[Document]: | ||||
|         if self.reranking_function: | ||||
|             scores = self.reranking_function.predict( | ||||
|                 [(query, doc.page_content) for doc in documents] | ||||
|             ) | ||||
|         else: | ||||
|             query_embedding = self.embeddings_function(query) | ||||
|             document_embedding = self.embeddings_function( | ||||
|                 [doc.page_content for doc in documents] | ||||
|             ) | ||||
|             scores = util.cos_sim(query_embedding, document_embedding)[0] | ||||
| 
 | ||||
|         docs_with_scores = list(zip(documents, scores.tolist())) | ||||
|         if self.r_score: | ||||
|             docs_with_scores = [ | ||||
|                 (d, s) for d, s in docs_with_scores if s >= self.r_score | ||||
|             ] | ||||
| 
 | ||||
|         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | ||||
|         final_results = [] | ||||
|         for doc, doc_score in result[: self.top_n]: | ||||
|             metadata = doc.metadata | ||||
|             metadata["score"] = doc_score | ||||
|             doc = Document( | ||||
|                 page_content=doc.page_content, | ||||
|                 metadata=metadata, | ||||
|             ) | ||||
|             final_results.append(doc) | ||||
|         return final_results | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Steven Kreitzer
						Steven Kreitzer