forked from open-webui/open-webui
		
	Improve embedding model update & resolve network dependency
* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior * Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub * Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model() * Add GUI setting to execute manual update process
This commit is contained in:
		
							parent
							
								
									62392aa88a
								
							
						
					
					
						commit
						3b66aa55c0
					
				
					 5 changed files with 218 additions and 19 deletions
				
			
		|  | @ -13,7 +13,6 @@ import os, shutil, logging, re | |||
| from pathlib import Path | ||||
| from typing import List | ||||
| 
 | ||||
| from sentence_transformers import SentenceTransformer | ||||
| from chromadb.utils import embedding_functions | ||||
| 
 | ||||
| from langchain_community.document_loaders import ( | ||||
|  | @ -45,7 +44,7 @@ from apps.web.models.documents import ( | |||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from apps.rag.utils import query_doc, query_collection | ||||
| from apps.rag.utils import query_doc, query_collection, embedding_model_get_path | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|     calculate_sha256, | ||||
|  | @ -60,6 +59,7 @@ from config import ( | |||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
|     CHROMA_CLIENT, | ||||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|  | @ -71,15 +71,6 @@ from constants import ERROR_MESSAGES | |||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| # | ||||
| # if RAG_EMBEDDING_MODEL: | ||||
| #    sentence_transformer_ef = SentenceTransformer( | ||||
| #        model_name_or_path=RAG_EMBEDDING_MODEL, | ||||
| #        cache_folder=RAG_EMBEDDING_MODEL_DIR, | ||||
| #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
| #    ) | ||||
| 
 | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
|  | @ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE | |||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE) | ||||
| app.state.TOP_K = 4 | ||||
| 
 | ||||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     ) | ||||
| ) | ||||
|  | @ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)): | |||
|     return { | ||||
|         "status": True, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel): | |||
| async def update_embedding_model( | ||||
|     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     status = True | ||||
|     old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH | ||||
|     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
| 
 | ||||
|     log.debug(f"form_data.embedding_model: {form_data.embedding_model}") | ||||
|     log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}") | ||||
| 
 | ||||
|     try: | ||||
|         app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True) | ||||
|         app.state.sentence_transformer_ef = ( | ||||
|             embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|             model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|                 model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|                 device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|             ) | ||||
|         ) | ||||
|     except Exception as e:  | ||||
|         log.exception(f"Problem updating embedding model: {e}") | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=e, | ||||
|         ) | ||||
| 
 | ||||
|     if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: | ||||
|       status = False | ||||
| 
 | ||||
|     log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}") | ||||
|     log.debug(f"old_model_path: {old_model_path}") | ||||
|     log.debug(f"status: {status}") | ||||
| 
 | ||||
|     return { | ||||
|         "status": True, | ||||
|         "status": status, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| import os | ||||
| import re | ||||
| import logging | ||||
| from typing import List | ||||
| from huggingface_hub import snapshot_download | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
|  | @ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|     messages[last_user_message_idx] = new_user_message | ||||
| 
 | ||||
|     return messages | ||||
| 
 | ||||
| def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False): | ||||
|     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||
|     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||
|     local_files_only = not update_embedding_model | ||||
|     snapshot_kwargs = { | ||||
|         "cache_dir": cache_dir, | ||||
|         "local_files_only": local_files_only, | ||||
|     } | ||||
| 
 | ||||
|     log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}") | ||||
|     log.debug(f"embedding_model: {embedding_model}") | ||||
|     log.debug(f"update_embedding_model: {update_embedding_model}") | ||||
|     log.debug(f"local_files_only: {local_files_only}") | ||||
| 
 | ||||
|     # Inspiration from upstream sentence_transformers | ||||
|     if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only): | ||||
|         # If fully qualified path exists, return input, else set repo_id | ||||
|         return embedding_model | ||||
|     elif "/" not in embedding_model: | ||||
|         # Set valid repo_id for model short-name | ||||
|         embedding_model = "sentence-transformers" + "/" + embedding_model | ||||
| 
 | ||||
|     snapshot_kwargs["repo_id"] = embedding_model | ||||
| 
 | ||||
|     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||
|     try: | ||||
|         embedding_model_repo_path = snapshot_download(**snapshot_kwargs) | ||||
|         log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") | ||||
|         return embedding_model_repo_path | ||||
|     except Exception as e: | ||||
|         log.exception(f"Cannot determine embedding model snapshot path: {e}") | ||||
|         return embedding_model | ||||
|  |  | |||
|  | @ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") | |||
| RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( | ||||
|     "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" | ||||
| ) | ||||
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = False | ||||
| if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true": | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE = True | ||||
| CHROMA_CLIENT = chromadb.PersistentClient( | ||||
|     path=CHROMA_DATA_PATH, | ||||
|     settings=Settings(allow_reset=True, anonymized_telemetry=False), | ||||
|  |  | |||
|  | @ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => { | |||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getEmbeddingModel = async (token: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err.detail; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| type EmbeddingModelUpdateForm = { | ||||
| 	embedding_model: string; | ||||
| }; | ||||
| 
 | ||||
| export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			...payload | ||||
| 		}) | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err.detail; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
|  |  | |||
|  | @ -6,7 +6,9 @@ | |||
| 		getQuerySettings, | ||||
| 		scanDocs, | ||||
| 		updateQuerySettings, | ||||
| 		resetVectorDB | ||||
| 		resetVectorDB, | ||||
| 		getEmbeddingModel, | ||||
| 		updateEmbeddingModel | ||||
| 	} from '$lib/apis/rag'; | ||||
| 
 | ||||
