From 7c127c35fcb52cdd29d05ea1dc734ad170dc96f3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 19 Feb 2024 11:05:45 -0800 Subject: [PATCH] feat: dynamic embedding model load --- backend/apps/rag/main.py | 92 ++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 656ba4e5..4176d567 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -35,6 +35,8 @@ from pydantic import BaseModel from typing import Optional import mimetypes import uuid +import json + from apps.web.models.documents import ( Documents, @@ -63,24 +65,26 @@ from config import ( from constants import ERROR_MESSAGES # -#if RAG_EMBEDDING_MODEL: +# if RAG_EMBEDDING_MODEL: # sentence_transformer_ef = SentenceTransformer( # model_name_or_path=RAG_EMBEDDING_MODEL, # cache_folder=RAG_EMBEDDING_MODEL_DIR, # device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, # ) -if RAG_EMBEDDING_MODEL: - sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=RAG_EMBEDDING_MODEL, - device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, - ) app = FastAPI() app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.RAG_TEMPLATE = RAG_TEMPLATE +app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) +) origins = ["*"] @@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: metadatas = [doc.metadata for doc in docs] try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.create_collection( - name=collection_name, embedding_function=sentence_transformer_ef - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.create_collection(name=collection_name) + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) collection.add( documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] @@ -139,6 +139,38 @@ async def get_status(): "status": True, "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, + "template": app.state.RAG_TEMPLATE, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +@app.get("/embedding/model") +async def get_embedding_model(user=Depends(get_admin_user)): + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, + } + + +class EmbeddingModelUpdateForm(BaseModel): + embedding_model: str + + +@app.post("/embedding/model/update") +async def update_embedding_model( + form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) +): + app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model + app.state.sentence_transformer_ef = ( + embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=app.state.RAG_EMBEDDING_MODEL, + device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, + ) + ) + + return { + "status": True, + "embedding_model": app.state.RAG_EMBEDDING_MODEL, } @@ -203,17 +235,11 @@ def query_doc( user=Depends(get_current_user), ): try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - embedding_function=sentence_transformer_ef, - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=form_data.collection_name, - ) + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=form_data.collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: @@ -284,17 +310,11 @@ def query_collection( for collection_name in form_data.collection_names: try: - if RAG_EMBEDDING_MODEL: - # if you use docker use the model from the environment variable - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - embedding_function=sentence_transformer_ef, - ) - else: - # for local development use the default model - collection = CHROMA_CLIENT.get_collection( - name=collection_name, - ) + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) result = collection.query( query_texts=[form_data.query], n_results=form_data.k