feat: hybrid search

This commit is contained in:
Steven Kreitzer 2024-04-22 15:49:58 -05:00 committed by Steven Kreitzer
parent f3e5700d49
commit 4e0b32b505
7 changed files with 406 additions and 110 deletions

View file

@ -49,8 +49,8 @@ from apps.web.models.documents import (
from apps.rag.utils import (
query_embeddings_doc,
query_embeddings_function,
query_embeddings_collection,
generate_openai_embeddings,
)
from utils.misc import (
@ -67,6 +67,8 @@ from config import (
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
DEVICE_TYPE,
@ -91,6 +93,7 @@ app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
@ -105,6 +108,12 @@ if app.state.RAG_EMBEDDING_ENGINE == "":
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
origins = ["*"]
@ -134,6 +143,7 @@ async def get_status():
"template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
}
@ -150,6 +160,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
}
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
class OpenAIConfigForm(BaseModel):
url: str
key: str
@ -205,6 +220,36 @@ async def update_embedding_config(
)
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@app.post("/reranking/update")
async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
)
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
return {
@ -286,34 +331,21 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
query_embeddings = app.state.sentence_transformer_ef.encode(
form_data.query
).tolist()
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
embeddings_function = query_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_doc(
collection_name=form_data.collection_name,
query=form_data.query,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -334,33 +366,21 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
query_embeddings = app.state.sentence_transformer_ef.encode(
form_data.query
).tolist()
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
embeddings_function = embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -427,8 +447,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs]
texts = list(map(lambda x: x.replace("\n", " "), texts))
metadatas = [doc.metadata for doc in docs]
try:
@ -440,26 +458,20 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = query_embeddings_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
if app.state.RAG_EMBEDDING_ENGINE == "":
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = embedding_func(embedding_texts)
else:
embeddings = [
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
for text in texts
embedding_func(embedding_texts) for text in texts
]
for batch in create_batches(