Improve embedding model update & resolve network dependency

* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior
* Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub
* Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model()
* Add GUI setting to execute manual update process
This commit is contained in:
Self Denial 2024-04-04 11:01:23 -06:00
parent 62392aa88a
commit 3b66aa55c0
5 changed files with 218 additions and 19 deletions

View file

@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path
from typing import List
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import (
@ -45,7 +44,7 @@ from apps.web.models.documents import (
DocumentResponse,
)
from apps.rag.utils import query_doc, query_collection
from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
from utils.misc import (
calculate_sha256,
@ -60,6 +59,7 @@ from config import (
DOCS_DIR,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False
@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
app.state.TOP_K = 4
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
}
@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
status = True
old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
try:
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
)
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
status = False
log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
log.debug(f"old_model_path: {old_model_path}")
log.debug(f"status: {status}")
return {
"status": True,
"status": status,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
}

View file

@ -1,6 +1,8 @@
import os
import re
import logging
from typing import List
from huggingface_hub import snapshot_download
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function):
messages[last_user_message_idx] = new_user_message
return messages
def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}")
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"update_embedding_model: {update_embedding_model}")
log.debug(f"local_files_only: {local_files_only}")
# Inspiration from upstream sentence_transformers
if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model