open-webui/backend/apps/rag/utils.py

493 lines
14 KiB
Python
Raw Normal View History

2024-04-25 14:49:59 +02:00
import os
import logging
2024-04-14 23:55:00 +02:00
import requests
from typing import List
2024-04-14 23:55:00 +02:00
from apps.ollama.main import (
generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
2024-03-09 04:26:39 +01:00
2024-04-25 14:49:59 +02:00
from huggingface_hub import snapshot_download
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
2024-04-22 22:49:58 +02:00
from langchain.retrievers import (
ContextualCompressionRetriever,
2024-04-22 22:49:58 +02:00
EnsembleRetriever,
)
2024-04-25 23:03:00 +02:00
from typing import Optional
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
2024-04-15 01:48:15 +02:00
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-03-09 04:26:39 +01:00
2024-04-22 22:49:58 +02:00
def query_embeddings_doc(
collection_name: str,
query: str,
embeddings_function,
2024-04-26 00:31:21 +02:00
reranking_function,
2024-04-25 23:03:00 +02:00
k: int,
2024-04-26 03:00:47 +02:00
r: int,
2024-04-26 20:41:39 +02:00
hybrid_search: bool,
2024-04-22 22:49:58 +02:00
):
2024-04-14 23:55:00 +02:00
try:
2024-04-26 03:00:47 +02:00
collection = CHROMA_CLIENT.get_collection(name=collection_name)
2024-04-25 23:03:00 +02:00
2024-04-26 20:41:39 +02:00
if hybrid_search:
2024-04-25 23:03:00 +02:00
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
embeddings_function=embeddings_function,
top_n=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embeddings_function=embeddings_function,
reranking_function=reranking_function,
r_score=r,
top_n=k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
else:
query_embeddings = embeddings_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
2024-04-26 03:00:47 +02:00
log.info(f"query_embeddings_doc:result {result}")
2024-04-14 23:55:00 +02:00
return result
except Exception as e:
raise e
2024-04-26 03:00:47 +02:00
def merge_and_sort_query_results(query_results, k, reverse=False):
2024-03-09 04:26:39 +01:00
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
2024-04-22 22:49:58 +02:00
combined_metadatas = []
2024-03-09 04:26:39 +01:00
for data in query_results:
combined_distances.extend(data["distances"][0])
combined_documents.extend(data["documents"][0])
2024-04-22 22:49:58 +02:00
combined_metadatas.extend(data["metadatas"][0])
2024-03-09 04:26:39 +01:00
2024-04-22 22:49:58 +02:00
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
2024-03-09 04:26:39 +01:00
# Sort the list based on distances
2024-04-26 03:00:47 +02:00
combined.sort(key=lambda x: x[0], reverse=reverse)
2024-03-09 04:26:39 +01:00
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
2024-03-09 04:26:39 +01:00
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
2024-03-09 04:26:39 +01:00
# Create the output dictionary
result = {
2024-03-09 04:26:39 +01:00
"distances": [sorted_distances],
"documents": [sorted_documents],
2024-04-22 22:49:58 +02:00
"metadatas": [sorted_metadatas],
2024-03-09 04:26:39 +01:00
}
return result
2024-03-09 04:26:39 +01:00
def query_embeddings_collection(
2024-04-22 22:49:58 +02:00
collection_names: List[str],
query: str,
k: int,
r: float,
2024-04-22 22:49:58 +02:00
embeddings_function,
reranking_function,
2024-04-26 20:41:39 +02:00
hybrid_search: bool,
2024-03-09 04:26:39 +01:00
):
2024-04-14 23:55:00 +02:00
results = []
2024-04-15 00:47:45 +02:00
2024-04-14 23:55:00 +02:00
for collection_name in collection_names:
try:
result = query_embeddings_doc(
collection_name=collection_name,
query=query,
k=k,
r=r,
2024-04-22 22:49:58 +02:00
embeddings_function=embeddings_function,
reranking_function=reranking_function,
2024-04-26 20:41:39 +02:00
hybrid_search=hybrid_search,
2024-04-14 23:55:00 +02:00
)
results.append(result)
except:
pass
2024-04-26 03:00:47 +02:00
reverse = hybrid and reranking_function is not None
return merge_and_sort_query_results(results, k=k, reverse=reverse)
2024-04-14 23:55:00 +02:00
2024-03-09 07:34:47 +01:00
def rag_template(template: str, context: str, query: str):
2024-03-15 21:34:52 +01:00
template = template.replace("[context]", context)
template = template.replace("[query]", query)
2024-03-09 07:34:47 +01:00
return template
2024-03-11 02:40:50 +01:00
2024-04-22 22:49:58 +02:00
def query_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
2024-04-22 22:49:58 +02:00
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
2024-04-22 22:49:58 +02:00
2024-04-15 01:48:15 +02:00
def rag_messages(
docs,
messages,
template,
k,
r,
2024-04-26 20:41:39 +02:00
hybrid_search,
2024-04-15 01:48:15 +02:00
embedding_engine,
embedding_model,
embedding_function,
2024-04-22 22:49:58 +02:00
reranking_function,
2024-04-15 01:48:15 +02:00
openai_key,
openai_url,
):
2024-04-15 01:56:33 +02:00
log.debug(
2024-04-22 22:49:58 +02:00
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
2024-04-15 01:56:33 +02:00
)
2024-03-11 02:40:50 +01:00
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
embeddings_function = query_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
)
extracted_collections = []
2024-03-11 02:40:50 +01:00
relevant_contexts = []
for doc in docs:
context = None
collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection = set(collection).difference(extracted_collections)
if not collection:
log.debug(f"skipping {doc} as it has already been extracted")
continue
2024-04-15 01:48:15 +02:00
try:
2024-04-15 01:48:15 +02:00
if doc["type"] == "text":
2024-03-24 08:40:27 +01:00
context = doc["content"]
elif doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
2024-04-26 20:41:39 +02:00
hybrid_search=hybrid_search,
)
2024-03-11 02:40:50 +01:00
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
2024-04-26 20:41:39 +02:00
hybrid_search=hybrid_search,
2024-04-22 22:49:58 +02:00
)
2024-03-11 02:40:50 +01:00
except Exception as e:
log.exception(e)
2024-03-11 02:40:50 +01:00
context = None
if context:
relevant_contexts.append(context)
extracted_collections.extend(collection)
2024-03-11 02:40:50 +01:00
context_string = ""
for context in relevant_contexts:
items = context["documents"][0]
context_string += "\n\n".join(items)
context_string = context_string.strip()
2024-03-11 02:40:50 +01:00
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
2024-03-11 02:40:50 +01:00
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages
2024-04-04 20:07:42 +02:00
2024-04-25 14:49:59 +02:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
2024-04-25 20:28:31 +02:00
log.debug(f"model: {model}")
2024-04-25 14:49:59 +02:00
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
2024-04-15 01:15:39 +02:00
def generate_openai_embeddings(
2024-04-20 22:15:59 +02:00
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
2024-04-15 01:15:39 +02:00
):
try:
r = requests.post(
2024-04-20 22:15:59 +02:00
f"{url}/embeddings",
2024-04-15 01:15:39 +02:00
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None
2024-04-22 22:49:58 +02:00
from typing import Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
2024-04-22 22:49:58 +02:00
class ChromaRetriever(BaseRetriever):
collection: Any
embeddings_function: Any
top_n: int
2024-04-22 22:49:58 +02:00
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embeddings_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
2024-04-22 22:49:58 +02:00
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
return [
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
for idx in range(len(ids))
]
import operator
from typing import Optional, Sequence
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embeddings_function: Any
reranking_function: Any
r_score: float
top_n: int
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
if self.reranking_function:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
query_embedding = self.embeddings_function(query)
document_embedding = self.embeddings_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
2024-04-26 03:00:47 +02:00
reverse = self.reranking_function is not None
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results