Improve embedding model update & resolve network dependency

* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior
* Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub
* Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model()
* Add GUI setting to execute manual update process
This commit is contained in:
Self Denial 2024-04-04 11:01:23 -06:00
parent 62392aa88a
commit 3b66aa55c0
5 changed files with 218 additions and 19 deletions

View file

@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
@ -45,7 +44,7 @@ from apps.web.models.documents import (
DocumentResponse, DocumentResponse,
) )
from apps.rag.utils import query_doc, query_collection from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
@ -60,6 +59,7 @@ from config import (
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE, RAG_EMBEDDING_MODEL_DEVICE_TYPE,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
CHROMA_CLIENT, CHROMA_CLIENT,
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI() app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
app.state.TOP_K = 4 app.state.TOP_K = 4
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL, model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
) )
) )
@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
} }
@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_model( async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
status = True
old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
model_name=app.state.RAG_EMBEDDING_MODEL, log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
try:
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
) )
) except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
status = False
log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
log.debug(f"old_model_path: {old_model_path}")
log.debug(f"status: {status}")
return { return {
"status": True, "status": status,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
} }

View file

@ -1,6 +1,8 @@
import os
import re import re
import logging import logging
from typing import List from typing import List
from huggingface_hub import snapshot_download
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function):
messages[last_user_message_idx] = new_user_message messages[last_user_message_idx] = new_user_message
return messages return messages
def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}")
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"update_embedding_model: {update_embedding_model}")
log.debug(f"local_files_only: {local_files_only}")
# Inspiration from upstream sentence_transformers
if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model

View file

@ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
) )
RAG_EMBEDDING_MODEL_AUTO_UPDATE = False
if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true":
RAG_EMBEDDING_MODEL_AUTO_UPDATE = True
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),

View file

@ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => {
return res; return res;
}; };
export const getEmbeddingModel = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
type EmbeddingModelUpdateForm = {
embedding_model: string;
};
export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
...payload
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -6,7 +6,9 @@
getQuerySettings, getQuerySettings,
scanDocs, scanDocs,
updateQuerySettings, updateQuerySettings,
resetVectorDB resetVectorDB,
getEmbeddingModel,
updateEmbeddingModel
} from '$lib/apis/rag'; } from '$lib/apis/rag';
import { documents } from '$lib/stores'; import { documents } from '$lib/stores';
@ -18,6 +20,7 @@
export let saveHandler: Function; export let saveHandler: Function;
let loading = false; let loading = false;
let loading1 = false;
let showResetConfirm = false; let showResetConfirm = false;
@ -30,6 +33,10 @@
k: 4 k: 4
}; };
let embeddingModel = {
embedding_model: '',
};
const scanHandler = async () => { const scanHandler = async () => {
loading = true; loading = true;
const res = await scanDocs(localStorage.token); const res = await scanDocs(localStorage.token);
@ -41,6 +48,21 @@
} }
}; };
const embeddingModelUpdateHandler = async () => {
loading1 = true;
const res = await updateEmbeddingModel(localStorage.token, embeddingModel);
loading1 = false;
if (res) {
console.log('embeddingModelUpdateHandler:', res);
if (res.status == true) {
toast.success($i18n.t('Model {{embedding_model}} update complete!', res));
} else {
toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res));
}
}
};
const submitHandler = async () => { const submitHandler = async () => {
const res = await updateRAGConfig(localStorage.token, { const res = await updateRAGConfig(localStorage.token, {
pdf_extract_images: pdfExtractImages, pdf_extract_images: pdfExtractImages,
@ -62,6 +84,8 @@
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
} }
embeddingModel = await getEmbeddingModel(localStorage.token);
querySettings = await getQuerySettings(localStorage.token); querySettings = await getQuerySettings(localStorage.token);
}); });
</script> </script>
@ -137,6 +161,67 @@
{/if} {/if}
</button> </button>
</div> </div>
<div class=" flex w-full justify-between">
<div class=" self-center text-xs font-medium">
{$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)}
</div>
<button
class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading1
? ' cursor-not-allowed'
: ''}"
on:click={() => {
embeddingModelUpdateHandler(embeddingModel);
console.log('Update embedding model:', embeddingModel.embedding_model);
}}
type="button"
disabled={loading1}
>
<div class="self-center font-medium">{$i18n.t('Update')}</div>
<!-- <svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3 h-3"
>
<path
fill-rule="evenodd"
d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
clip-rule="evenodd"
/>
</svg> -->
{#if loading1}
<div class="ml-3 self-center">
<svg
class=" w-3 h-3"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
{/if}
</button>
</div>
</div> </div>
<hr class=" dark:border-gray-700" /> <hr class=" dark:border-gray-700" />