diff --git a/Dockerfile b/Dockerfile index 72230348..38f2a53f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. -ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" +ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" WORKDIR /app/backend @@ -55,7 +55,7 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # preload embedding model -RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])" # preload tts model RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 9e90c839..5ab3b843 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -51,7 +51,7 @@ from utils.utils import get_current_user, get_admin_user from config import ( UPLOAD_DIR, DOCS_DIR, - SENTENCE_TRANSFORMER_EMBED_MODEL, + RAG_EMBEDDING_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -60,7 +60,11 @@ from config import ( from constants import ERROR_MESSAGES -sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) + +if RAG_EMBEDDING_MODEL: + sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=RAG_EMBEDDING_MODEL + ) app = FastAPI() @@ -98,17 +102,18 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) - + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=sentence_transformer_ef + ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.create_collection(name=collection_name) collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) return True except Exception as e: print(e) @@ -188,16 +193,16 @@ def query_doc( user=Depends(get_current_user), ): try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, embedding_function=sentence_transformer_ef, ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + name=form_data.collection_name, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result @@ -269,18 +274,18 @@ def query_collection( for collection_name in form_data.collection_names: try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=collection_name, embedding_function=sentence_transformer_ef, ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=collection_name, + name=collection_name, ) - + result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) diff --git a/backend/config.py b/backend/config.py index 76911e34..2cc6c2a5 100644 --- a/backend/config.py +++ b/backend/config.py @@ -137,7 +137,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) -SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") +RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False),