forked from open-webui/open-webui
		
	Improve embedding model update & resolve network dependency
* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior * Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub * Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model() * Add GUI setting to execute manual update process
This commit is contained in:
		
							parent
							
								
									62392aa88a
								
							
						
					
					
						commit
						3b66aa55c0
					
				
					 5 changed files with 218 additions and 19 deletions
				
			
		|  | @ -13,7 +13,6 @@ import os, shutil, logging, re | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import List | from typing import List | ||||||
| 
 | 
 | ||||||
| from sentence_transformers import SentenceTransformer |  | ||||||
| from chromadb.utils import embedding_functions | from chromadb.utils import embedding_functions | ||||||
| 
 | 
 | ||||||
| from langchain_community.document_loaders import ( | from langchain_community.document_loaders import ( | ||||||
|  | @ -45,7 +44,7 @@ from apps.web.models.documents import ( | ||||||
|     DocumentResponse, |     DocumentResponse, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from apps.rag.utils import query_doc, query_collection | from apps.rag.utils import query_doc, query_collection, embedding_model_get_path | ||||||
| 
 | 
 | ||||||
| from utils.misc import ( | from utils.misc import ( | ||||||
|     calculate_sha256, |     calculate_sha256, | ||||||
|  | @ -60,6 +59,7 @@ from config import ( | ||||||
|     DOCS_DIR, |     DOCS_DIR, | ||||||
|     RAG_EMBEDDING_MODEL, |     RAG_EMBEDDING_MODEL, | ||||||
|     RAG_EMBEDDING_MODEL_DEVICE_TYPE, |     RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
|  |     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||||
|     CHROMA_CLIENT, |     CHROMA_CLIENT, | ||||||
|     CHUNK_SIZE, |     CHUNK_SIZE, | ||||||
|     CHUNK_OVERLAP, |     CHUNK_OVERLAP, | ||||||
|  | @ -71,15 +71,6 @@ from constants import ERROR_MESSAGES | ||||||
| log = logging.getLogger(__name__) | log = logging.getLogger(__name__) | ||||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||||
| 
 | 
 | ||||||
| # |  | ||||||
| # if RAG_EMBEDDING_MODEL: |  | ||||||
| #    sentence_transformer_ef = SentenceTransformer( |  | ||||||
| #        model_name_or_path=RAG_EMBEDDING_MODEL, |  | ||||||
| #        cache_folder=RAG_EMBEDDING_MODEL_DIR, |  | ||||||
| #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, |  | ||||||
| #    ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| 
 | 
 | ||||||
| app.state.PDF_EXTRACT_IMAGES = False | app.state.PDF_EXTRACT_IMAGES = False | ||||||
|  | @ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE | ||||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||||
|  | app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE) | ||||||
| app.state.TOP_K = 4 | app.state.TOP_K = 4 | ||||||
| 
 | 
 | ||||||
| app.state.sentence_transformer_ef = ( | app.state.sentence_transformer_ef = ( | ||||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( |     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, |         model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, |         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
|     ) |     ) | ||||||
| ) | ) | ||||||
|  | @ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)): | ||||||
|     return { |     return { | ||||||
|         "status": True, |         "status": True, | ||||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, |         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel): | ||||||
| async def update_embedding_model( | async def update_embedding_model( | ||||||
|     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) |     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||||
| ): | ): | ||||||
|  |     status = True | ||||||
|  |     old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH | ||||||
|     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model |     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||||
|     app.state.sentence_transformer_ef = ( | 
 | ||||||
|         embedding_functions.SentenceTransformerEmbeddingFunction( |     log.debug(f"form_data.embedding_model: {form_data.embedding_model}") | ||||||
|             model_name=app.state.RAG_EMBEDDING_MODEL, |     log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}") | ||||||
|             device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | 
 | ||||||
|  |     try: | ||||||
|  |         app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True) | ||||||
|  |         app.state.sentence_transformer_ef = ( | ||||||
|  |             embedding_functions.SentenceTransformerEmbeddingFunction( | ||||||
|  |                 model_name=app.state.RAG_EMBEDDING_MODEL_PATH, | ||||||
|  |                 device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|     ) |     except Exception as e:  | ||||||
|  |         log.exception(f"Problem updating embedding model: {e}") | ||||||
|  |         raise HTTPException( | ||||||
|  |             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||
|  |             detail=e, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: | ||||||
|  |       status = False | ||||||
|  | 
 | ||||||
|  |     log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}") | ||||||
|  |     log.debug(f"old_model_path: {old_model_path}") | ||||||
|  |     log.debug(f"status: {status}") | ||||||
| 
 | 
 | ||||||
|     return { |     return { | ||||||
|         "status": True, |         "status": status, | ||||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, |         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |         "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,6 +1,8 @@ | ||||||
|  | import os | ||||||
| import re | import re | ||||||
| import logging | import logging | ||||||
| from typing import List | from typing import List | ||||||
|  | from huggingface_hub import snapshot_download | ||||||
| 
 | 
 | ||||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||||
| 
 | 
 | ||||||
|  | @ -188,3 +190,36 @@ def rag_messages(docs, messages, template, k, embedding_function): | ||||||
|     messages[last_user_message_idx] = new_user_message |     messages[last_user_message_idx] = new_user_message | ||||||
| 
 | 
 | ||||||
|     return messages |     return messages | ||||||
|  | 
 | ||||||
|  | def embedding_model_get_path(embedding_model: str, update_embedding_model: bool = False): | ||||||
|  |     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||||
|  |     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||||
|  |     local_files_only = not update_embedding_model | ||||||
|  |     snapshot_kwargs = { | ||||||
|  |         "cache_dir": cache_dir, | ||||||
|  |         "local_files_only": local_files_only, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     log.debug(f"SENTENCE_TRANSFORMERS_HOME cache_dir: {cache_dir}") | ||||||
|  |     log.debug(f"embedding_model: {embedding_model}") | ||||||
|  |     log.debug(f"update_embedding_model: {update_embedding_model}") | ||||||
|  |     log.debug(f"local_files_only: {local_files_only}") | ||||||
|  | 
 | ||||||
|  |     # Inspiration from upstream sentence_transformers | ||||||
|  |     if (os.path.exists(embedding_model) or ("\\" in embedding_model or embedding_model.count("/") > 1) and local_files_only): | ||||||
|  |         # If fully qualified path exists, return input, else set repo_id | ||||||
|  |         return embedding_model | ||||||
|  |     elif "/" not in embedding_model: | ||||||
|  |         # Set valid repo_id for model short-name | ||||||
|  |         embedding_model = "sentence-transformers" + "/" + embedding_model | ||||||
|  | 
 | ||||||
|  |     snapshot_kwargs["repo_id"] = embedding_model | ||||||
|  | 
 | ||||||
|  |     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||||
|  |     try: | ||||||
|  |         embedding_model_repo_path = snapshot_download(**snapshot_kwargs) | ||||||
|  |         log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}") | ||||||
|  |         return embedding_model_repo_path | ||||||
|  |     except Exception as e: | ||||||
|  |         log.exception(f"Cannot determine embedding model snapshot path: {e}") | ||||||
|  |         return embedding_model | ||||||
|  |  | ||||||
|  | @ -395,6 +395,9 @@ RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") | ||||||
| RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( | RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( | ||||||
|     "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" |     "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" | ||||||
| ) | ) | ||||||
|  | RAG_EMBEDDING_MODEL_AUTO_UPDATE = False | ||||||
|  | if os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true": | ||||||
|  |     RAG_EMBEDDING_MODEL_AUTO_UPDATE = True | ||||||
| CHROMA_CLIENT = chromadb.PersistentClient( | CHROMA_CLIENT = chromadb.PersistentClient( | ||||||
|     path=CHROMA_DATA_PATH, |     path=CHROMA_DATA_PATH, | ||||||
|     settings=Settings(allow_reset=True, anonymized_telemetry=False), |     settings=Settings(allow_reset=True, anonymized_telemetry=False), | ||||||
|  |  | ||||||
|  | @ -345,3 +345,64 @@ export const resetVectorDB = async (token: string) => { | ||||||
| 
 | 
 | ||||||
| 	return res; | 	return res; | ||||||
| }; | }; | ||||||
|  | 
 | ||||||
|  | export const getEmbeddingModel = async (token: string) => { | ||||||
|  | 	let error = null; | ||||||
|  | 
 | ||||||
|  | 	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model`, { | ||||||
|  | 		method: 'GET', | ||||||
|  | 		headers: { | ||||||
|  | 			'Content-Type': 'application/json', | ||||||
|  | 			Authorization: `Bearer ${token}` | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 		.then(async (res) => { | ||||||
|  | 			if (!res.ok) throw await res.json(); | ||||||
|  | 			return res.json(); | ||||||
|  | 		}) | ||||||
|  | 		.catch((err) => { | ||||||
|  | 			console.log(err); | ||||||
|  | 			error = err.detail; | ||||||
|  | 			return null; | ||||||
|  | 		}); | ||||||
|  | 
 | ||||||
|  | 	if (error) { | ||||||
|  | 		throw error; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return res; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | type EmbeddingModelUpdateForm = { | ||||||
|  | 	embedding_model: string; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | export const updateEmbeddingModel = async (token: string, payload: EmbeddingModelUpdateForm) => { | ||||||
|  | 	let error = null; | ||||||
|  | 
 | ||||||
|  | 	const res = await fetch(`${RAG_API_BASE_URL}/embedding/model/update`, { | ||||||
|  | 		method: 'POST', | ||||||
|  | 		headers: { | ||||||
|  | 			'Content-Type': 'application/json', | ||||||
|  | 			Authorization: `Bearer ${token}` | ||||||
|  | 		}, | ||||||
|  | 		body: JSON.stringify({ | ||||||
|  | 			...payload | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 		.then(async (res) => { | ||||||
|  | 			if (!res.ok) throw await res.json(); | ||||||
|  | 			return res.json(); | ||||||
|  | 		}) | ||||||
|  | 		.catch((err) => { | ||||||
|  | 			console.log(err); | ||||||
|  | 			error = err.detail; | ||||||
|  | 			return null; | ||||||
|  | 		}); | ||||||
|  | 
 | ||||||
|  | 	if (error) { | ||||||
|  | 		throw error; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return res; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | @ -6,7 +6,9 @@ | ||||||
| 		getQuerySettings, | 		getQuerySettings, | ||||||
| 		scanDocs, | 		scanDocs, | ||||||
| 		updateQuerySettings, | 		updateQuerySettings, | ||||||
| 		resetVectorDB | 		resetVectorDB, | ||||||
|  | 		getEmbeddingModel, | ||||||
|  | 		updateEmbeddingModel | ||||||
| 	} from '$lib/apis/rag'; | 	} from '$lib/apis/rag'; | ||||||
| 
 | 
 | ||||||
| 	import { documents } from '$lib/stores'; | 	import { documents } from '$lib/stores'; | ||||||
|  | @ -18,6 +20,7 @@ | ||||||
| 	export let saveHandler: Function; | 	export let saveHandler: Function; | ||||||
| 
 | 
 | ||||||
| 	let loading = false; | 	let loading = false; | ||||||
|  | 	let loading1 = false; | ||||||
| 
 | 
 | ||||||
| 	let showResetConfirm = false; | 	let showResetConfirm = false; | ||||||
| 
 | 
 | ||||||
|  | @ -30,6 +33,10 @@ | ||||||
| 		k: 4 | 		k: 4 | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
|  | 	let embeddingModel = { | ||||||
|  | 		embedding_model: '', | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
| 	const scanHandler = async () => { | 	const scanHandler = async () => { | ||||||
| 		loading = true; | 		loading = true; | ||||||
| 		const res = await scanDocs(localStorage.token); | 		const res = await scanDocs(localStorage.token); | ||||||
|  | @ -41,6 +48,21 @@ | ||||||
| 		} | 		} | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
|  | 	const embeddingModelUpdateHandler = async () => { | ||||||
|  | 		loading1 = true; | ||||||
|  | 		const res = await updateEmbeddingModel(localStorage.token, embeddingModel); | ||||||
|  | 		loading1 = false; | ||||||
|  | 
 | ||||||
|  | 		if (res) { | ||||||
|  | 			console.log('embeddingModelUpdateHandler:', res); | ||||||
|  | 			if (res.status == true) { | ||||||
|  | 				toast.success($i18n.t('Model {{embedding_model}} update complete!', res)); | ||||||
|  | 			} else { | ||||||
|  | 				toast.error($i18n.t('Model {{embedding_model}} update failed or not required!', res)); | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
| 	const submitHandler = async () => { | 	const submitHandler = async () => { | ||||||
| 		const res = await updateRAGConfig(localStorage.token, { | 		const res = await updateRAGConfig(localStorage.token, { | ||||||
| 			pdf_extract_images: pdfExtractImages, | 			pdf_extract_images: pdfExtractImages, | ||||||
|  | @ -62,6 +84,8 @@ | ||||||
| 			chunkOverlap = res.chunk.chunk_overlap; | 			chunkOverlap = res.chunk.chunk_overlap; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		embeddingModel = await getEmbeddingModel(localStorage.token); | ||||||
|  | 
 | ||||||
| 		querySettings = await getQuerySettings(localStorage.token); | 		querySettings = await getQuerySettings(localStorage.token); | ||||||
| 	}); | 	}); | ||||||
| </script> | </script> | ||||||
|  | @ -137,6 +161,67 @@ | ||||||
| 					{/if} | 					{/if} | ||||||
| 				</button> | 				</button> | ||||||
| 			</div> | 			</div> | ||||||
|  | 
 | ||||||
|  | 			<div class="  flex w-full justify-between"> | ||||||
|  | 				<div class=" self-center text-xs font-medium"> | ||||||
|  | 					{$i18n.t('Update embedding model {{embedding_model}}', embeddingModel)} | ||||||
|  | 				</div> | ||||||
|  | 
 | ||||||
|  | 				<button | ||||||
|  | 					class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading1 | ||||||
|  | 						? ' cursor-not-allowed' | ||||||
|  | 						: ''}" | ||||||
|  | 					on:click={() => { | ||||||
|  | 						embeddingModelUpdateHandler(embeddingModel); | ||||||
|  | 						console.log('Update embedding model:', embeddingModel.embedding_model); | ||||||
|  | 					}} | ||||||
|  | 					type="button" | ||||||
|  | 					disabled={loading1} | ||||||
|  | 				> | ||||||
|  | 					<div class="self-center font-medium">{$i18n.t('Update')}</div> | ||||||
|  | 
 | ||||||
|  | 					<!-- <svg | ||||||
|  | 						xmlns="http://www.w3.org/2000/svg" | ||||||
|  | 						viewBox="0 0 16 16" | ||||||
|  | 						fill="currentColor" | ||||||
|  | 						class="w-3 h-3" | ||||||
|  | 					> | ||||||
|  | 						<path | ||||||
|  | 							fill-rule="evenodd" | ||||||
|  | 							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z" | ||||||
|  | 							clip-rule="evenodd" | ||||||
|  | 						/> | ||||||
|  | 					</svg> --> | ||||||
|  | 
 | ||||||
|  | 					{#if loading1} | ||||||
|  | 						<div class="ml-3 self-center"> | ||||||
|  | 							<svg | ||||||
|  | 								class=" w-3 h-3" | ||||||
|  | 								viewBox="0 0 24 24" | ||||||
|  | 								fill="currentColor" | ||||||
|  | 								xmlns="http://www.w3.org/2000/svg" | ||||||
|  | 								><style> | ||||||
|  | 									.spinner_ajPY { | ||||||
|  | 										transform-origin: center; | ||||||
|  | 										animation: spinner_AtaB 0.75s infinite linear; | ||||||
|  | 									} | ||||||
|  | 									@keyframes spinner_AtaB { | ||||||
|  | 										100% { | ||||||
|  | 											transform: rotate(360deg); | ||||||
|  | 										} | ||||||
|  | 									} | ||||||
|  | 								</style><path | ||||||
|  | 									d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z" | ||||||
|  | 									opacity=".25" | ||||||
|  | 								/><path | ||||||
|  | 									d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z" | ||||||
|  | 									class="spinner_ajPY" | ||||||
|  | 								/></svg | ||||||
|  | 							> | ||||||
|  | 						</div> | ||||||
|  | 					{/if} | ||||||
|  | 				</button> | ||||||
|  | 			</div> | ||||||
| 		</div> | 		</div> | ||||||
| 
 | 
 | ||||||
| 		<hr class=" dark:border-gray-700" /> | 		<hr class=" dark:border-gray-700" /> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue