feat: dynamic embedding model load

This commit is contained in:
Timothy J. Baek 2024-02-19 11:05:45 -08:00
parent ab104d5905
commit 7c127c35fc

View file

@ -35,6 +35,8 @@ from pydantic import BaseModel
from typing import Optional from typing import Optional
import mimetypes import mimetypes
import uuid import uuid
import json
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
@ -63,24 +65,26 @@ from config import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
# #
#if RAG_EMBEDDING_MODEL: # if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer( # sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL, # model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR, # cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, # device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# ) # )
if RAG_EMBEDDING_MODEL:
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
app = FastAPI() app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"] origins = ["*"]
@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection( collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=sentence_transformer_ef name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
) )
else:
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add( collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@ -139,6 +139,38 @@ async def get_status():
"status": True, "status": True,
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
} }
@ -203,16 +235,10 @@ def query_doc(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
embedding_function=sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
) )
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
@ -284,16 +310,10 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
embedding_function=sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
) )
result = collection.query( result = collection.query(