forked from open-webui/open-webui
Merge pull request #772 from jannikstdl/choose-embedding-model
feat: choose embedding model when using docker
This commit is contained in:
commit
c3916927bb
4 changed files with 87 additions and 17 deletions
23
Dockerfile
23
Dockerfile
|
@ -30,10 +30,24 @@ ENV WEBUI_SECRET_KEY ""
|
|||
ENV SCARF_NO_ANALYTICS true
|
||||
ENV DO_NOT_TRACK true
|
||||
|
||||
#Whisper TTS Settings
|
||||
######## Preloaded models ########
|
||||
# whisper TTS Settings
|
||||
ENV WHISPER_MODEL="base"
|
||||
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
||||
|
||||
# RAG Embedding Model Settings
|
||||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
|
||||
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
|
||||
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
|
||||
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
|
||||
|
||||
######## Preloaded models ########
|
||||
|
||||
WORKDIR /app/backend
|
||||
|
||||
# install python dependencies
|
||||
|
@ -48,9 +62,10 @@ RUN apt-get update \
|
|||
&& apt-get install -y pandoc netcat-openbsd \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
|
||||
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
|
||||
|
||||
# preload embedding model
|
||||
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
|
||||
# preload tts model
|
||||
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
|
||||
|
||||
# copy embedding weight from build
|
||||
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
||||
|
|
|
@ -56,7 +56,7 @@ def transcribe(
|
|||
|
||||
model = WhisperModel(
|
||||
WHISPER_MODEL,
|
||||
device="cpu",
|
||||
device="auto",
|
||||
compute_type="int8",
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from fastapi import (
|
||||
FastAPI,
|
||||
Request,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
|
@ -14,7 +13,8 @@ import os, shutil
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# from chromadb.utils import embedding_functions
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
|
@ -30,16 +30,12 @@ from langchain_community.document_loaders import (
|
|||
UnstructuredExcelLoader,
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import mimetypes
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
from apps.web.models.documents import (
|
||||
|
@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user
|
|||
from config import (
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
EMBED_MODEL,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
RAG_TEMPLATE,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
# model_name=EMBED_MODEL
|
||||
#
|
||||
# if RAG_EMBEDDING_MODEL:
|
||||
# sentence_transformer_ef = SentenceTransformer(
|
||||
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
||||
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
||||
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
# )
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
origins = ["*"]
|
||||
|
@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
|||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
try:
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
collection.add(
|
||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||
|
@ -126,6 +139,38 @@ async def get_status():
|
|||
"status": True,
|
||||
"chunk_size": app.state.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||
"template": app.state.RAG_TEMPLATE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/embedding/model")
|
||||
async def get_embedding_model(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
embedding_model: str
|
||||
|
||||
|
||||
@app.post("/embedding/model/update")
|
||||
async def update_embedding_model(
|
||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
}
|
||||
|
||||
|
||||
|
@ -190,8 +235,10 @@ def query_doc(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=form_data.collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
||||
return result
|
||||
|
@ -263,9 +310,12 @@ def query_collection(
|
|||
|
||||
for collection_name in form_data.collection_names:
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[form_data.query], n_results=form_data.k
|
||||
)
|
||||
|
|
|
@ -136,7 +136,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
|||
####################################
|
||||
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
EMBED_MODEL = "all-MiniLM-L6-v2"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
|
||||
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
|
||||
)
|
||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||
|
|
Loading…
Reference in a new issue