forked from open-webui/open-webui
fix: sort ranking hybrid
This commit is contained in:
parent
9755cd5baa
commit
69822e4c25
2 changed files with 13 additions and 17 deletions
|
@ -18,8 +18,6 @@ from langchain.retrievers import (
|
||||||
EnsembleRetriever,
|
EnsembleRetriever,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sentence_transformers import CrossEncoder
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||||
|
|
||||||
|
@ -34,14 +32,13 @@ def query_embeddings_doc(
|
||||||
embeddings_function,
|
embeddings_function,
|
||||||
reranking_function,
|
reranking_function,
|
||||||
k: int,
|
k: int,
|
||||||
r: Optional[float] = None,
|
r: int,
|
||||||
hybrid: Optional[bool] = False,
|
hybrid: bool,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if hybrid:
|
|
||||||
# if you use docker use the model from the environment variable
|
|
||||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||||
|
|
||||||
|
if hybrid:
|
||||||
documents = collection.get() # get all documents
|
documents = collection.get() # get all documents
|
||||||
bm25_retriever = BM25Retriever.from_texts(
|
bm25_retriever = BM25Retriever.from_texts(
|
||||||
texts=documents.get("documents"),
|
texts=documents.get("documents"),
|
||||||
|
@ -77,12 +74,7 @@ def query_embeddings_doc(
|
||||||
"metadatas": [[d.metadata for d in result]],
|
"metadatas": [[d.metadata for d in result]],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# if you use docker use the model from the environment variable
|
|
||||||
query_embeddings = embeddings_function(query)
|
query_embeddings = embeddings_function(query)
|
||||||
|
|
||||||
log.info(f"query_embeddings_doc {query_embeddings}")
|
|
||||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
|
||||||
|
|
||||||
result = collection.query(
|
result = collection.query(
|
||||||
query_embeddings=[query_embeddings],
|
query_embeddings=[query_embeddings],
|
||||||
n_results=k,
|
n_results=k,
|
||||||
|
@ -94,7 +86,7 @@ def query_embeddings_doc(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def merge_and_sort_query_results(query_results, k):
|
def merge_and_sort_query_results(query_results, k, reverse=False):
|
||||||
# Initialize lists to store combined data
|
# Initialize lists to store combined data
|
||||||
combined_distances = []
|
combined_distances = []
|
||||||
combined_documents = []
|
combined_documents = []
|
||||||
|
@ -109,7 +101,7 @@ def merge_and_sort_query_results(query_results, k):
|
||||||
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
||||||
|
|
||||||
# Sort the list based on distances
|
# Sort the list based on distances
|
||||||
combined.sort(key=lambda x: x[0])
|
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||||
|
|
||||||
# We don't have anything :-(
|
# We don't have anything :-(
|
||||||
if not combined:
|
if not combined:
|
||||||
|
@ -162,7 +154,8 @@ def query_embeddings_collection(
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return merge_and_sort_query_results(results, k)
|
reverse = hybrid and reranking_function is not None
|
||||||
|
return merge_and_sort_query_results(results, k=k, reverse=reverse)
|
||||||
|
|
||||||
|
|
||||||
def rag_template(template: str, context: str, query: str):
|
def rag_template(template: str, context: str, query: str):
|
||||||
|
@ -484,7 +477,9 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||||
]
|
]
|
||||||
|
|
||||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
reverse = self.reranking_function is not None
|
||||||
|
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for doc, doc_score in result[: self.top_n]:
|
for doc, doc_score in result[: self.top_n]:
|
||||||
metadata = doc.metadata
|
metadata = doc.metadata
|
||||||
|
|
|
@ -121,6 +121,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||||
rag_app.state.RAG_TEMPLATE,
|
rag_app.state.RAG_TEMPLATE,
|
||||||
rag_app.state.TOP_K,
|
rag_app.state.TOP_K,
|
||||||
rag_app.state.RELEVANCE_THRESHOLD,
|
rag_app.state.RELEVANCE_THRESHOLD,
|
||||||
|
rag_app.state.HYBRID,
|
||||||
rag_app.state.RAG_EMBEDDING_ENGINE,
|
rag_app.state.RAG_EMBEDDING_ENGINE,
|
||||||
rag_app.state.RAG_EMBEDDING_MODEL,
|
rag_app.state.RAG_EMBEDDING_MODEL,
|
||||||
rag_app.state.sentence_transformer_ef,
|
rag_app.state.sentence_transformer_ef,
|
||||||
|
|
Loading…
Reference in a new issue