From ccf08fb91e3960b1b8457306ac231d18dc7f21e2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 17 Feb 2024 22:29:52 -0800 Subject: [PATCH] feat: editable chunk params --- backend/apps/rag/main.py | 39 +++++++++++- src/lib/apis/rag/index.ts | 58 +++++++++++++++++ .../documents/Settings/General.svelte | 62 +++++++++++++++++-- 3 files changed, 152 insertions(+), 7 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f24b6b90..79981680 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -71,6 +71,9 @@ from constants import ERROR_MESSAGES app = FastAPI() +app.state.CHUNK_SIZE = CHUNK_SIZE +app.state.CHUNK_OVERLAP = CHUNK_OVERLAP + origins = ["*"] app.add_middleware( @@ -92,7 +95,7 @@ class StoreWebForm(CollectionNameForm): def store_data_in_vector_db(data, collection_name) -> bool: text_splitter = RecursiveCharacterTextSplitter( - chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP ) docs = text_splitter.split_documents(data) @@ -116,7 +119,39 @@ def store_data_in_vector_db(data, collection_name) -> bool: @app.get("/") async def get_status(): - return {"status": True} + return { + "status": True, + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + } + + +@app.get("/chunk") +async def get_chunk_params(user=Depends(get_admin_user)): + return { + "status": True, + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + } + + +class ChunkParamUpdateForm(BaseModel): + chunk_size: int + chunk_overlap: int + + +@app.post("/chunk/update") +async def update_chunk_params( + form_data: ChunkParamUpdateForm, user=Depends(get_admin_user) +): + app.state.CHUNK_SIZE = form_data.chunk_size + app.state.CHUNK_OVERLAP = form_data.chunk_overlap + + return { + "status": True, + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + } class QueryDocForm(BaseModel): diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index fc3571aa..5819badb 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -1,5 +1,63 @@ import { RAG_API_BASE_URL } from '$lib/constants'; +export const getChunkParams = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/chunk`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateChunkParams = async (token: string, size: number, overlap: number) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + chunk_size: size, + chunk_overlap: overlap + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { const data = new FormData(); data.append('file', file); diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c3c7df5b..038fa83b 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -1,6 +1,6 @@
{ - // console.log('submit'); + submitHandler(); saveHandler(); }} > @@ -93,14 +107,52 @@ + +
+ +
+
Chunk Params
+ +
+
+
Chunk Size
+ +
+ +
+
+ +
+
Chunk Overlap
+ +
+ +
+
+
+
- +