forked from open-webui/open-webui
		
	Improve embedding model update & resolve network dependency
* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior * Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub * Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model() * Add GUI setting to execute manual update process
This commit is contained in:
		
							parent
							
								
									62392aa88a
								
							
						
					
					
						commit
						3b66aa55c0
					
				
					 5 changed files with 218 additions and 19 deletions
				
			
		|  | @ -13,7 +13,6 @@ import os, shutil, logging, re | |||
| from pathlib import Path | ||||
| from typing import List | ||||
| 
 | ||||
| from sentence_transformers import SentenceTransformer | ||||
| from chromadb.utils import embedding_functions | ||||
| 
 | ||||
| from langchain_community.document_loaders import ( | ||||
|  | @ -45,7 +44,7 @@ from apps.web.models.documents import ( | |||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from apps.rag.utils import query_doc, query_collection | ||||
| from apps.rag.utils import query_doc, query_collection, embedding_model_get_path | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|     calculate_sha256, | ||||
|  | @ -60,6 +59,7 @@ from config import ( | |||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
|     CHROMA_CLIENT, | ||||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|  | @ -71,15 +71,6 @@ from constants import ERROR_MESSAGES | |||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| # | ||||
| # if RAG_EMBEDDING_MODEL: | ||||
| #    sentence_transformer_ef = SentenceTransformer( | ||||
| #        model_name_or_path=RAG_EMBEDDING_MODEL, | ||||
| #        cache_folder=RAG_EMBEDDING_MODEL_DIR, | ||||
| #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
| #    ) | ||||
| 
 | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
|  | @ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE | |||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE) | ||||
| app.state.TOP_K = 4 | ||||
| 
 | ||||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     ) | ||||
| ) | ||||
|  | @ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)): | |||
|     return { | ||||
|         "status": True, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel): | |||
| async def update_embedding_model( | ||||
|     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     status = True | ||||
|     old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH | ||||
|     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|     app.state.sentence_transformer_ef = ( | ||||
|         embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|             model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|             device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
| 
 | ||||
|     log.debug(f"form_data.embedding_model: {form_data.embedding_model}") | ||||
|     log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}") | ||||
| 
 | ||||
|     try: | ||||
|         app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True) | ||||
|         app.state.sentence_transformer_ef = ( | ||||
|             embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|                 model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|                 device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|             ) | ||||
|         ) | ||||
|     ) | ||||
|     except Exception as e:  | ||||
|         log.exception(f"Problem updating embedding model: {e}") | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=e, | ||||
|         ) | ||||
| 
 | ||||
|     if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: | ||||
|       status = False | ||||
| 
 | ||||
|     log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}") | ||||
|     log.debug(f"old_model_path: {old_model_path}") | ||||
|     log.debug(f"status: {status}") | ||||
| 
 | ||||
|     return { | ||||
|         "status": True, | ||||
|         "status": status, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| import os | ||||
| import re | ||||
| import logging | ||||
| from typing import List | ||||
| from huggingface_hub import snapshot_download | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
|  | @ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|     messages[last_user_message_idx] = new_user_message | ||||
| 
 | ||||
|     return messages | ||||
| 
 | ||||
| def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False): | ||||
|     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||
|     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||
|     local_files_only = not update_embedding_model | ||||
|     snapshot_kwargs = { | ||||
|         "cache_dir": cache_dir, | ||||
|         "local_files_only": local_files_only, | ||||
|     } | ||||
| 
 | ||||
|     log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}") | ||||
|     log.debug(f"embedding_model: {embedding_model}") | ||||
|     log.debug(f"update_embedding_model: {update_embedding_model}") | ||||
|     log.debug(f"local_files_only: {local_files_only}") | ||||
| 
 | ||||
|     # Inspiration from upstream sentence_transformers | ||||
|     if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only): | ||||
|         # If fully qualified path exists, return input, else set repo_id | ||||
|         return embedding_model | ||||
|     elif "/" not in embedding_model: | ||||
|         # Set valid repo_id for model short-name | ||||
|         embedding_model = "sentence-transformers" + "/" + embedding_model | ||||
| 
 | ||||
|     snapshot_kwargs["repo_id"] = embedding_model | ||||
| 
 | ||||
|     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||
|     try: | ||||
|         embedding_model_repo_path = snapshot_download(**snapshot_kwargs) | ||||
|         log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") | ||||
|         return embedding_model_repo_path | ||||
|     except Exception as e: | ||||
|         log.exception(f"Cannot determine embedding model snapshot path: {e}") | ||||
|         return embedding_model | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Self Denial
						Self Denial