feat: dynamic embedding model load

This commit is contained in:
Timothy J. Baek 2024-02-19 11:05:45 -08:00
parent ab104d5905
commit 7c127c35fc

View file

@ -35,6 +35,8 @@ from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import json
from apps.web.models.documents import (
Documents,
@ -63,24 +65,26 @@ from config import (
from constants import ERROR_MESSAGES
#
#if RAG_EMBEDDING_MODEL:
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
if RAG_EMBEDDING_MODEL:
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"]
@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]
try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.create_collection(
name=collection_name, embedding_function=sentence_transformer_ef
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.create_collection(name=collection_name)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@ -139,6 +139,38 @@ async def get_status():
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@ -203,16 +235,10 @@ def query_doc(
user=Depends(get_current_user),
):
try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=sentence_transformer_ef,
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
@ -284,16 +310,10 @@ def query_collection(
for collection_name in form_data.collection_names:
try:
if RAG_EMBEDDING_MODEL:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=sentence_transformer_ef,
)
else:
# for local development use the default model
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(