forked from open-webui/open-webui
Improve embedding model update & resolve network dependency
* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior * Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub * Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model() * Add GUI setting to execute manual update process
This commit is contained in:
parent
62392aa88a
commit
3b66aa55c0
5 changed files with 218 additions and 19 deletions
|
@ -13,7 +13,6 @@ import os, shutil, logging, re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
from langchain_community.document_loaders import (
|
from langchain_community.document_loaders import (
|
||||||
|
@ -45,7 +44,7 @@ from apps.web.models.documents import (
|
||||||
DocumentResponse,
|
DocumentResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from apps.rag.utils import query_doc, query_collection
|
from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
|
||||||
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
calculate_sha256,
|
calculate_sha256,
|
||||||
|
@ -60,6 +59,7 @@ from config import (
|
||||||
DOCS_DIR,
|
DOCS_DIR,
|
||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
CHROMA_CLIENT,
|
CHROMA_CLIENT,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
CHUNK_OVERLAP,
|
CHUNK_OVERLAP,
|
||||||
|
@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
#
|
|
||||||
# if RAG_EMBEDDING_MODEL:
|
|
||||||
# sentence_transformer_ef = SentenceTransformer(
|
|
||||||
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
|
||||||
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
|
||||||
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
app.state.PDF_EXTRACT_IMAGES = False
|
app.state.PDF_EXTRACT_IMAGES = False
|
||||||
|
@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
|
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
|
||||||
app.state.TOP_K = 4
|
app.state.TOP_K = 4
|
||||||
|
|
||||||
app.state.sentence_transformer_ef = (
|
app.state.sentence_transformer_ef = (
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||||
async def update_embedding_model(
|
async def update_embedding_model(
|
||||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
status = True
|
||||||
|
old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
|
||||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
|
|
||||||
|
log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
|
||||||
|
log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
|
||||||
app.state.sentence_transformer_ef = (
|
app.state.sentence_transformer_ef = (
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Problem updating embedding model: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
|
||||||
|
status = False
|
||||||
|
|
||||||
|
log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
|
||||||
|
log.debug(f"old_model_path: {old_model_path}")
|
||||||
|
log.debug(f"status: {status}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": status,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||||
|
|
||||||
|
@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
||||||
messages[last_user_message_idx] = new_user_message
|
messages[last_user_message_idx] = new_user_message
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
|
||||||
|
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||||
|
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||||
|
local_files_only = not update_embedding_model
|
||||||
|
snapshot_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}")
|
||||||
|
log.debug(f"embedding_model: {embedding_model}")
|
||||||
|
log.debug(f"update_embedding_model: {update_embedding_model}")
|
||||||
|
log.debug(f"local_files_only: {local_files_only}")
|
||||||
|
|
||||||
|
# Inspiration from upstream sentence_transformers
|
||||||
|
if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
|
||||||
|
# If fully qualified path exists, return input, else set repo_id
|
||||||
|
return embedding_model
|
||||||
|
elif "/" not in embedding_model:
|
||||||
|
# Set valid repo_id for model short-name
|
||||||
|
embedding_model = "sentence-transformers" + "/" + embedding_model
|
||||||
|
|
||||||
|
snapshot_kwargs["repo_id"] = embedding_model
|
||||||
|
|
||||||
|
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||||
|
try:
|
||||||
|
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||||
|
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
|
||||||
|
return embedding_model_repo_path
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||||
|
return embedding_model
|
||||||
|
|
|
@ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
|
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
|
||||||
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
|
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
|
||||||
)
|
)
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE = False
|
||||||
|
if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true":
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE = True
|
||||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||||
path=CHROMA_DATA_PATH,
|
path=CHROMA_DATA_PATH,
|
||||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||||
|
|
|
@ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => {
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getEmbeddingModel = async (token: string) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
error = err.detail;
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
type EmbeddingModelUpdateForm = {
|
||||||
|
embedding_model: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
...payload
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
error = err.detail;
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
|
@ -6,7 +6,9 @@
|
||||||
getQuerySettings,
|
getQuerySettings,
|
||||||
scanDocs,
|
scanDocs,
|
||||||
updateQuerySettings,
|
updateQuerySettings,
|
||||||
resetVectorDB
|
resetVectorDB,
|
||||||
|
getEmbeddingModel,
|
||||||
|
updateEmbeddingModel
|
||||||
} from '$lib/apis/rag';
|
} from '$lib/apis/rag';
|
||||||
|
|
||||||
import { documents } from '$lib/stores';
|
import { documents } from '$lib/stores';
|
||||||
|
@ -18,6 +20,7 @@
|
||||||
export let saveHandler: Function;
|
export let saveHandler: Function;
|
||||||
|
|
||||||
let loading = false;
|
let loading = false;
|
||||||
|
let loading1 = false;
|
||||||
|
|
||||||
let showResetConfirm = false;
|
let showResetConfirm = false;
|
||||||
|
|
||||||
|
@ -30,6 +33,10 @@
|
||||||
k: 4
|
k: 4
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let embeddingModel = {
|
||||||
|
embedding_model: '',
|
||||||
|
};
|
||||||
|
|
||||||
const scanHandler = async () => {
|
const scanHandler = async () => {
|
||||||
loading = true;
|
loading = true;
|
||||||
const res = await scanDocs(localStorage.token);
|
const res = await scanDocs(localStorage.token);
|
||||||
|
@ -41,6 +48,21 @@
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const embeddingModelUpdateHandler = async () => {
|
||||||
|
loading1 = true;
|
||||||
|
const res = await updateEmbeddingModel(localStorage.token, embeddingModel);
|
||||||
|
loading1 = false;
|
||||||
|
|
||||||
|
if (res) {
|
||||||
|
console.log('embeddingModelUpdateHandler:', res);
|
||||||
|
if (res.status == true) {
|
||||||
|
toast.success($i18n.t('Model {{embedding_model}} update complete!', res));
|
||||||
|
} else {
|
||||||
|
toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const submitHandler = async () => {
|
const submitHandler = async () => {
|
||||||
const res = await updateRAGConfig(localStorage.token, {
|
const res = await updateRAGConfig(localStorage.token, {
|
||||||
pdf_extract_images: pdfExtractImages,
|
pdf_extract_images: pdfExtractImages,
|
||||||
|
@ -62,6 +84,8 @@
|
||||||
chunkOverlap = res.chunk.chunk_overlap;
|
chunkOverlap = res.chunk.chunk_overlap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
embeddingModel = await getEmbeddingModel(localStorage.token);
|
||||||
|
|
||||||
querySettings = await getQuerySettings(localStorage.token);
|
querySettings = await getQuerySettings(localStorage.token);
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
@ -137,6 +161,67 @@
|
||||||
{/if}
|
{/if}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class=" flex w-full justify-between">
|
||||||
|
<div class=" self-center text-xs font-medium">
|
||||||
|
{$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button
|
||||||
|
class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading1
|
||||||
|
? ' cursor-not-allowed'
|
||||||
|
: ''}"
|
||||||
|
on:click={() => {
|
||||||
|
embeddingModelUpdateHandler(embeddingModel);
|
||||||
|
console.log('Update embedding model:', embeddingModel.embedding_model);
|
||||||
|
}}
|
||||||
|
type="button"
|
||||||
|
disabled={loading1}
|
||||||
|
>
|
||||||
|
<div class="self-center font-medium">{$i18n.t('Update')}</div>
|
||||||
|
|
||||||
|
<!-- <svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
viewBox="0 0 16 16"
|
||||||
|
fill="currentColor"
|
||||||
|
class="w-3 h-3"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
fill-rule="evenodd"
|
||||||
|
d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
|
||||||
|
clip-rule="evenodd"
|
||||||
|
/>
|
||||||
|
</svg> -->
|
||||||
|
|
||||||
|
{#if loading1}
|
||||||
|
<div class="ml-3 self-center">
|
||||||
|
<svg
|
||||||
|
class=" w-3 h-3"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="currentColor"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
><style>
|
||||||
|
.spinner_ajPY {
|
||||||
|
transform-origin: center;
|
||||||
|
animation: spinner_AtaB 0.75s infinite linear;
|
||||||
|
}
|
||||||
|
@keyframes spinner_AtaB {
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style><path
|
||||||
|
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
|
||||||
|
opacity=".25"
|
||||||
|
/><path
|
||||||
|
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
|
||||||
|
class="spinner_ajPY"
|
||||||
|
/></svg
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<hr class=" dark:border-gray-700" />
|
<hr class=" dark:border-gray-700" />
|
||||||
|
|
Loading…
Reference in a new issue