feat: toggle hybrid search

This commit is contained in:
Steven Kreitzer 2024-04-25 17:31:21 -05:00
parent 984dbf13ab
commit 9755cd5baa
4 changed files with 133 additions and 88 deletions

View file

@ -70,6 +70,7 @@ from config import (
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_HYBRID,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
@ -91,6 +92,8 @@ app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.HYBRID = RAG_HYBRID
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
@ -321,6 +324,7 @@ async def get_query_settings(user=Depends(get_admin_user)):
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.HYBRID,
}
@ -328,6 +332,7 @@ class QuerySettingsForm(BaseModel):
k: Optional[int] = None
r: Optional[float] = None
template: Optional[str] = None
hybrid: Optional[bool] = None
@app.post("/query/settings/update")
@ -337,7 +342,14 @@ async def update_query_settings(
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
return {"status": True, "template": app.state.RAG_TEMPLATE}
app.state.HYBRID = form_data.hybrid if form_data.hybrid else False
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.HYBRID,
}
class QueryDocForm(BaseModel):
@ -345,6 +357,7 @@ class QueryDocForm(BaseModel):
query: str
k: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@app.post("/query/doc")
@ -368,6 +381,7 @@ def query_doc_handler(
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
)
except Exception as e:
log.exception(e)
@ -382,6 +396,7 @@ class QueryCollectionsForm(BaseModel):
query: str
k: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
@app.post("/query/collection")
@ -405,6 +420,7 @@ def query_collection_handler(
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
)
except Exception as e:
log.exception(e)

View file

@ -32,13 +32,13 @@ def query_embeddings_doc(
collection_name: str,
query: str,
embeddings_function,
reranking_function,
k: int,
reranking_function: Optional[CrossEncoder] = None,
r: Optional[float] = None,
hybrid: Optional[bool] = False,
):
try:
if reranking_function:
if hybrid:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(name=collection_name)
@ -142,6 +142,7 @@ def query_embeddings_collection(
r: float,
embeddings_function,
reranking_function,
hybrid: bool,
):
results = []
@ -155,6 +156,7 @@ def query_embeddings_collection(
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid=hybrid,
)
results.append(result)
except:
@ -211,6 +213,7 @@ def rag_messages(
template,
k,
r,
hybrid,
embedding_engine,
embedding_model,
embedding_function,
@ -283,6 +286,7 @@ def rag_messages(
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid=hybrid,
)
else:
context = query_embeddings_doc(
@ -292,6 +296,7 @@ def rag_messages(
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
hybrid=hybrid,
)
except Exception as e:
log.exception(e)