forked from open-webui/open-webui
feat: add rag top k value setting
This commit is contained in:
parent
9694c6569f
commit
47a05a47b4
5 changed files with 123 additions and 38 deletions
|
@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
|||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.TOP_K = 4
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
|
@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
|
|||
}
|
||||
|
||||
|
||||
class RAGTemplateForm(BaseModel):
|
||||
template: str
|
||||
@app.get("/query/settings")
|
||||
async def get_query_settings(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"k": app.state.TOP_K,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/template/update")
|
||||
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
|
||||
# TODO: check template requirements
|
||||
app.state.RAG_TEMPLATE = (
|
||||
form_data.template if form_data.template != "" else RAG_TEMPLATE
|
||||
)
|
||||
class QuerySettingsForm(BaseModel):
|
||||
k: Optional[int] = None
|
||||
template: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/query/settings/update")
|
||||
async def update_query_settings(
|
||||
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
||||
app.state.TOP_K = form_data.k if form_data.k else 4
|
||||
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
||||
|
||||
|
||||
class QueryDocForm(BaseModel):
|
||||
collection_name: str
|
||||
query: str
|
||||
k: Optional[int] = 4
|
||||
k: Optional[int] = None
|
||||
|
||||
|
||||
@app.post("/query/doc")
|
||||
|
@ -240,7 +252,10 @@ def query_doc(
|
|||
name=form_data.collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
||||
result = collection.query(
|
||||
query_texts=[form_data.query],
|
||||
n_results=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
@ -253,7 +268,7 @@ def query_doc(
|
|||
class QueryCollectionsForm(BaseModel):
|
||||
collection_names: List[str]
|
||||
query: str
|
||||
k: Optional[int] = 4
|
||||
k: Optional[int] = None
|
||||
|
||||
|
||||
def merge_and_sort_query_results(query_results, k):
|
||||
|
@ -317,13 +332,16 @@ def query_collection(
|
|||
)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[form_data.query], n_results=form_data.k
|
||||
query_texts=[form_data.query],
|
||||
n_results=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, form_data.k)
|
||||
return merge_and_sort_query_results(
|
||||
results, form_data.k if form_data.k else app.state.TOP_K
|
||||
)
|
||||
|
||||
|
||||
@app.post("/web")
|
||||
|
@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
|||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0):
|
||||
elif file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
loader = TextLoader(file_path)
|
||||
else:
|
||||
loader = TextLoader(file_path)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue