diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 423f1e03..e1a5e6eb 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -138,20 +138,22 @@ async def get_status(): } -@app.get("/embedding/model") -async def get_embedding_model(user=Depends(get_admin_user)): +@app.get("/embedding") +async def get_embedding_config(user=Depends(get_admin_user)): return { "status": True, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } class EmbeddingModelUpdateForm(BaseModel): + embedding_engine: str embedding_model: str -@app.post("/embedding/model/update") -async def update_embedding_model( +@app.post("/embedding/update") +async def update_embedding_config( form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) ): @@ -160,18 +162,26 @@ async def update_embedding_model( ) try: - sentence_transformer_ef = ( - embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=get_embedding_model_path(form_data.embedding_model, True), - device=DEVICE_TYPE, - ) - ) + app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine - app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model - app.state.sentence_transformer_ef = sentence_transformer_ef + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = None + else: + sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=get_embedding_model_path( + form_data.embedding_model, True + ), + device=DEVICE_TYPE, + ) + ) + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = sentence_transformer_ef return { "status": True, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 33c70e2b..bfcee55f 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -346,10 +346,10 @@ export const resetVectorDB = async (token: string) => { return res; }; -export const getEmbeddingModel = async (token: string) => { +export const getEmbeddingConfig = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { + const res = await fetch(`${RAG_API_BASE_URL}/embedding`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -374,13 +374,14 @@ export const getEmbeddingModel = async (token: string) => { }; type EmbeddingModelUpdateForm = { + embedding_engine: string; embedding_model: string; }; -export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { +export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { + const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index 85df678c..c9142fbe 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -7,11 +7,11 @@ scanDocs, updateQuerySettings, resetVectorDB, - getEmbeddingModel, - updateEmbeddingModel + getEmbeddingConfig, + updateEmbeddingConfig } from '$lib/apis/rag'; - import { documents } from '$lib/stores'; + import { documents, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { toast } from 'svelte-sonner'; @@ -27,6 +27,8 @@ let showResetConfirm = false; let embeddingEngine = ''; + let embeddingModel = ''; + let chunkSize = 0; let chunkOverlap = 0; let pdfExtractImages = true; @@ -36,8 +38,6 @@ k: 4 }; - let embeddingModel = ''; - const scanHandler = async () => { scanDirLoading = true; const res = await scanDocs(localStorage.token); @@ -50,7 +50,16 @@ }; const embeddingModelUpdateHandler = async () => { - if (embeddingModel.split('/').length - 1 > 1) { + if (embeddingModel === '') { + toast.error( + $i18n.t( + 'Model filesystem path detected. Model shortname is required for update, cannot continue.' + ) + ); + return; + } + + if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { toast.error( $i18n.t( 'Model filesystem path detected. Model shortname is required for update, cannot continue.' @@ -62,11 +71,17 @@ console.log('Update embedding model attempt:', embeddingModel); updateEmbeddingModelLoading = true; - const res = await updateEmbeddingModel(localStorage.token, { + const res = await updateEmbeddingConfig(localStorage.token, { + embedding_engine: embeddingEngine, embedding_model: embeddingModel }).catch(async (error) => { toast.error(error); - embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + + const embeddingConfig = await getEmbeddingConfig(localStorage.token); + if (embeddingConfig) { + embeddingEngine = embeddingConfig.embedding_engine; + embeddingModel = embeddingConfig.embedding_model; + } return null; }); updateEmbeddingModelLoading = false; @@ -102,7 +117,12 @@ chunkOverlap = res.chunk.chunk_overlap; } - embeddingModel = (await getEmbeddingModel(localStorage.token)).embedding_model; + const embeddingConfig = await getEmbeddingConfig(localStorage.token); + + if (embeddingConfig) { + embeddingEngine = embeddingConfig.embedding_engine; + embeddingModel = embeddingConfig.embedding_model; + } querySettings = await getQuerySettings(localStorage.token); }); @@ -126,6 +146,9 @@ class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" bind:value={embeddingEngine} placeholder="Select an embedding engine" + on:change={() => { + embeddingModel = ''; + }} > @@ -136,10 +159,77 @@