refac: more descriptive var names

This commit is contained in:
Timothy J. Baek 2024-02-18 11:16:10 -08:00
parent 4b88e7e44f
commit 0cb0358485
3 changed files with 26 additions and 21 deletions

View file

@ -38,7 +38,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" # for better persormance and multilangauge support use "intfloat/multilingual-e5-large"
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
WORKDIR /app/backend WORKDIR /app/backend
@ -55,7 +55,7 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# preload embedding model # preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'])"
# preload tts model # preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"

View file

@ -51,7 +51,7 @@ from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
SENTENCE_TRANSFORMER_EMBED_MODEL, RAG_EMBEDDING_MODEL,
CHROMA_CLIENT, CHROMA_CLIENT,
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
@ -60,7 +60,11 @@ from config import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
if RAG_EMBEDDING_MODEL:
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=RAG_EMBEDDING_MODEL
)
app = FastAPI() app = FastAPI()
@ -98,17 +102,18 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=sentence_transformer_ef
)
else: else:
# for local development use the default model # for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add( collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
) )
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
@ -188,16 +193,16 @@ def query_doc(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
embedding_function=sentence_transformer_ef, embedding_function=sentence_transformer_ef,
) )
else: else:
# for local development use the default model # for local development use the default model
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
) )
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
@ -269,16 +274,16 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
embedding_function=sentence_transformer_ef, embedding_function=sentence_transformer_ef,
) )
else: else:
# for local development use the default model # for local development use the default model
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
) )
result = collection.query( result = collection.query(

View file

@ -137,7 +137,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "")
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),