forked from open-webui/open-webui
refac: rag pipeline
This commit is contained in:
parent
8f1563a7a5
commit
ce9a5d12e0
3 changed files with 179 additions and 154 deletions
|
@ -47,9 +47,11 @@ from apps.web.models.documents import (
|
||||||
|
|
||||||
from apps.rag.utils import (
|
from apps.rag.utils import (
|
||||||
get_model_path,
|
get_model_path,
|
||||||
query_embeddings_doc,
|
get_embedding_function,
|
||||||
get_embeddings_function,
|
query_doc,
|
||||||
query_embeddings_collection,
|
query_doc_with_hybrid_search,
|
||||||
|
query_collection,
|
||||||
|
query_collection_with_hybrid_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
|
@ -147,6 +149,15 @@ update_reranking_model(
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
|
app.state.RAG_EMBEDDING_ENGINE,
|
||||||
|
app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
app.state.sentence_transformer_ef,
|
||||||
|
app.state.OPENAI_API_KEY,
|
||||||
|
app.state.OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,6 +238,14 @@ async def update_embedding_config(
|
||||||
|
|
||||||
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
||||||
|
|
||||||
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
|
app.state.RAG_EMBEDDING_ENGINE,
|
||||||
|
app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
app.state.sentence_transformer_ef,
|
||||||
|
app.state.OPENAI_API_KEY,
|
||||||
|
app.state.OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||||
|
@ -367,27 +386,22 @@ def query_doc_handler(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
embeddings_function = get_embeddings_function(
|
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
return query_doc_with_hybrid_search(
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
collection_name=form_data.collection_name,
|
||||||
app.state.sentence_transformer_ef,
|
query=form_data.query,
|
||||||
app.state.OPENAI_API_KEY,
|
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||||
app.state.OPENAI_API_BASE_URL,
|
reranking_function=app.state.sentence_transformer_rf,
|
||||||
)
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||||
return query_embeddings_doc(
|
)
|
||||||
collection_name=form_data.collection_name,
|
else:
|
||||||
query=form_data.query,
|
return query_doc(
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
collection_name=form_data.collection_name,
|
||||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
query=form_data.query,
|
||||||
embeddings_function=embeddings_function,
|
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||||
reranking_function=app.state.sentence_transformer_rf,
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
hybrid_search=(
|
)
|
||||||
form_data.hybrid
|
|
||||||
if form_data.hybrid
|
|
||||||
else app.state.ENABLE_RAG_HYBRID_SEARCH
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -410,27 +424,23 @@ def query_collection_handler(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
embeddings_function = get_embeddings_function(
|
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
return query_collection_with_hybrid_search(
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
collection_names=form_data.collection_names,
|
||||||
app.state.sentence_transformer_ef,
|
query=form_data.query,
|
||||||
app.state.OPENAI_API_KEY,
|
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||||
app.state.OPENAI_API_BASE_URL,
|
reranking_function=app.state.sentence_transformer_rf,
|
||||||
)
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return query_collection(
|
||||||
|
collection_names=form_data.collection_names,
|
||||||
|
query=form_data.query,
|
||||||
|
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||||
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
)
|
||||||
|
|
||||||
return query_embeddings_collection(
|
|
||||||
collection_names=form_data.collection_names,
|
|
||||||
query=form_data.query,
|
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
||||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
|
||||||
embeddings_function=embeddings_function,
|
|
||||||
reranking_function=app.state.sentence_transformer_rf,
|
|
||||||
hybrid_search=(
|
|
||||||
form_data.hybrid
|
|
||||||
if form_data.hybrid
|
|
||||||
else app.state.ENABLE_RAG_HYBRID_SEARCH
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||||
|
|
||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||||
|
|
||||||
embedding_func = get_embeddings_function(
|
embedding_func = get_embedding_function(
|
||||||
app.state.RAG_EMBEDDING_ENGINE,
|
app.state.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
app.state.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.sentence_transformer_ef,
|
||||||
|
|
|
@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
def query_embeddings_doc(
|
def query_doc(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query: str,
|
query: str,
|
||||||
embeddings_function,
|
embedding_function,
|
||||||
reranking_function,
|
|
||||||
k: int,
|
k: int,
|
||||||
r: int,
|
|
||||||
hybrid_search: bool,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||||
|
query_embeddings = embedding_function(query)
|
||||||
|
result = collection.query(
|
||||||
|
query_embeddings=[query_embeddings],
|
||||||
|
n_results=k,
|
||||||
|
)
|
||||||
|
|
||||||
if hybrid_search:
|
log.info(f"query_doc:result {result}")
|
||||||
documents = collection.get() # get all documents
|
return result
|
||||||
bm25_retriever = BM25Retriever.from_texts(
|
except Exception as e:
|
||||||
texts=documents.get("documents"),
|
raise e
|
||||||
metadatas=documents.get("metadatas"),
|
|
||||||
)
|
|
||||||
bm25_retriever.k = k
|
|
||||||
|
|
||||||
chroma_retriever = ChromaRetriever(
|
|
||||||
collection=collection,
|
|
||||||
embeddings_function=embeddings_function,
|
|
||||||
top_n=k,
|
|
||||||
)
|
|
||||||
|
|
||||||
ensemble_retriever = EnsembleRetriever(
|
def query_doc_with_hybrid_search(
|
||||||
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
collection_name: str,
|
||||||
)
|
query: str,
|
||||||
|
embedding_function,
|
||||||
|
k: int,
|
||||||
|
reranking_function,
|
||||||
|
r: int,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||||
|
documents = collection.get() # get all documents
|
||||||
|
|
||||||
compressor = RerankCompressor(
|
bm25_retriever = BM25Retriever.from_texts(
|
||||||
embeddings_function=embeddings_function,
|
texts=documents.get("documents"),
|
||||||
reranking_function=reranking_function,
|
metadatas=documents.get("metadatas"),
|
||||||
r_score=r,
|
)
|
||||||
top_n=k,
|
bm25_retriever.k = k
|
||||||
)
|
|
||||||
|
|
||||||
compression_retriever = ContextualCompressionRetriever(
|
chroma_retriever = ChromaRetriever(
|
||||||
base_compressor=compressor, base_retriever=ensemble_retriever
|
collection=collection,
|
||||||
)
|
embedding_function=embedding_function,
|
||||||
|
top_n=k,
|
||||||
|
)
|
||||||
|
|
||||||
result = compression_retriever.invoke(query)
|
ensemble_retriever = EnsembleRetriever(
|
||||||
result = {
|
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
|
||||||
"distances": [[d.metadata.get("score") for d in result]],
|
)
|
||||||
"documents": [[d.page_content for d in result]],
|
|
||||||
"metadatas": [[d.metadata for d in result]],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
query_embeddings = embeddings_function(query)
|
|
||||||
result = collection.query(
|
|
||||||
query_embeddings=[query_embeddings],
|
|
||||||
n_results=k,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(f"query_embeddings_doc:result {result}")
|
compressor = RerankCompressor(
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
reranking_function=reranking_function,
|
||||||
|
r_score=r,
|
||||||
|
top_n=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
compression_retriever = ContextualCompressionRetriever(
|
||||||
|
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||||
|
)
|
||||||
|
|
||||||
|
result = compression_retriever.invoke(query)
|
||||||
|
result = {
|
||||||
|
"distances": [[d.metadata.get("score") for d in result]],
|
||||||
|
"documents": [[d.page_content for d in result]],
|
||||||
|
"metadatas": [[d.metadata for d in result]],
|
||||||
|
}
|
||||||
|
log.info(f"query_doc_with_hybrid_search:result {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def query_embeddings_collection(
|
def query_collection(
|
||||||
collection_names: List[str],
|
collection_names: List[str],
|
||||||
query: str,
|
query: str,
|
||||||
|
embedding_function,
|
||||||
k: int,
|
k: int,
|
||||||
r: float,
|
|
||||||
embeddings_function,
|
|
||||||
reranking_function,
|
|
||||||
hybrid_search: bool,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
result = query_embeddings_doc(
|
result = query_doc(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query=query,
|
query=query,
|
||||||
k=k,
|
k=k,
|
||||||
r=r,
|
embedding_function=embedding_function,
|
||||||
embeddings_function=embeddings_function,
|
)
|
||||||
|
results.append(result)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return merge_and_sort_query_results(results, k=k)
|
||||||
|
|
||||||
|
|
||||||
|
def query_collection_with_hybrid_search(
|
||||||
|
collection_names: List[str],
|
||||||
|
query: str,
|
||||||
|
embedding_function,
|
||||||
|
k: int,
|
||||||
|
reranking_function,
|
||||||
|
r: float,
|
||||||
|
):
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for collection_name in collection_names:
|
||||||
|
try:
|
||||||
|
result = query_doc_with_hybrid_search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query,
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
k=k,
|
||||||
reranking_function=reranking_function,
|
reranking_function=reranking_function,
|
||||||
hybrid_search=hybrid_search,
|
r=r,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
reverse = hybrid_search and reranking_function is not None
|
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||||
return merge_and_sort_query_results(results, k=k, reverse=reverse)
|
|
||||||
|
|
||||||
|
|
||||||
def rag_template(template: str, context: str, query: str):
|
def rag_template(template: str, context: str, query: str):
|
||||||
|
@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_function(
|
def get_embedding_function(
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
|
@ -204,19 +232,13 @@ def rag_messages(
|
||||||
docs,
|
docs,
|
||||||
messages,
|
messages,
|
||||||
template,
|
template,
|
||||||
|
embedding_function,
|
||||||
k,
|
k,
|
||||||
|
reranking_function,
|
||||||
r,
|
r,
|
||||||
hybrid_search,
|
hybrid_search,
|
||||||
embedding_engine,
|
|
||||||
embedding_model,
|
|
||||||
embedding_function,
|
|
||||||
reranking_function,
|
|
||||||
openai_key,
|
|
||||||
openai_url,
|
|
||||||
):
|
):
|
||||||
log.debug(
|
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
||||||
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
|
|
||||||
)
|
|
||||||
|
|
||||||
last_user_message_idx = None
|
last_user_message_idx = None
|
||||||
for i in range(len(messages) - 1, -1, -1):
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
@ -243,14 +265,6 @@ def rag_messages(
|
||||||
content_type = None
|
content_type = None
|
||||||
query = ""
|
query = ""
|
||||||
|
|
||||||
embeddings_function = get_embeddings_function(
|
|
||||||
embedding_engine,
|
|
||||||
embedding_model,
|
|
||||||
embedding_function,
|
|
||||||
openai_key,
|
|
||||||
openai_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted_collections = []
|
extracted_collections = []
|
||||||
relevant_contexts = []
|
relevant_contexts = []
|
||||||
|
|
||||||
|
@ -271,26 +285,31 @@ def rag_messages(
|
||||||
try:
|
try:
|
||||||
if doc["type"] == "text":
|
if doc["type"] == "text":
|
||||||
context = doc["content"]
|
context = doc["content"]
|
||||||
elif doc["type"] == "collection":
|
|
||||||
context = query_embeddings_collection(
|
|
||||||
collection_names=doc["collection_names"],
|
|
||||||
query=query,
|
|
||||||
k=k,
|
|
||||||
r=r,
|
|
||||||
embeddings_function=embeddings_function,
|
|
||||||
reranking_function=reranking_function,
|
|
||||||
hybrid_search=hybrid_search,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
context = query_embeddings_doc(
|
if hybrid_search:
|
||||||
collection_name=doc["collection_name"],
|
context = query_collection_with_hybrid_search(
|
||||||
query=query,
|
collection_names=(
|
||||||
k=k,
|
doc["collection_names"]
|
||||||
r=r,
|
if doc["type"] == "collection"
|
||||||
embeddings_function=embeddings_function,
|
else [doc["collection_name"]]
|
||||||
reranking_function=reranking_function,
|
),
|
||||||
hybrid_search=hybrid_search,
|
query=query,
|
||||||
)
|
embedding_function=embedding_function,
|
||||||
|
k=k,
|
||||||
|
reranking_function=reranking_function,
|
||||||
|
r=r,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context = query_collection(
|
||||||
|
collection_names=(
|
||||||
|
doc["collection_names"]
|
||||||
|
if doc["type"] == "collection"
|
||||||
|
else [doc["collection_name"]]
|
||||||
|
),
|
||||||
|
query=query,
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
context = None
|
context = None
|
||||||
|
@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
|
|
||||||
class ChromaRetriever(BaseRetriever):
|
class ChromaRetriever(BaseRetriever):
|
||||||
collection: Any
|
collection: Any
|
||||||
embeddings_function: Any
|
embedding_function: Any
|
||||||
top_n: int
|
top_n: int
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
|
@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForRetrieverRun,
|
run_manager: CallbackManagerForRetrieverRun,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
query_embeddings = self.embeddings_function(query)
|
query_embeddings = self.embedding_function(query)
|
||||||
|
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
query_embeddings=[query_embeddings],
|
query_embeddings=[query_embeddings],
|
||||||
|
@ -445,7 +464,7 @@ from sentence_transformers import util
|
||||||
|
|
||||||
|
|
||||||
class RerankCompressor(BaseDocumentCompressor):
|
class RerankCompressor(BaseDocumentCompressor):
|
||||||
embeddings_function: Any
|
embedding_function: Any
|
||||||
reranking_function: Any
|
reranking_function: Any
|
||||||
r_score: float
|
r_score: float
|
||||||
top_n: int
|
top_n: int
|
||||||
|
@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||||
[(query, doc.page_content) for doc in documents]
|
[(query, doc.page_content) for doc in documents]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query_embedding = self.embeddings_function(query)
|
query_embedding = self.embedding_function(query)
|
||||||
document_embedding = self.embeddings_function(
|
document_embedding = self.embedding_function(
|
||||||
[doc.page_content for doc in documents]
|
[doc.page_content for doc in documents]
|
||||||
)
|
)
|
||||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||||
|
|
|
@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||||
if "docs" in data:
|
if "docs" in data:
|
||||||
data = {**data}
|
data = {**data}
|
||||||
data["messages"] = rag_messages(
|
data["messages"] = rag_messages(
|
||||||
data["docs"],
|
docs=data["docs"],
|
||||||
data["messages"],
|
messages=data["messages"],
|
||||||
rag_app.state.RAG_TEMPLATE,
|
template=rag_app.state.RAG_TEMPLATE,
|
||||||
rag_app.state.TOP_K,
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||||
rag_app.state.RELEVANCE_THRESHOLD,
|
k=rag_app.state.TOP_K,
|
||||||
rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||||
rag_app.state.RAG_EMBEDDING_ENGINE,
|
r=rag_app.state.RELEVANCE_THRESHOLD,
|
||||||
rag_app.state.RAG_EMBEDDING_MODEL,
|
hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
rag_app.state.sentence_transformer_ef,
|
|
||||||
rag_app.state.sentence_transformer_rf,
|
|
||||||
rag_app.state.OPENAI_API_KEY,
|
|
||||||
rag_app.state.OPENAI_API_BASE_URL,
|
|
||||||
)
|
)
|
||||||
del data["docs"]
|
del data["docs"]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue