diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 654b2481..405546dd 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -48,7 +48,7 @@ from apps.web.models.documents import ( from apps.rag.utils import ( get_model_path, query_embeddings_doc, - query_embeddings_function, + get_embeddings_function, query_embeddings_collection, ) @@ -367,7 +367,7 @@ def query_doc_handler( user=Depends(get_current_user), ): try: - embeddings_function = query_embeddings_function( + embeddings_function = get_embeddings_function( app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, @@ -410,7 +410,7 @@ def query_collection_handler( user=Depends(get_current_user), ): try: - embeddings_function = query_embeddings_function( + embeddings_function = get_embeddings_function( app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, @@ -508,7 +508,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection = CHROMA_CLIENT.create_collection(name=collection_name) - embedding_func = query_embeddings_function( + embedding_func = get_embeddings_function( app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_MODEL, app.state.sentence_transformer_ef, diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index e9fe8319..668f38e4 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -164,7 +164,7 @@ def rag_template(template: str, context: str, query: str): return template -def query_embeddings_function( +def get_embeddings_function( embedding_engine, embedding_model, embedding_function, @@ -243,7 +243,7 @@ def rag_messages( content_type = None query = "" - embeddings_function = query_embeddings_function( + embeddings_function = get_embeddings_function( embedding_engine, embedding_model, embedding_function,