From 1846c1e80dc597d83ad70759742abce67884c0e0 Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Sat, 17 Feb 2024 19:38:29 +0100 Subject: [PATCH 1/6] choose embedding model when using docker --- Dockerfile | 12 ++++++++-- backend/apps/rag/main.py | 51 ++++++++++++++++++++++++++-------------- backend/config.py | 3 ++- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/Dockerfile b/Dockerfile index 520c2964..72230348 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY "" ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true -#Whisper TTS Settings +# whisper TTS Settings ENV WHISPER_MODEL="base" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" +# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers +# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard +# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" +# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. +ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" + WORKDIR /app/backend # install python dependencies @@ -48,7 +54,9 @@ RUN apt-get update \ && apt-get install -y pandoc netcat-openbsd \ && rm -rf /var/lib/apt/lists/* -# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" +# preload embedding model +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" +# preload tts model RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 07a30ade..defe10f9 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -1,6 +1,5 @@ from fastapi import ( FastAPI, - Request, Depends, HTTPException, status, @@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware import os, shutil from typing import List -# from chromadb.utils import embedding_functions +from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( WebBaseLoader, @@ -28,24 +27,19 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores import Chroma -from langchain.chains import RetrievalQA from pydantic import BaseModel from typing import Optional import uuid -import time from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user, get_admin_user -from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES -# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( -# model_name=EMBED_MODEL -# ) +sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) app = FastAPI() @@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - collection = CHROMA_CLIENT.create_collection(name=collection_name) + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) + + else: + # for local development use the default model + collection = CHROMA_CLIENT.create_collection(name=collection_name) collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) return True except Exception as e: print(e) @@ -109,9 +109,17 @@ def query_doc( user=Depends(get_current_user), ): try: - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=sentence_transformer_ef + ) + else: + # for local development use the default model + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: @@ -182,9 +190,18 @@ def query_collection( for collection_name in form_data.collection_names: try: - collection = CHROMA_CLIENT.get_collection( - name=collection_name, + if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=sentence_transformer_ef + ) + else: + # for local development use the default model + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, ) + result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) diff --git a/backend/config.py b/backend/config.py index d7c89b3b..023954a4 100644 --- a/backend/config.py +++ b/backend/config.py @@ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": #################################### CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" -EMBED_MODEL = "all-MiniLM-L6-v2" +# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) +SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), From bc3dd34d8b7980668aa97041d804a84bc3e24e65 Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Sun, 18 Feb 2024 09:17:43 +0100 Subject: [PATCH 2/6] collection query fix --- backend/apps/rag/main.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index defe10f9..8a5a12d3 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -29,11 +29,13 @@ from langchain_community.document_loaders import ( from langchain.text_splitter import RecursiveCharacterTextSplitter + from pydantic import BaseModel from typing import Optional import uuid + from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user, get_admin_user from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP @@ -113,12 +115,12 @@ def query_doc( # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, - embedding_function=sentence_transformer_ef + embedding_function=sentence_transformer_ef, ) else: # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result @@ -191,16 +193,16 @@ def query_collection( for collection_name in form_data.collection_names: try: if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - embedding_function=sentence_transformer_ef + name=collection_name, + embedding_function=sentence_transformer_ef, ) else: # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + name=collection_name, + ) result = collection.query( query_texts=[form_data.query], n_results=form_data.k From 0cb035848531aea96d20882779ab9b80d028ca48 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 18 Feb 2024 11:16:10 -0800 Subject: [PATCH 3/6] refac: more descriptive var names --- Dockerfile | 4 ++-- backend/apps/rag/main.py | 41 ++++++++++++++++++++++------------------ backend/config.py | 2 +- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/Dockerfile b/Dockerfile index 72230348..38f2a53f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. -ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" +ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" WORKDIR /app/backend @@ -55,7 +55,7 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # preload embedding model -RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])" # preload tts model RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 9e90c839..5ab3b843 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -51,7 +51,7 @@ from utils.utils import get_current_user, get_admin_user from config import ( UPLOAD_DIR, DOCS_DIR, - SENTENCE_TRANSFORMER_EMBED_MODEL, + RAG_EMBEDDING_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -60,7 +60,11 @@ from config import ( from constants import ERROR_MESSAGES -sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) + +if RAG_EMBEDDING_MODEL: + sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=RAG_EMBEDDING_MODEL + ) app = FastAPI() @@ -98,17 +102,18 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) - + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=sentence_transformer_ef + ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.create_collection(name=collection_name) collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) return True except Exception as e: print(e) @@ -188,16 +193,16 @@ def query_doc( user=Depends(get_current_user), ): try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, embedding_function=sentence_transformer_ef, ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, + name=form_data.collection_name, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result @@ -269,18 +274,18 @@ def query_collection( for collection_name in form_data.collection_names: try: - if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: - # if you use docker use the model from the environment variable + if RAG_EMBEDDING_MODEL: + # if you use docker use the model from the environment variable collection = CHROMA_CLIENT.get_collection( name=collection_name, embedding_function=sentence_transformer_ef, ) else: - # for local development use the default model + # for local development use the default model collection = CHROMA_CLIENT.get_collection( - name=collection_name, + name=collection_name, ) - + result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) diff --git a/backend/config.py b/backend/config.py index 76911e34..2cc6c2a5 100644 --- a/backend/config.py +++ b/backend/config.py @@ -137,7 +137,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) -SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") +RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), From acf999013bbf9d5d9e41596dcbfc79c4d1288ae1 Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Mon, 19 Feb 2024 07:51:17 +0100 Subject: [PATCH 4/6] storing vectordb in project cache folder + device types --- Dockerfile | 12 +++++++++--- backend/apps/audio/main.py | 2 +- backend/apps/rag/main.py | 12 +++++++++++- backend/config.py | 3 +++ 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 38f2a53f..a7692fdb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,15 +30,21 @@ ENV WEBUI_SECRET_KEY "" ENV SCARF_NO_ANALYTICS true ENV DO_NOT_TRACK true +######## Preloaded models ######## # whisper TTS Settings ENV WHISPER_MODEL="base" ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" +# RAG Embedding Model Settings # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard -# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" +# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" +ENV SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" +# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance +ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" +######## Preloaded models ######## WORKDIR /app/backend @@ -55,9 +61,9 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # preload embedding model -RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])" +RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])" # preload tts model -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" +RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" # copy embedding weight from build diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 86e79c47..d8cb415f 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -56,7 +56,7 @@ def transcribe( model = WhisperModel( WHISPER_MODEL, - device="cpu", + device="auto", compute_type="int8", download_root=WHISPER_MODEL_DIR, ) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 5ab3b843..656ba4e5 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -13,6 +13,7 @@ import os, shutil from pathlib import Path from typing import List +from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( @@ -52,6 +53,7 @@ from config import ( UPLOAD_DIR, DOCS_DIR, RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_DEVICE_TYPE, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP, @@ -60,10 +62,18 @@ from config import ( from constants import ERROR_MESSAGES +# +#if RAG_EMBEDDING_MODEL: +# sentence_transformer_ef = SentenceTransformer( +# model_name_or_path=RAG_EMBEDDING_MODEL, +# cache_folder=RAG_EMBEDDING_MODEL_DIR, +# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, +# ) if RAG_EMBEDDING_MODEL: sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=RAG_EMBEDDING_MODEL + model_name=RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, ) app = FastAPI() diff --git a/backend/config.py b/backend/config.py index 2cc6c2a5..175b228e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -138,6 +138,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "") + +# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance +RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get("RAG_EMBEDDING_MODEL_DEVICE_TYPE", "") CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), From ab104d5905105ac62d9d4502573c859073aef991 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 19 Feb 2024 10:56:50 -0800 Subject: [PATCH 5/6] refac --- Dockerfile | 5 +++-- backend/config.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index a7692fdb..03dccefe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,9 +41,11 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" -ENV SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" # device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" +ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" +ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR + ######## Preloaded models ######## WORKDIR /app/backend @@ -65,7 +67,6 @@ RUN python -c "import os; from chromadb.utils import embedding_functions; senten # preload tts model RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" - # copy embedding weight from build RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx diff --git a/backend/config.py b/backend/config.py index 175b228e..b80bc081 100644 --- a/backend/config.py +++ b/backend/config.py @@ -137,10 +137,11 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) -RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "") - +RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance -RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get("RAG_EMBEDDING_MODEL_DEVICE_TYPE", "") +RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( + "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" +) CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True, anonymized_telemetry=False), From 7c127c35fcb52cdd29d05ea1dc734ad170dc96f3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 19 Feb 2024 11:05:45 -0800 Subject: [PATCH 6/6] feat: dynamic embedding model load --- backend/apps/rag/main.py | 92 ++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 656ba4e5..4176d567 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -35,6 +35,8 @@ from pydantic import BaseModel from typing import Optional import mimetypes import uuid +import json + from apps.web.models.documents import ( Documents, @@ -63,24 +65,26 @@ from config import ( from constants import ERROR_MESSAGES # -#if RAG_EMBEDDING_MODEL: +# if RAG_EMBEDDING_MODEL: # sentence_transformer_ef = SentenceTransformer( # model_name_or_path=RAG_EMBEDDING_MODEL, # cache_folder=RAG_EMBEDDING_MODEL_DIR, # device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, # ) -if RAG_EMBEDDING_MODEL: - sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=RAG_EMBEDDING_MODEL, - device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, - ) app = FastAPI() app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) +) origins = ["*"] @@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.create_collection( - name=collection_name, embedding_function=sentence_transformer_ef - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.create_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) collection.add( documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] @@ -139,6 +139,38 @@ async def get_status(): "status": True, "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, + "template": app.state.RAG_TEMPLATE, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +@app.get("/embedding/model") +async def get_embedding_model(user=Depends(get_admin_user)): + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +class EmbeddingModelUpdateForm(BaseModel): + embedding_model: str + + +@app.post("/embedding/model/update") +async def update_embedding_model( + form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) +): + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) + ) + + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, } @@ -203,17 +235,11 @@ def query_doc( user=Depends(get_current_user), ): try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - embedding_function=sentence_transformer_ef, - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: @@ -284,17 +310,11 @@ def query_collection( for collection_name in form_data.collection_names: try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=sentence_transformer_ef, - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - ) + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) result = collection.query( query_texts=[form_data.query], n_results=form_data.k