forked from open-webui/open-webui
Merge pull request #1419 from lainedfles/embedding-model-fix-and-manual-update
feat: improve embedding model update & resolve network dependency
This commit is contained in:
commit
b9cadff16b
6 changed files with 438 additions and 210 deletions
|
@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
|
@ -46,7 +45,7 @@ from apps.web.models.documents import (
|
|||
DocumentResponse,
|
||||
)
|
||||
|
||||
from apps.rag.utils import query_doc, query_collection
|
||||
from apps.rag.utils import query_doc, query_collection, get_embedding_model_path
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
|
@ -60,6 +59,7 @@ from config import (
|
|||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
DEVICE_TYPE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
|
@ -78,12 +78,18 @@ app.state.PDF_EXTRACT_IMAGES = False
|
|||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
|
||||
|
||||
app.state.TOP_K = 4
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
model_name=get_embedding_model_path(
|
||||
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
||||
),
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
@ -135,17 +141,33 @@ class EmbeddingModelUpdateForm(BaseModel):
|
|||
async def update_embedding_model(
|
||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
try:
|
||||
sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(form_data.embedding_model, True),
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = sentence_transformer_ef
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
|
@ -188,3 +190,43 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
messages[last_user_message_idx] = new_user_message
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_embedding_model_path(
|
||||
embedding_model: str, update_embedding_model: bool = False
|
||||
):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_embedding_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"embedding_model: {embedding_model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(embedding_model)
|
||||
or ("\\" in embedding_model or embedding_model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return embedding_model
|
||||
elif "/" not in embedding_model:
|
||||
# Set valid repo_id for model short-name
|
||||
embedding_model = "sentence-transformers" + "/" + embedding_model
|
||||
|
||||
snapshot_kwargs["repo_id"] = embedding_model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
|
||||
return embedding_model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||
return embedding_model
|
||||
|
|
|
@ -403,6 +403,12 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
|||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
||||
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue