choose embedding model when using docker

This commit is contained in:
Jannik Streidl 2024-02-17 19:38:29 +01:00
parent 4c3edd0375
commit 1846c1e80d
3 changed files with 46 additions and 20 deletions

View file

@ -1,6 +1,5 @@
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
@ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from typing import List
# from chromadb.utils import embedding_functions
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import (
WebBaseLoader,
@ -28,24 +27,19 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel
from typing import Optional
import uuid
import time
from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
# model_name=EMBED_MODEL
# )
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
app = FastAPI()
@ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]
try:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef)
else:
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
@ -109,9 +109,17 @@ def query_doc(
user=Depends(get_current_user),
):
try:
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
)
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
except Exception as e:
@ -182,9 +190,18 @@ def query_collection(
for collection_name in form_data.collection_names:
try:
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)