Merge pull request #772 from jannikstdl/choose-embedding-model

feat: choose embedding model when using docker
This commit is contained in:
Timothy Jaeryang Baek 2024-02-19 14:39:41 -05:00 committed by GitHub
commit c3916927bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 87 additions and 17 deletions

View file

@ -30,10 +30,24 @@ ENV WEBUI_SECRET_KEY ""
ENV SCARF_NO_ANALYTICS true ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true ENV DO_NOT_TRACK true
#Whisper TTS Settings ######## Preloaded models ########
# whisper TTS Settings
ENV WHISPER_MODEL="base" ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# RAG Embedding Model Settings
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
######## Preloaded models ########
WORKDIR /app/backend WORKDIR /app/backend
# install python dependencies # install python dependencies
@ -48,9 +62,10 @@ RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \ && apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" # preload embedding model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
# copy embedding weight from build # copy embedding weight from build
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2

View file

@ -56,7 +56,7 @@ def transcribe(
model = WhisperModel( model = WhisperModel(
WHISPER_MODEL, WHISPER_MODEL,
device="cpu", device="auto",
compute_type="int8", compute_type="int8",
download_root=WHISPER_MODEL_DIR, download_root=WHISPER_MODEL_DIR,
) )

View file

@ -1,6 +1,5 @@
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request,
Depends, Depends,
HTTPException, HTTPException,
status, status,
@ -14,7 +13,8 @@ import os, shutil
from pathlib import Path from pathlib import Path
from typing import List from typing import List
# from chromadb.utils import embedding_functions from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
@ -30,16 +30,12 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import Chroma
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import mimetypes import mimetypes
import uuid import uuid
import json import json
import time
from apps.web.models.documents import ( from apps.web.models.documents import (
@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
EMBED_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
CHROMA_CLIENT, CHROMA_CLIENT,
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( #
# model_name=EMBED_MODEL # if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# ) # )
app = FastAPI() app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"] origins = ["*"]
@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add( collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@ -126,6 +139,38 @@ async def get_status():
"status": True, "status": True,
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
} }
@ -190,8 +235,10 @@ def query_doc(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef,
) )
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
@ -263,9 +310,12 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
) )
result = collection.query( result = collection.query(
query_texts=[form_data.query], n_results=form_data.k query_texts=[form_data.query], n_results=form_data.k
) )

View file

@ -136,7 +136,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
)
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),