forked from open-webui/open-webui
choose embedding model when using docker
This commit is contained in:
parent
4c3edd0375
commit
1846c1e80d
3 changed files with 46 additions and 20 deletions
12
Dockerfile
12
Dockerfile
|
@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY ""
|
||||||
ENV SCARF_NO_ANALYTICS true
|
ENV SCARF_NO_ANALYTICS true
|
||||||
ENV DO_NOT_TRACK true
|
ENV DO_NOT_TRACK true
|
||||||
|
|
||||||
#Whisper TTS Settings
|
# whisper TTS Settings
|
||||||
ENV WHISPER_MODEL="base"
|
ENV WHISPER_MODEL="base"
|
||||||
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
||||||
|
|
||||||
|
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||||
|
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||||
|
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
|
||||||
|
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||||
|
ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2"
|
||||||
|
|
||||||
WORKDIR /app/backend
|
WORKDIR /app/backend
|
||||||
|
|
||||||
# install python dependencies
|
# install python dependencies
|
||||||
|
@ -48,7 +54,9 @@ RUN apt-get update \
|
||||||
&& apt-get install -y pandoc netcat-openbsd \
|
&& apt-get install -y pandoc netcat-openbsd \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
|
# preload embedding model
|
||||||
|
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])"
|
||||||
|
# preload tts model
|
||||||
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
|
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
FastAPI,
|
FastAPI,
|
||||||
Request,
|
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
status,
|
status,
|
||||||
|
@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
import os, shutil
|
import os, shutil
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
# from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
from langchain_community.document_loaders import (
|
from langchain_community.document_loaders import (
|
||||||
WebBaseLoader,
|
WebBaseLoader,
|
||||||
|
@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
|
||||||
UnstructuredExcelLoader,
|
UnstructuredExcelLoader,
|
||||||
)
|
)
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.vectorstores import Chroma
|
|
||||||
from langchain.chains import RetrievalQA
|
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
|
||||||
|
|
||||||
from utils.misc import calculate_sha256, calculate_sha256_string
|
from utils.misc import calculate_sha256, calculate_sha256_string
|
||||||
from utils.utils import get_current_user, get_admin_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
|
||||||
# model_name=EMBED_MODEL
|
|
||||||
# )
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -78,6 +72,12 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
||||||
metadatas = [doc.metadata for doc in docs]
|
metadatas = [doc.metadata for doc in docs]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
||||||
|
# if you use docker use the model from the environment variable
|
||||||
|
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# for local development use the default model
|
||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||||
|
|
||||||
collection.add(
|
collection.add(
|
||||||
|
@ -109,6 +109,14 @@ def query_doc(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
||||||
|
# if you use docker use the model from the environment variable
|
||||||
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
|
name=form_data.collection_name,
|
||||||
|
embedding_function=sentence_transformer_ef
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# for local development use the default model
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
name=form_data.collection_name,
|
name=form_data.collection_name,
|
||||||
)
|
)
|
||||||
|
@ -182,9 +190,18 @@ def query_collection(
|
||||||
|
|
||||||
for collection_name in form_data.collection_names:
|
for collection_name in form_data.collection_names:
|
||||||
try:
|
try:
|
||||||
|
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
|
||||||
|
# if you use docker use the model from the environment variable
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
name=collection_name,
|
name=form_data.collection_name,
|
||||||
|
embedding_function=sentence_transformer_ef
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# for local development use the default model
|
||||||
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
|
name=form_data.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
result = collection.query(
|
result = collection.query(
|
||||||
query_texts=[form_data.query], n_results=form_data.k
|
query_texts=[form_data.query], n_results=form_data.k
|
||||||
)
|
)
|
||||||
|
|
|
@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||||
EMBED_MODEL = "all-MiniLM-L6-v2"
|
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||||
|
SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL")
|
||||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||||
path=CHROMA_DATA_PATH,
|
path=CHROMA_DATA_PATH,
|
||||||
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
settings=Settings(allow_reset=True, anonymized_telemetry=False),
|
||||||
|
|
Loading…
Reference in a new issue