feat: hybrid search and reranking support

This commit is contained in:
Steven Kreitzer 2024-04-22 18:36:46 -05:00
parent db801aee79
commit c0259aad67
10 changed files with 262 additions and 131 deletions

View file

@ -64,6 +64,8 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@ -86,7 +88,8 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app.state.TOP_K = 4
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
@ -107,12 +110,17 @@ if app.state.RAG_EMBEDDING_ENGINE == "":
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_ef = None
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
if not app.state.RAG_RERANKING_MODEL == "":
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_rf = None
origins = ["*"]
@ -185,22 +193,22 @@ async def update_embedding_config(
)
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None
if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
app.state.sentence_transformer_ef = None
else:
sentence_transformer_ef = sentence_transformers.SentenceTransformer(
app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
app.state.sentence_transformer_ef = (
sentence_transformers.SentenceTransformer(
app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef
return {
"status": True,
@ -222,7 +230,7 @@ async def update_embedding_config(
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@app.post("/reranking/update")
async def update_reranking_config(
@ -233,10 +241,14 @@ async def update_reranking_config(
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
)
if app.state.RAG_RERANKING_MODEL == "":
app.state.sentence_transformer_rf = None
else:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
app.state.RAG_RERANKING_MODEL,
device=DEVICE_TYPE,
)
return {
"status": True,
@ -302,11 +314,13 @@ async def get_query_settings(user=Depends(get_admin_user)):
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
}
class QuerySettingsForm(BaseModel):
k: Optional[int] = None
r: Optional[float] = None
template: Optional[str] = None
@ -316,6 +330,7 @@ async def update_query_settings(
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
return {"status": True, "template": app.state.RAG_TEMPLATE}
@ -323,6 +338,7 @@ class QueryDocForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = None
r: Optional[float] = None
@app.post("/query/doc")
@ -343,6 +359,7 @@ def query_doc_handler(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
)
@ -358,6 +375,7 @@ class QueryCollectionsForm(BaseModel):
collection_names: List[str]
query: str
k: Optional[int] = None
r: Optional[float] = None
@app.post("/query/collection")
@ -378,6 +396,7 @@ def query_collection_handler(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
)
@ -467,12 +486,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
if app.state.RAG_EMBEDDING_ENGINE == "":
embeddings = embedding_func(embedding_texts)
else:
embeddings = [
embedding_func(embedding_texts) for text in texts
]
embeddings = embedding_func(embedding_texts)
for batch in create_batches(
api=CHROMA_CLIENT,