Improve embedding model update & resolve network dependency

* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior
* Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub
* Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model()
* Add GUI setting to execute manual update process
This commit is contained in:
Self Denial 2024-04-04 11:01:23 -06:00
parent 62392aa88a
commit 3b66aa55c0
5 changed files with 218 additions and 19 deletions

View file

@ -1,6 +1,8 @@
import os
import re
import logging
from typing import List
from huggingface_hub import snapshot_download
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function):
messages[last_user_message_idx] = new_user_message
return messages
def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}")
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"update_embedding_model: {update_embedding_model}")
log.debug(f"local_files_only: {local_files_only}")
# Inspiration from upstream sentence_transformers
if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model