| 	import { documents } from '$lib/stores'; | ||||
|  | @ -18,6 +20,7 @@ | |||
| 	export let saveHandler: Function; | ||||
| 
 | ||||
| 	let loading = false; | ||||
| 	let loading1 = false; | ||||
| 
 | ||||
| 	let showResetConfirm = false; | ||||
| 
 | ||||
|  | @ -30,6 +33,10 @@ | |||
| 		k: 4 | ||||
| 	}; | ||||
| 
 | ||||
| 	let embeddingModel = { | ||||
| 		embedding_model: '', | ||||
| 	}; | ||||
| 
 | ||||
| 	const scanHandler = async () => { | ||||
| 		loading = true; | ||||
| 		const res = await scanDocs(localStorage.token); | ||||
|  | @ -41,6 +48,21 @@ | |||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	const embeddingModelUpdateHandler = async () => { | ||||
| 		loading1 = true; | ||||
| 		const res = await updateEmbeddingModel(localStorage.token, embeddingModel); | ||||
| 		loading1 = false; | ||||
| 
 | ||||
| 		if (res) { | ||||
| 			console.log('embeddingModelUpdateHandler:', res); | ||||
| 			if (res.status == true) { | ||||
| 				toast.success($i18n.t('Model {{embedding_model}} update complete!', res)); | ||||
| 			} else { | ||||
| 				toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res)); | ||||
| 			} | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	const submitHandler = async () => { | ||||
| 		const res = await updateRAGConfig(localStorage.token, { | ||||
| 			pdf_extract_images: pdfExtractImages, | ||||
|  | @ -62,6 +84,8 @@ | |||
| 			chunkOverlap = res.chunk.chunk_overlap; | ||||
| 		} | ||||
| 
 | ||||
| 		embeddingModel = await getEmbeddingModel(localStorage.token); | ||||
| 
 | ||||
| 		querySettings = await getQuerySettings(localStorage.token); | ||||
| 	}); | ||||
| </script> | ||||
|  | @ -137,6 +161,67 @@ | |||
| 					{/if} | ||||
| 				</button> | ||||
| 			</div> | ||||
| 
 | ||||
| 			<div class="  flex w-full justify-between"> | ||||
| 				<div class=" self-center text-xs font-medium"> | ||||
| 					{$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)} | ||||
| 				</div> | ||||
| 
 | ||||
| 				<button | ||||
| 					class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading1 | ||||
| 						? ' cursor-not-allowed' | ||||
| 						: ''}" | ||||
| 					on:click={() => { | ||||
| 						embeddingModelUpdateHandler(embeddingModel); | ||||
| 						console.log('Update embedding model:', embeddingModel.embedding_model); | ||||
| 					}} | ||||
| 					type="button" | ||||
| 					disabled={loading1} | ||||
| 				> | ||||
| 					<div class="self-center font-medium">{$i18n.t('Update')}</div> | ||||
| 
 | ||||
| 					<!-- <svg | ||||
| 						xmlns="http://www.w3.org/2000/svg" | ||||
| 						viewBox="0 0 16 16" | ||||
| 						fill="currentColor" | ||||
| 						class="w-3 h-3" | ||||
| 					> | ||||
| 						<path | ||||
| 							fill-rule="evenodd" | ||||
| 							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z" | ||||
| 							clip-rule="evenodd" | ||||
| 						/> | ||||
| 					</svg> --> | ||||
| 
 | ||||
| 					{#if loading1} | ||||
| 						<div class="ml-3 self-center"> | ||||
| 							<svg | ||||
| 								class=" w-3 h-3" | ||||
| 								viewBox="0 0 24 24" | ||||
| 								fill="currentColor" | ||||
| 								xmlns="http://www.w3.org/2000/svg" | ||||
| 								><style> | ||||
| 									.spinner_ajPY { | ||||
| 										transform-origin: center; | ||||
| 										animation: spinner_AtaB 0.75s infinite linear; | ||||
| 									} | ||||
| 									@keyframes spinner_AtaB { | ||||
| 										100% { | ||||
| 											transform: rotate(360deg); | ||||
| 										} | ||||
| 									} | ||||
| 								</style><path | ||||
| 									d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z" | ||||
| 									opacity=".25" | ||||
| 								/><path | ||||
| 									d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z" | ||||
| 									class="spinner_ajPY" | ||||
| 								/></svg | ||||
| 							> | ||||
| 						</div> | ||||
| 					{/if} | ||||
| 				</button> | ||||
| 			</div> | ||||
| 		</div> | ||||
| 
 | ||||
| 		<hr class=" dark:border-gray-700" /> | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Self Denial
						Self Denial