refac: more descriptive var names

This commit is contained in:
Timothy J. Baek 2024-02-18 11:16:10 -08:00
parent 4b88e7e44f
commit 0cb0358485
3 changed files with 26 additions and 21 deletions

View file

@ -51,7 +51,7 @@ from utils.utils import get_current_user, get_admin_user
from config import (
UPLOAD_DIR,
DOCS_DIR,
SENTENCE_TRANSFORMER_EMBED_MODEL,
RAG_EMBEDDING_MODEL,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
@ -60,7 +60,11 @@ from config import (
from constants import ERROR_MESSAGES
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
if RAG_EMBEDDING_MODEL:
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=RAG_EMBEDDING_MODEL
)
app = FastAPI()
@ -98,17 +102,18 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]
try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
@ -188,16 +193,16 @@ def query_doc(
user=Depends(get_current_user),
):
try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef,
)
else:
# for local development use the default model
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
name=form_data.collection_name,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
@ -269,18 +274,18 @@ def query_collection(
for collection_name in form_data.collection_names:
try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=sentence_transformer_ef,
)
else:
# for local development use the default model
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
name=collection_name,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)