diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 2db2cf1f..654b2481 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,6 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + ENABLE_RAG_HYBRID_SEARCH, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -91,6 +92,9 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD + +app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH + app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -321,6 +325,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, } @@ -328,6 +333,7 @@ class QuerySettingsForm(BaseModel): k: Optional[int] = None r: Optional[float] = None template: Optional[str] = None + hybrid: Optional[bool] = None @app.post("/query/settings/update") @@ -337,7 +343,14 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - return {"status": True, "template": app.state.RAG_TEMPLATE} + app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + "r": app.state.RELEVANCE_THRESHOLD, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, + } class QueryDocForm(BaseModel): @@ -345,6 +358,7 @@ class QueryDocForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/doc") @@ -368,6 +382,11 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) @@ -382,6 +401,7 @@ class QueryCollectionsForm(BaseModel): query: str k: Optional[int] = None r: Optional[float] = None + hybrid: Optional[bool] = None @app.post("/query/collection") @@ -405,6 +425,11 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index da71495b..e9fe8319 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -18,8 +18,6 @@ from langchain.retrievers import ( EnsembleRetriever, ) -from sentence_transformers import CrossEncoder - from typing import Optional from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -32,16 +30,15 @@ def query_embeddings_doc( collection_name: str, query: str, embeddings_function, + reranking_function, k: int, - reranking_function: Optional[CrossEncoder] = None, - r: Optional[float] = None, + r: int, + hybrid_search: bool, ): try: + collection = CHROMA_CLIENT.get_collection(name=collection_name) - if reranking_function: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection(name=collection_name) - + if hybrid_search: documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( texts=documents.get("documents"), @@ -77,24 +74,19 @@ def query_embeddings_doc( "metadatas": [[d.metadata for d in result]], } else: - # if you use docker use the model from the environment variable query_embeddings = embeddings_function(query) - - log.info(f"query_embeddings_doc {query_embeddings}") - collection = CHROMA_CLIENT.get_collection(name=collection_name) - result = collection.query( query_embeddings=[query_embeddings], n_results=k, ) - log.info(f"query_embeddings_doc:result {result}") + log.info(f"query_embeddings_doc:result {result}") return result except Exception as e: raise e -def merge_and_sort_query_results(query_results, k): +def merge_and_sort_query_results(query_results, k, reverse=False): # Initialize lists to store combined data combined_distances = [] combined_documents = [] @@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k): combined = list(zip(combined_distances, combined_documents, combined_metadatas)) # Sort the list based on distances - combined.sort(key=lambda x: x[0]) + combined.sort(key=lambda x: x[0], reverse=reverse) # We don't have anything :-( if not combined: @@ -142,6 +134,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, + hybrid_search: bool, ): results = [] @@ -155,12 +148,14 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid_search=hybrid_search, ) results.append(result) except: pass - return merge_and_sort_query_results(results, k) + reverse = hybrid and reranking_function is not None + return merge_and_sort_query_results(results, k=k, reverse=reverse) def rag_template(template: str, context: str, query: str): @@ -211,6 +206,7 @@ def rag_messages( template, k, r, + hybrid_search, embedding_engine, embedding_model, embedding_function, @@ -283,6 +279,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid_search=hybrid_search, ) else: context = query_embeddings_doc( @@ -292,6 +289,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, + hybrid_search=hybrid_search, ) except Exception as e: log.exception(e) @@ -479,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor): (d, s) for d, s in docs_with_scores if s >= self.r_score ] - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + reverse = self.reranking_function is not None + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse) + final_results = [] for doc, doc_score in result[: self.top_n]: metadata = doc.metadata diff --git a/backend/config.py b/backend/config.py index f1c7b241..d354b116 100644 --- a/backend/config.py +++ b/backend/config.py @@ -423,6 +423,10 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) +ENABLE_RAG_HYBRID_SEARCH = ( + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +) + RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_MODEL = os.environ.get( diff --git a/backend/main.py b/backend/main.py index 1b92ae73..b0dc3a7f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, rag_app.state.RELEVANCE_THRESHOLD, + rag_app.state.ENABLE_RAG_HYBRID_SEARCH, rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef, diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c6695bb6..9fb7c677 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -43,7 +43,8 @@ let querySettings = { template: '', r: 0.0, - k: 4 + k: 4, + hybrid: false }; const scanHandler = async () => { @@ -174,6 +175,12 @@ } }; + const toggleHybridSearch = async () => { + querySettings.hybrid = !querySettings.hybrid; + + querySettings = await updateQuerySettings(localStorage.token, querySettings); + }; + onMount(async () => { const res = await getRAGConfig(localStorage.token); @@ -202,6 +209,24 @@
{$i18n.t('General Settings')}
+
+
{$i18n.t('Hybrid Search')}
+ + +
+
{$i18n.t('Embedding Model Engine')}
@@ -386,78 +411,74 @@
-
-
{$i18n.t('Update Reranking Model')}
+ {#if querySettings.hybrid === true} +
+
{$i18n.t('Update Reranking Model')}
-
-
- -
-
- {:else} - - - - - {/if} - + + + + {/if} + +
-
-
- {$i18n.t( - 'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.' - )} -
- -
+
+ {/if}
@@ -583,25 +604,27 @@
-
-
-
- {$i18n.t('Relevance Threshold')} -
+ {#if querySettings.hybrid === true} +
+
+
+ {$i18n.t('Relevance Threshold')} +
-
- +
+ +
-
+ {/if}
{$i18n.t('RAG Template')}