From cebf733b9d0e188e1fd903707a2342678008f4ff Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 26 Apr 2024 14:41:39 -0400 Subject: [PATCH] refac: naming convention --- backend/apps/rag/main.py | 23 ++++++++++++++++------- backend/apps/rag/utils.py | 14 +++++++------- backend/config.py | 5 ++++- backend/main.py | 2 +- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index d72fde74..654b2481 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -70,7 +70,7 @@ from config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_HYBRID, + ENABLE_RAG_HYBRID_SEARCH, RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, @@ -92,7 +92,8 @@ app = FastAPI() app.state.TOP_K = RAG_TOP_K app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD -app.state.HYBRID = RAG_HYBRID + +app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP @@ -324,7 +325,7 @@ async def get_query_settings(user=Depends(get_admin_user)): "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.HYBRID, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, } @@ -342,13 +343,13 @@ async def update_query_settings( app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.TOP_K = form_data.k if form_data.k else 4 app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 - app.state.HYBRID = form_data.hybrid if form_data.hybrid else False + app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False return { "status": True, "template": app.state.RAG_TEMPLATE, "k": app.state.TOP_K, "r": app.state.RELEVANCE_THRESHOLD, - "hybrid": app.state.HYBRID, + "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, } @@ -381,7 +382,11 @@ def query_doc_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, - hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) @@ -420,7 +425,11 @@ def query_collection_handler( r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, embeddings_function=embeddings_function, reranking_function=app.state.sentence_transformer_rf, - hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, + hybrid_search=( + form_data.hybrid + if form_data.hybrid + else app.state.ENABLE_RAG_HYBRID_SEARCH + ), ) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 62c29b2b..e9fe8319 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -33,12 +33,12 @@ def query_embeddings_doc( reranking_function, k: int, r: int, - hybrid: bool, + hybrid_search: bool, ): try: collection = CHROMA_CLIENT.get_collection(name=collection_name) - if hybrid: + if hybrid_search: documents = collection.get() # get all documents bm25_retriever = BM25Retriever.from_texts( texts=documents.get("documents"), @@ -134,7 +134,7 @@ def query_embeddings_collection( r: float, embeddings_function, reranking_function, - hybrid: bool, + hybrid_search: bool, ): results = [] @@ -148,7 +148,7 @@ def query_embeddings_collection( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) results.append(result) except: @@ -206,7 +206,7 @@ def rag_messages( template, k, r, - hybrid, + hybrid_search, embedding_engine, embedding_model, embedding_function, @@ -279,7 +279,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) else: context = query_embeddings_doc( @@ -289,7 +289,7 @@ def rag_messages( r=r, embeddings_function=embeddings_function, reranking_function=reranking_function, - hybrid=hybrid, + hybrid_search=hybrid_search, ) except Exception as e: log.exception(e) diff --git a/backend/config.py b/backend/config.py index e60a789b..f67fd017 100644 --- a/backend/config.py +++ b/backend/config.py @@ -422,7 +422,10 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) -RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" + +ENABLE_RAG_HYBRID_SEARCH = ( + os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" +) RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") diff --git a/backend/main.py b/backend/main.py index 284d8371..b0dc3a7f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -121,7 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware): rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, rag_app.state.RELEVANCE_THRESHOLD, - rag_app.state.HYBRID, + rag_app.state.ENABLE_RAG_HYBRID_SEARCH, rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef,