diff --git a/Dockerfile b/Dockerfile index 520c2964..03dccefe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,10 +30,24 @@ ENV WEBUI_SECRET_KEY "" ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true -#Whisper TTS Settings +######## Preloaded models ######## +# whisper TTS Settings ENV WHISPER_MODEL="base" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" +# RAG Embedding Model Settings +# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers +# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard +# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) +# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" +# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance +ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" +ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" +ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR + +######## Preloaded models ######## + WORKDIR /app/backend # install python dependencies @@ -48,9 +62,10 @@ RUN apt-get update \ && apt-get install -y pandoc netcat-openbsd \ && rm -rf /var/lib/apt/lists/* -# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" - +# preload embedding model +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])" +# preload tts model +RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" # copy embedding weight from build RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 86e79c47..d8cb415f 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -56,7 +56,7 @@ def transcribe( model = WhisperModel( WHISPER_MODEL, - device="cpu", + device="auto", compute_type="int8", download_root=WHISPER_MODEL_DIR, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 6d06456f..4176d567 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,6 +1,5 @@ from fastapi import ( FastAPI, - Request, Depends, HTTPException, status, @@ -14,7 +13,8 @@ import os, shutil from pathlib import Path from typing import List -# from chromadb.utils import embedding_functions +from sentence_transformers import SentenceTransformer +from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( WebBaseLoader, @@ -30,16 +30,12 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.chains import RetrievalQA -from langchain_community.vectorstores import Chroma - from pydantic import BaseModel from typing import Optional import mimetypes import uuid import json -import time from apps.web.models.documents import ( @@ -58,23 +54,37 @@ from utils.utils import get_current_user, get_admin_user from config import ( UPLOAD_DIR, DOCS_DIR, - EMBED_MODEL, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, RAG_TEMPLATE, ) + from constants import ERROR_MESSAGES -# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( -# model_name=EMBED_MODEL -# ) +# +# if RAG_EMBEDDING_MODEL: +# sentence_transformer_ef = SentenceTransformer( +# model_name_or_path=RAG_EMBEDDING_MODEL, +# cache_folder=RAG_EMBEDDING_MODEL_DIR, +# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, +# ) + app = FastAPI() app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) +) origins = ["*"] @@ -106,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - collection = CHROMA_CLIENT.create_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) collection.add( documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] @@ -126,6 +139,38 @@ async def get_status(): "status": True, "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, + "template": app.state.RAG_TEMPLATE, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +@app.get("/embedding/model") +async def get_embedding_model(user=Depends(get_admin_user)): + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +class EmbeddingModelUpdateForm(BaseModel): + embedding_model: str + + +@app.post("/embedding/model/update") +async def update_embedding_model( + form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) +): + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) + ) + + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, } @@ -190,8 +235,10 @@ def query_doc( user=Depends(get_current_user), ): try: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, + embedding_function=app.state.sentence_transformer_ef, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result @@ -263,9 +310,12 @@ def query_collection( for collection_name in form_data.collection_names: try: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=collection_name, + embedding_function=app.state.sentence_transformer_ef, ) + result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) diff --git a/backend/config.py b/backend/config.py index 440256c4..b80bc081 100644 --- a/backend/config.py +++ b/backend/config.py @@ -136,7 +136,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": #################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -EMBED_MODEL = "all-MiniLM-L6-v2" +# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) +RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") +# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( + "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" +) CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False),