choose embedding model when using docker

This commit is contained in:
Jannik Streidl 2024-02-17 19:38:29 +01:00
parent 4c3edd0375
commit 1846c1e80d
3 changed files with 46 additions and 20 deletions

View file

@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY ""
ENV SCARF_NO_ANALYTICS true ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true ENV DO_NOT_TRACK true
#Whisper TTS Settings # whisper TTS Settings
ENV WHISPER_MODEL="base" ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2"
WORKDIR /app/backend WORKDIR /app/backend
# install python dependencies # install python dependencies
@ -48,7 +54,9 @@ RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \ && apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" # preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"

View file

@ -1,6 +1,5 @@
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request,
Depends, Depends,
HTTPException, HTTPException,
status, status,
@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil
from typing import List from typing import List
# from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import uuid import uuid
import time
from utils.misc import calculate_sha256, calculate_sha256_string from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
# model_name=EMBED_MODEL
# )
app = FastAPI() app = FastAPI()
@ -78,6 +72,12 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
else:
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add( collection.add(
@ -109,6 +109,14 @@ def query_doc(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
) )
@ -182,9 +190,18 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=form_data.collection_name,
embedding_function=sentence_transformer_ef
) )
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
)
result = collection.query( result = collection.query(
query_texts=[form_data.query], n_results=form_data.k query_texts=[form_data.query], n_results=form_data.k
) )

View file

@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL")
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),