refac: rag pipeline

This commit is contained in:
Timothy J. Baek 2024-04-27 15:38:50 -04:00
parent 8f1563a7a5
commit ce9a5d12e0
3 changed files with 179 additions and 154 deletions

View file

@ -47,9 +47,11 @@ from apps.web.models.documents import (
from apps.rag.utils import (
get_model_path,
query_embeddings_doc,
get_embeddings_function,
query_embeddings_collection,
get_embedding_function,
query_doc,
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
)
from utils.misc import (
@ -147,6 +149,15 @@ update_reranking_model(
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
origins = ["*"]
@ -227,6 +238,14 @@ async def update_embedding_config(
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
@ -367,26 +386,21 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_doc(
if app.state.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid_search=(
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e:
log.exception(e)
@ -410,27 +424,23 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
embeddings_function = get_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_collection(
if app.state.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
reranking_function=app.state.sentence_transformer_rf,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid_search=(
form_data.hybrid
if form_data.hybrid
else app.state.ENABLE_RAG_HYBRID_SEARCH
),
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embeddings_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embeddings_function(
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,

View file

@ -26,20 +26,38 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_embeddings_doc(
def query_doc(
collection_name: str,
query: str,
embeddings_function,
reranking_function,
embedding_function,
k: int,
r: int,
hybrid_search: bool,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
query_embeddings = embedding_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
if hybrid_search:
log.info(f"query_doc:result {result}")
return result
except Exception as e:
raise e
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: int,
):
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
documents = collection.get() # get all documents
bm25_retriever = BM25Retriever.from_texts(
texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
@ -48,7 +66,7 @@ def query_embeddings_doc(
chroma_retriever = ChromaRetriever(
collection=collection,
embeddings_function=embeddings_function,
embedding_function=embedding_function,
top_n=k,
)
@ -57,7 +75,7 @@ def query_embeddings_doc(
)
compressor = RerankCompressor(
embeddings_function=embeddings_function,
embedding_function=embedding_function,
reranking_function=reranking_function,
r_score=r,
top_n=k,
@ -73,14 +91,7 @@ def query_embeddings_doc(
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
else:
query_embeddings = embeddings_function(query)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
log.info(f"query_embeddings_doc:result {result}")
log.info(f"query_doc_with_hybrid_search:result {result}")
return result
except Exception as e:
raise e
@ -127,35 +138,52 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
return result
def query_embeddings_collection(
def query_collection(
collection_names: List[str],
query: str,
embedding_function,
k: int,
r: float,
embeddings_function,
reranking_function,
hybrid_search: bool,
):
results = []
for collection_name in collection_names:
try:
result = query_embeddings_doc(
result = query_doc(
collection_name=collection_name,
query=query,
k=k,
r=r,
embeddings_function=embeddings_function,
embedding_function=embedding_function,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: List[str],
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
):
results = []
for collection_name in collection_names:
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
hybrid_search=hybrid_search,
r=r,
)
results.append(result)
except:
pass
reverse = hybrid_search and reranking_function is not None
return merge_and_sort_query_results(results, k=k, reverse=reverse)
return merge_and_sort_query_results(results, k=k, reverse=True)
def rag_template(template: str, context: str, query: str):
@ -164,7 +192,7 @@ def rag_template(template: str, context: str, query: str):
return template
def get_embeddings_function(
def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
@ -204,19 +232,13 @@ def rag_messages(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
embedding_engine,
embedding_model,
embedding_function,
reranking_function,
openai_key,
openai_url,
):
log.debug(
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
)
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
@ -243,14 +265,6 @@ def rag_messages(
content_type = None
query = ""
embeddings_function = get_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
)
extracted_collections = []
relevant_contexts = []
@ -271,25 +285,30 @@ def rag_messages(
try:
if doc["type"] == "text":
context = doc["content"]
elif doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query,
embedding_function=embedding_function,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid_search=hybrid_search,
r=r,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
context = query_collection(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query,
embedding_function=embedding_function,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid_search=hybrid_search,
)
except Exception as e:
log.exception(e)
@ -404,7 +423,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
class ChromaRetriever(BaseRetriever):
collection: Any
embeddings_function: Any
embedding_function: Any
top_n: int
def _get_relevant_documents(
@ -413,7 +432,7 @@ class ChromaRetriever(BaseRetriever):
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embeddings_function(query)
query_embeddings = self.embedding_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
@ -445,7 +464,7 @@ from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embeddings_function: Any
embedding_function: Any
reranking_function: Any
r_score: float
top_n: int
@ -465,8 +484,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents]
)
else:
query_embedding = self.embeddings_function(query)
document_embedding = self.embeddings_function(
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]

View file

@ -117,18 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
data["docs"],
data["messages"],
rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K,
rag_app.state.RELEVANCE_THRESHOLD,
rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef,
rag_app.state.sentence_transformer_rf,
rag_app.state.OPENAI_API_KEY,
rag_app.state.OPENAI_API_BASE_URL,
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
)
del data["docs"]