From 7e5e2c42c942a52388ca73bb1edc3ffaaadf76a6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 19:26:39 -0800 Subject: [PATCH] refac: rag routes --- backend/apps/rag/main.py | 89 ++++++--------------------------------- backend/apps/rag/utils.py | 89 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 75 deletions(-) create mode 100644 backend/apps/rag/utils.py diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 45ad6970..6781a9a1 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -44,6 +44,8 @@ from apps.web.models.documents import ( DocumentResponse, ) +from apps.rag.utils import query_doc, query_collection + from utils.misc import ( calculate_sha256, calculate_sha256_string, @@ -248,21 +250,18 @@ class QueryDocForm(BaseModel): @app.post("/query/doc") -def query_doc( +def query_doc_handler( form_data: QueryDocForm, user=Depends(get_current_user), ): + try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, embedding_function=app.state.sentence_transformer_ef, ) - result = collection.query( - query_texts=[form_data.query], - n_results=form_data.k if form_data.k else app.state.TOP_K, - ) - return result except Exception as e: print(e) raise HTTPException( @@ -277,76 +276,16 @@ class QueryCollectionsForm(BaseModel): k: Optional[int] = None -def merge_and_sort_query_results(query_results, k): - # Initialize lists to store combined data - combined_ids = [] - combined_distances = [] - combined_metadatas = [] - combined_documents = [] - - # Combine data from each dictionary - for data in query_results: - combined_ids.extend(data["ids"][0]) - combined_distances.extend(data["distances"][0]) - combined_metadatas.extend(data["metadatas"][0]) - combined_documents.extend(data["documents"][0]) - - # Create a list of tuples (distance, id, metadata, document) - combined = list( - zip(combined_distances, combined_ids, combined_metadatas, combined_documents) - ) - - # Sort the list based on distances - combined.sort(key=lambda x: x[0]) - - # Unzip the sorted list - sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) - - # Slicing the lists to include only k elements - sorted_distances = list(sorted_distances)[:k] - sorted_ids = list(sorted_ids)[:k] - sorted_metadatas = list(sorted_metadatas)[:k] - sorted_documents = list(sorted_documents)[:k] - - # Create the output dictionary - merged_query_results = { - "ids": [sorted_ids], - "distances": [sorted_distances], - "metadatas": [sorted_metadatas], - "documents": [sorted_documents], - "embeddings": None, - "uris": None, - "data": None, - } - - return merged_query_results - - @app.post("/query/collection") -def query_collection( +def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): - results = [] - - for collection_name in form_data.collection_names: - try: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - result = collection.query( - query_texts=[form_data.query], - n_results=form_data.k if form_data.k else app.state.TOP_K, - ) - results.append(result) - except: - pass - - return merge_and_sort_query_results( - results, form_data.k if form_data.k else app.state.TOP_K + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, ) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py new file mode 100644 index 00000000..19374397 --- /dev/null +++ b/backend/apps/rag/utils.py @@ -0,0 +1,89 @@ +from typing import List + +from config import CHROMA_CLIENT + + +def query_doc(collection_name: str, query: str, k: int, embedding_function): + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + result = collection.query( + query_texts=[query], + n_results=k, + ) + return result + except Exception as e: + raise e + + +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + +def query_collection( + collection_names: List[str], query: str, k: int, embedding_function +): + + results = [] + + for collection_name in collection_names: + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=embedding_function, + ) + + result = collection.query( + query_texts=[query], + n_results=k, + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, k)