forked from open-webui/open-webui
feat: hybrid search and reranking support
This commit is contained in:
parent
db801aee79
commit
c0259aad67
10 changed files with 262 additions and 131 deletions
|
@ -64,6 +64,8 @@ from config import (
|
|||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_TOP_K,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
|
@ -86,7 +88,8 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|||
app = FastAPI()
|
||||
|
||||
|
||||
app.state.TOP_K = 4
|
||||
app.state.TOP_K = RAG_TOP_K
|
||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
|
@ -107,12 +110,17 @@ if app.state.RAG_EMBEDDING_ENGINE == "":
|
|||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
else:
|
||||
app.state.sentence_transformer_ef = None
|
||||
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
app.state.RAG_RERANKING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
if not app.state.RAG_RERANKING_MODEL == "":
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
app.state.RAG_RERANKING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
else:
|
||||
app.state.sentence_transformer_rf = None
|
||||
|
||||
|
||||
origins = ["*"]
|
||||
|
@ -185,22 +193,22 @@ async def update_embedding_config(
|
|||
)
|
||||
try:
|
||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = None
|
||||
|
||||
if form_data.openai_config != None:
|
||||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
||||
|
||||
app.state.sentence_transformer_ef = None
|
||||
else:
|
||||
sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
app.state.sentence_transformer_ef = (
|
||||
sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
)
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = sentence_transformer_ef
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
|
@ -222,7 +230,7 @@ async def update_embedding_config(
|
|||
|
||||
class RerankingModelUpdateForm(BaseModel):
|
||||
reranking_model: str
|
||||
|
||||
|
||||
|
||||
@app.post("/reranking/update")
|
||||
async def update_reranking_config(
|
||||
|
@ -233,10 +241,14 @@ async def update_reranking_config(
|
|||
)
|
||||
try:
|
||||
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
app.state.RAG_RERANKING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
if app.state.RAG_RERANKING_MODEL == "":
|
||||
app.state.sentence_transformer_rf = None
|
||||
else:
|
||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||
app.state.RAG_RERANKING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
|
@ -302,11 +314,13 @@ async def get_query_settings(user=Depends(get_admin_user)):
|
|||
"status": True,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"k": app.state.TOP_K,
|
||||
"r": app.state.RELEVANCE_THRESHOLD,
|
||||
}
|
||||
|
||||
|
||||
class QuerySettingsForm(BaseModel):
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
template: Optional[str] = None
|
||||
|
||||
|
||||
|
@ -316,6 +330,7 @@ async def update_query_settings(
|
|||
):
|
||||
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
||||
app.state.TOP_K = form_data.k if form_data.k else 4
|
||||
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
||||
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
||||
|
||||
|
||||
|
@ -323,6 +338,7 @@ class QueryDocForm(BaseModel):
|
|||
collection_name: str
|
||||
query: str
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
|
||||
|
||||
@app.post("/query/doc")
|
||||
|
@ -343,6 +359,7 @@ def query_doc_handler(
|
|||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
embeddings_function=embeddings_function,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
)
|
||||
|
@ -358,6 +375,7 @@ class QueryCollectionsForm(BaseModel):
|
|||
collection_names: List[str]
|
||||
query: str
|
||||
k: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
|
||||
|
||||
@app.post("/query/collection")
|
||||
|
@ -378,6 +396,7 @@ def query_collection_handler(
|
|||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
embeddings_function=embeddings_function,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
)
|
||||
|
@ -467,12 +486,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
)
|
||||
|
||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
embeddings = embedding_func(embedding_texts)
|
||||
else:
|
||||
embeddings = [
|
||||
embedding_func(embedding_texts) for text in texts
|
||||
]
|
||||
embeddings = embedding_func(embedding_texts)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue