feat: hybrid search

This commit is contained in:
Steven Kreitzer 2024-04-22 15:49:58 -05:00 committed by Steven Kreitzer
parent f3e5700d49
commit 4e0b32b505
7 changed files with 406 additions and 110 deletions

View file

@ -1,5 +1,8 @@
import logging
import requests
import operator
import sentence_transformers
from typing import List
@ -8,6 +11,11 @@ from apps.ollama.main import (
GenerateEmbeddingsForm,
)
from langchain.retrievers import (
BM25Retriever,
EnsembleRetriever,
)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -15,60 +23,96 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
def query_embeddings_doc(
collection_name: str,
query: str,
k: int,
embeddings_function,
reranking_function,
):
try:
# if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
# keyword search
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
# semantic search (vector)
chroma_retriever = ChromaRetriever(
collection=collection,
k=k,
embeddings_function=embeddings_function,
)
log.info(f"query_embeddings_doc:result {result}")
# hybrid search (ensemble)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever],
weights=[0.6, 0.4]
)
documents = ensemble_retriever.invoke(query)
result = query_results_rank(
query=query,
documents=documents,
k=k,
reranking_function=reranking_function,
)
result = {
"distances": [[d[1].item() for d in result]],
"documents": [[d[0].page_content for d in result]],
"metadatas": [[d[0].metadata for d in result]],
}
return result
except Exception as e:
raise e
def query_results_rank(query: str, documents, k: int, reranking_function):
scores = reranking_function.predict([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return result[: k]
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
combined_metadatas = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, id, metadata, document)
# Create a list of tuples (distance, document, metadata)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
zip(combined_distances, combined_documents, combined_metadatas)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
"embeddings": None,
"uris": None,
"data": None,
@ -78,19 +122,23 @@ def merge_and_sort_query_results(query_results, k):
def query_embeddings_collection(
collection_names: List[str], query: str, query_embeddings, k: int
collection_names: List[str],
query: str,
k: int,
embeddings_function,
reranking_function,
):
results = []
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names:
try:
result = query_embeddings_doc(
collection_name=collection_name,
query=query,
query_embeddings=query_embeddings,
k=k,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
)
results.append(result)
except:
@ -105,6 +153,33 @@ def rag_template(template: str, context: str, query: str):
return template
def query_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine == "ollama":
return lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
return lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def rag_messages(
docs,
messages,
@ -113,11 +188,12 @@ def rag_messages(
embedding_engine,
embedding_model,
embedding_function,
reranking_function,
openai_key,
openai_url,
):
log.debug(
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
)
last_user_message_idx = None
@ -155,38 +231,29 @@ def rag_messages(
if doc["type"] == "text":
context = doc["content"]
else:
if embedding_engine == "":
query_embeddings = embedding_function.encode(query).tolist()
elif embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
embeddings_function = query_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
query_embeddings=query_embeddings,
k=k,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query=query,
query_embeddings=query_embeddings,
k=k,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
)
except Exception as e:
@ -250,3 +317,41 @@ def generate_openai_embeddings(
except Exception as e:
print(e)
return None
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ChromaRetriever(BaseRetriever):
collection: Any
k: int
embeddings_function: Any
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embeddings_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.k,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
return [
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
for idx in range(len(ids))
]