refac: rag pipeline

This commit is contained in:
Timothy J. Baek 2024-04-27 15:38:50 -04:00
parent 8f1563a7a5
commit ce9a5d12e0
3 changed files with 179 additions and 154 deletions

View file

@ -47,9 +47,11 @@ from apps.web.models.documents import (
from apps.rag.utils import ( from apps.rag.utils import (
get_model_path, get_model_path,
query_embeddings_doc, get_embedding_function,
get_embeddings_function, query_doc,
query_embeddings_collection, query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
) )
from utils.misc import ( from utils.misc import (
@ -147,6 +149,15 @@ update_reranking_model(
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
origins = ["*"] origins = ["*"]
@ -227,6 +238,14 @@ async def update_embedding_config(
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
@ -367,27 +386,22 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
embeddings_function = get_embeddings_function( if app.state.ENABLE_RAG_HYBRID_SEARCH:
app.state.RAG_EMBEDDING_ENGINE, return query_doc_with_hybrid_search(
app.state.RAG_EMBEDDING_MODEL, collection_name=form_data.collection_name,
app.state.sentence_transformer_ef, query=form_data.query,
app.state.OPENAI_API_KEY, embeddings_function=app.state.EMBEDDING_FUNCTION,
app.state.OPENAI_API_BASE_URL, reranking_function=app.state.sentence_transformer_rf,
) k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
return query_embeddings_doc( )
collection_name=form_data.collection_name, else:
query=form_data.query, return query_doc(
k=form_data.k if form_data.k else app.state.TOP_K, collection_name=form_data.collection_name,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, query=form_data.query,
embeddings_function=embeddings_function, embeddings_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf, k=form_data.k if form_data.k else app.state.TOP_K,
hybrid_search=( )
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -410,27 +424,23 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
embeddings_function = get_embeddings_function( if app.state.ENABLE_RAG_HYBRID_SEARCH:
app.state.RAG_EMBEDDING_ENGINE, return query_collection_with_hybrid_search(
app.state.RAG_EMBEDDING_MODEL, collection_names=form_data.collection_names,
app.state.sentence_transformer_ef, query=form_data.query,
app.state.OPENAI_API_KEY, embeddings_function=app.state.EMBEDDING_FUNCTION,
app.state.OPENAI_API_BASE_URL, reranking_function=app.state.sentence_transformer_rf,
) k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)
return query_embeddings_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid_search=(
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embeddings_function( embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,

View file

@ -26,61 +26,72 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_embeddings_doc( def query_doc(
collection_name: str, collection_name: str,
query: str, query: str,
embeddings_function, embedding_function,
reranking_function,
k: int, k: int,
r: int,
hybrid_search: bool,
): ):
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
if hybrid_search: log.info(f"query_doc:result {result}")
documents = collection.get() # get all documents return result
bm25_retriever = BM25Retriever.from_texts( except Exception as e:
texts=documents.get("documents"), raise e
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
embeddings_function=embeddings_function,
top_n=k,
)
ensemble_retriever = EnsembleRetriever( def query_doc_with_hybrid_search(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] collection_name: str,
) query: str,
embedding_function,
k: int,
reranking_function,
r: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
compressor = RerankCompressor( bm25_retriever = BM25Retriever.from_texts(
embeddings_function=embeddings_function, texts=documents.get("documents"),
reranking_function=reranking_function, metadatas=documents.get("metadatas"),
r_score=r, )
top_n=k, bm25_retriever.k = k
)
compression_retriever = ContextualCompressionRetriever( chroma_retriever = ChromaRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever collection=collection,
) embedding_function=embedding_function,
top_n=k,
)
result = compression_retriever.invoke(query) ensemble_retriever = EnsembleRetriever(
result = { retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
"distances": [[d.metadata.get("score") for d in result]], )
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
else:
query_embeddings = embeddings_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
log.info(f"query_embeddings_doc:result {result}") compressor = RerankCompressor(
embedding_function=embedding_function,
reranking_function=reranking_function,
r_score=r,
top_n=k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
log.info(f"query_doc_with_hybrid_search:result {result}")
return result return result
except Exception as e: except Exception as e:
raise e raise e
@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
return result return result
def query_embeddings_collection( def query_collection(
collection_names: List[str], collection_names: List[str],
query: str, query: str,
embedding_function,
k: int, k: int,
r: float,
embeddings_function,
reranking_function,
hybrid_search: bool,
): ):
results = [] results = []
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_embeddings_doc( result = query_doc(
collection_name=collection_name, collection_name=collection_name,
query=query, query=query,
k=k, k=k,
r=r, embedding_function=embedding_function,
embeddings_function=embeddings_function, )
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: List[str],
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
):
results = []
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
hybrid_search=hybrid_search, r=r,
) )
results.append(result) results.append(result)
except: except:
pass pass
reverse = hybrid_search and reranking_function is not None return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k, reverse=reverse)
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
return template return template
def get_embeddings_function( def get_embedding_function(
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,
@ -204,19 +232,13 @@ def rag_messages(
docs, docs,
messages, messages,
template, template,
embedding_function,
k, k,
reranking_function,
r, r,
hybrid_search, hybrid_search,
embedding_engine,
embedding_model,
embedding_function,
reranking_function,
openai_key,
openai_url,
): ):
log.debug( log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
)
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
@ -243,14 +265,6 @@ def rag_messages(
content_type = None content_type = None
query = "" query = ""
embeddings_function = get_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
@ -271,26 +285,31 @@ def rag_messages(
try: try:
if doc["type"] == "text": if doc["type"] == "text":
context = doc["content"] context = doc["content"]
elif doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid_search=hybrid_search,
)
else: else:
context = query_embeddings_doc( if hybrid_search:
collection_name=doc["collection_name"], context = query_collection_with_hybrid_search(
query=query, collection_names=(
k=k, doc["collection_names"]
r=r, if doc["type"] == "collection"
embeddings_function=embeddings_function, else [doc["collection_name"]]
reranking_function=reranking_function, ),
hybrid_search=hybrid_search, query=query,
) embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
else:
context = query_collection(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query,
embedding_function=embedding_function,
k=k,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None context = None
@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
class ChromaRetriever(BaseRetriever): class ChromaRetriever(BaseRetriever):
collection: Any collection: Any
embeddings_function: Any embedding_function: Any
top_n: int top_n: int
def _get_relevant_documents( def _get_relevant_documents(
@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]: ) -> List[Document]:
query_embeddings = self.embeddings_function(query) query_embeddings = self.embedding_function(query)
results = self.collection.query( results = self.collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
@ -445,7 +464,7 @@ from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor): class RerankCompressor(BaseDocumentCompressor):
embeddings_function: Any embedding_function: Any
reranking_function: Any reranking_function: Any
r_score: float r_score: float
top_n: int top_n: int
@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents] [(query, doc.page_content) for doc in documents]
) )
else: else:
query_embedding = self.embeddings_function(query) query_embedding = self.embedding_function(query)
document_embedding = self.embeddings_function( document_embedding = self.embedding_function(
[doc.page_content for doc in documents] [doc.page_content for doc in documents]
) )
scores = util.cos_sim(query_embedding, document_embedding)[0] scores = util.cos_sim(query_embedding, document_embedding)[0]

View file

@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"] = rag_messages( data["messages"] = rag_messages(
data["docs"], docs=data["docs"],
data["messages"], messages=data["messages"],
rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
rag_app.state.RELEVANCE_THRESHOLD, k=rag_app.state.TOP_K,
rag_app.state.ENABLE_RAG_HYBRID_SEARCH, reranking_function=rag_app.state.sentence_transformer_rf,
rag_app.state.RAG_EMBEDDING_ENGINE, r=rag_app.state.RELEVANCE_THRESHOLD,
rag_app.state.RAG_EMBEDDING_MODEL, hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
rag_app.state.sentence_transformer_ef,
rag_app.state.sentence_transformer_rf,
rag_app.state.OPENAI_API_KEY,
rag_app.state.OPENAI_API_BASE_URL,
) )
del data["docs"] del data["docs"]