forked from open-webui/open-webui
Improve embedding model update & resolve network dependency
* Add config variable RAG_EMBEDDING_MODEL_AUTO_UPDATE to control update behavior * Add RAG utils embedding_model_get_path() function to output the filesystem path in addition to update of the model using huggingface_hub * Update and utilize existing RAG functions in main: get_embedding_model() & update_embedding_model() * Add GUI setting to execute manual update process
This commit is contained in:
parent
62392aa88a
commit
3b66aa55c0
5 changed files with 218 additions and 19 deletions
|
@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
|
@ -45,7 +44,7 @@ from apps.web.models.documents import (
|
|||
DocumentResponse,
|
||||
)
|
||||
|
||||
from apps.rag.utils import query_doc, query_collection
|
||||
from apps.rag.utils import query_doc, query_collection, embedding_model_get_path
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
|
@ -60,6 +59,7 @@ from config import (
|
|||
DOCS_DIR,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
|
@ -71,15 +71,6 @@ from constants import ERROR_MESSAGES
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
#
|
||||
# if RAG_EMBEDDING_MODEL:
|
||||
# sentence_transformer_ef = SentenceTransformer(
|
||||
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
||||
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
||||
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
# )
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
@ -87,11 +78,12 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
|||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE)
|
||||
app.state.TOP_K = 4
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
@ -132,6 +124,7 @@ async def get_embedding_model(user=Depends(get_admin_user)):
|
|||
return {
|
||||
"status": True,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
}
|
||||
|
||||
|
||||
|
@ -143,17 +136,39 @@ class EmbeddingModelUpdateForm(BaseModel):
|
|||
async def update_embedding_model(
|
||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
status = True
|
||||
old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
|
||||
log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
|
||||
log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}")
|
||||
|
||||
try:
|
||||
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True)
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=e,
|
||||
)
|
||||
|
||||
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
|
||||
status = False
|
||||
|
||||
log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}")
|
||||
log.debug(f"old_model_path: {old_model_path}")
|
||||
log.debug(f"status: {status}")
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"status": status,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue