From 98948814fd28508d968b47c0ea092784874778ad Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 10 Mar 2024 13:32:34 -0700 Subject: [PATCH] feat: toggle pdf ocr --- backend/apps/rag/main.py | 37 ++-- src/lib/apis/rag/index.ts | 21 ++- .../documents/Settings/General.svelte | 169 ++++++++++-------- 3 files changed, 137 insertions(+), 90 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6781a9a1..b21724cc 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -77,6 +77,7 @@ from constants import ERROR_MESSAGES app = FastAPI() +app.state.PDF_EXTRACT_IMAGES = False app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE @@ -184,12 +185,15 @@ async def update_embedding_model( } -@app.get("/chunk") -async def get_chunk_params(user=Depends(get_admin_user)): +@app.get("/config") +async def get_rag_config(user=Depends(get_admin_user)): return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -198,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel): chunk_overlap: int -@app.post("/chunk/update") -async def update_chunk_params( - form_data: ChunkParamUpdateForm, user=Depends(get_admin_user) -): - app.state.CHUNK_SIZE = form_data.chunk_size - app.state.CHUNK_OVERLAP = form_data.chunk_overlap +class ConfigUpdateForm(BaseModel): + pdf_extract_images: bool + chunk: ChunkParamUpdateForm + + +@app.post("/config/update") +async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): + app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images + app.state.CHUNK_SIZE = form_data.chunk.chunk_size + app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap return { "status": True, - "chunk_size": app.state.CHUNK_SIZE, - "chunk_overlap": app.state.CHUNK_OVERLAP, + "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, + "chunk": { + "chunk_size": app.state.CHUNK_SIZE, + "chunk_overlap": app.state.CHUNK_OVERLAP, + }, } @@ -364,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ] if file_ext == "pdf": - loader = PyPDFLoader(file_path, extract_images=True) + loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 6dcfbbe7..668fe227 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -1,9 +1,9 @@ import { RAG_API_BASE_URL } from '$lib/constants'; -export const getChunkParams = async (token: string) => { +export const getRAGConfig = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk`, { + const res = await fetch(`${RAG_API_BASE_URL}/config`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => { return res; }; -export const updateChunkParams = async (token: string, size: number, overlap: number) => { +type ChunkConfigForm = { + chunk_size: number; + chunk_overlap: number; +}; + +type RAGConfigForm = { + pdf_extract_images: boolean; + chunk: ChunkConfigForm; +}; + +export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/config/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - chunk_size: size, - chunk_overlap: overlap + ...payload }) }) .then(async (res) => { diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index 28f3e71a..d3342b6a 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -1,10 +1,10 @@