fix: sort ranking hybrid

This commit is contained in:
Steven Kreitzer 2024-04-25 20:00:47 -05:00
parent 9755cd5baa
commit 69822e4c25
2 changed files with 13 additions and 17 deletions

View file

@ -18,8 +18,6 @@ from langchain.retrievers import (
EnsembleRetriever, EnsembleRetriever,
) )
from sentence_transformers import CrossEncoder
from typing import Optional from typing import Optional
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -34,14 +32,13 @@ def query_embeddings_doc(
embeddings_function, embeddings_function,
reranking_function, reranking_function,
k: int, k: int,
r: Optional[float] = None, r: int,
hybrid: Optional[bool] = False, hybrid: bool,
): ):
try: try:
if hybrid: collection = CHROMA_CLIENT.get_collection(name=collection_name)
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(name=collection_name)
if hybrid:
documents = collection.get() # get all documents documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts( bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"), texts=documents.get("documents"),
@ -77,24 +74,19 @@ def query_embeddings_doc(
"metadatas": [[d.metadata for d in result]], "metadatas": [[d.metadata for d in result]],
} }
else: else:
# if you use docker use the model from the environment variable
query_embeddings = embeddings_function(query) query_embeddings = embeddings_function(query)
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query( result = collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
) )
log.info(f"query_embeddings_doc:result {result}") log.info(f"query_embeddings_doc:result {result}")
return result return result
except Exception as e: except Exception as e:
raise e raise e
def merge_and_sort_query_results(query_results, k): def merge_and_sort_query_results(query_results, k, reverse=False):
# Initialize lists to store combined data # Initialize lists to store combined data
combined_distances = [] combined_distances = []
combined_documents = [] combined_documents = []
@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k):
combined = list(zip(combined_distances, combined_documents, combined_metadatas)) combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances # Sort the list based on distances
combined.sort(key=lambda x: x[0]) combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-( # We don't have anything :-(
if not combined: if not combined:
@ -162,7 +154,8 @@ def query_embeddings_collection(
except: except:
pass pass
return merge_and_sort_query_results(results, k) reverse = hybrid and reranking_function is not None
return merge_and_sort_query_results(results, k=k, reverse=reverse)
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
@ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor):
(d, s) for d, s in docs_with_scores if s >= self.r_score (d, s) for d, s in docs_with_scores if s >= self.r_score
] ]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) reverse = self.reranking_function is not None
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
final_results = [] final_results = []
for doc, doc_score in result[: self.top_n]: for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata metadata = doc.metadata

View file

@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
rag_app.state.RAG_TEMPLATE, rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, rag_app.state.TOP_K,
rag_app.state.RELEVANCE_THRESHOLD, rag_app.state.RELEVANCE_THRESHOLD,
rag_app.state.HYBRID,
rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef, rag_app.state.sentence_transformer_ef,