From 47a05a47b4dc813b0b0357660e19ed56cbab454f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 2 Mar 2024 18:56:57 -0800 Subject: [PATCH] feat: add rag top k value setting --- backend/apps/rag/main.py | 48 ++++++++++++------ src/lib/apis/rag/index.ts | 40 +++++++++++++-- .../documents/Settings/General.svelte | 49 ++++++++++++++++--- src/routes/(app)/+page.svelte | 12 ++--- src/routes/(app)/c/[id]/+page.svelte | 12 ++--- 5 files changed, 123 insertions(+), 38 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 83c10233..2a8b2a49 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.TOP_K = 4 + app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=app.state.RAG_EMBEDDING_MODEL, @@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)): } -class RAGTemplateForm(BaseModel): - template: str +@app.get("/query/settings") +async def get_query_settings(user=Depends(get_admin_user)): + return { + "status": True, + "template": app.state.RAG_TEMPLATE, + "k": app.state.TOP_K, + } -@app.post("/template/update") -async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): - # TODO: check template requirements - app.state.RAG_TEMPLATE = ( - form_data.template if form_data.template != "" else RAG_TEMPLATE - ) +class QuerySettingsForm(BaseModel): + k: Optional[int] = None + template: Optional[str] = None + + +@app.post("/query/settings/update") +async def update_query_settings( + form_data: QuerySettingsForm, user=Depends(get_admin_user) +): + app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE + app.state.TOP_K = form_data.k if form_data.k else 4 return {"status": True, "template": app.state.RAG_TEMPLATE} class QueryDocForm(BaseModel): collection_name: str query: str - k: Optional[int] = 4 + k: Optional[int] = None @app.post("/query/doc") @@ -240,7 +252,10 @@ def query_doc( name=form_data.collection_name, embedding_function=app.state.sentence_transformer_ef, ) - result = collection.query(query_texts=[form_data.query], n_results=form_data.k) + result = collection.query( + query_texts=[form_data.query], + n_results=form_data.k if form_data.k else app.state.TOP_K, + ) return result except Exception as e: print(e) @@ -253,7 +268,7 @@ def query_doc( class QueryCollectionsForm(BaseModel): collection_names: List[str] query: str - k: Optional[int] = 4 + k: Optional[int] = None def merge_and_sort_query_results(query_results, k): @@ -317,13 +332,16 @@ def query_collection( ) result = collection.query( - query_texts=[form_data.query], n_results=form_data.k + query_texts=[form_data.query], + n_results=form_data.k if form_data.k else app.state.TOP_K, ) results.append(result) except: pass - return merge_and_sort_query_results(results, form_data.k) + return merge_and_sort_query_results( + results, form_data.k if form_data.k else app.state.TOP_K + ) @app.post("/web") @@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ] or file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(file_path) - elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0): + elif file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ): loader = TextLoader(file_path) else: loader = TextLoader(file_path) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ed36f014..4e8e9b14 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => { return res?.template ?? ''; }; -export const updateRAGTemplate = async (token: string, template: string) => { +export const getQuerySettings = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/template/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +type QuerySettings = { + k: number | null; + template: string | null; +}; + +export const updateQuerySettings = async (token: string, settings: QuerySettings) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - template: template + ...settings }) }) .then(async (res) => { @@ -183,7 +215,7 @@ export const queryDoc = async ( token: string, collection_name: string, query: string, - k: number + k: number | null = null ) => { let error = null; diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index 9bf496a2..28f3e71a 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -2,10 +2,10 @@ import { getDocs } from '$lib/apis/documents'; import { getChunkParams, - getRAGTemplate, + getQuerySettings, scanDocs, updateChunkParams, - updateRAGTemplate + updateQuerySettings } from '$lib/apis/rag'; import { documents } from '$lib/stores'; import { onMount } from 'svelte'; @@ -18,7 +18,10 @@ let chunkSize = 0; let chunkOverlap = 0; - let template = ''; + let querySettings = { + template: '', + k: 4 + }; const scanHandler = async () => { loading = true; @@ -33,7 +36,7 @@ const submitHandler = async () => { const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap); - await updateRAGTemplate(localStorage.token, template); + querySettings = await updateQuerySettings(localStorage.token, querySettings); }; onMount(async () => { @@ -44,7 +47,7 @@ chunkOverlap = res.chunk_overlap; } - template = await getRAGTemplate(localStorage.token); + querySettings = await getQuerySettings(localStorage.token); }); @@ -156,10 +159,44 @@ +
Query Params
+ +
+
+
Top K
+ +
+ +
+
+ + +
+
RAG Template