forked from open-webui/open-webui
		
	feat: dynamic embedding model load
This commit is contained in:
		
							parent
							
								
									ab104d5905
								
							
						
					
					
						commit
						7c127c35fc
					
				
					 1 changed files with 56 additions and 36 deletions
				
			
		|  | @ -35,6 +35,8 @@ from pydantic import BaseModel | |||
| from typing import Optional | ||||
| import mimetypes | ||||
| import uuid | ||||
| import json | ||||
| 
 | ||||
| 
 | ||||
| from apps.web.models.documents import ( | ||||
|     Documents, | ||||
|  | @ -70,17 +72,19 @@ from constants import ERROR_MESSAGES | |||
| #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
| #    ) | ||||
| 
 | ||||
| if RAG_EMBEDDING_MODEL: | ||||
|     sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=RAG_EMBEDDING_MODEL, | ||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     ) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| origins = ["*"] | ||||
|  | @ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: | |||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if RAG_EMBEDDING_MODEL: | ||||
|             # if you use docker use the model from the environment variable | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|                 name=collection_name, embedding_function=sentence_transformer_ef | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
|         else: | ||||
|             # for local development use the default model | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|  | @ -139,6 +139,38 @@ async def get_status(): | |||
|         "status": True, | ||||
|         "chunk_size": app.state.CHUNK_SIZE, | ||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/embedding/model") | ||||
| async def get_embedding_model(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "status": True, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class EmbeddingModelUpdateForm(BaseModel): | ||||
|     embedding_model: str | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/embedding/model/update") | ||||
| async def update_embedding_model( | ||||
|     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|     app.state.sentence_transformer_ef = ( | ||||
|         embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|             model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|             device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     return { | ||||
|         "status": True, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -203,16 +235,10 @@ def query_doc( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         if RAG_EMBEDDING_MODEL: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=form_data.collection_name, | ||||
|                 embedding_function=sentence_transformer_ef, | ||||
|             ) | ||||
|         else: | ||||
|             # for local development use the default model | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||
|         return result | ||||
|  | @ -284,16 +310,10 @@ def query_collection( | |||
| 
 | ||||
|     for collection_name in form_data.collection_names: | ||||
|         try: | ||||
|             if RAG_EMBEDDING_MODEL: | ||||
|             # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=collection_name, | ||||
|                     embedding_function=sentence_transformer_ef, | ||||
|                 ) | ||||
|             else: | ||||
|                 # for local development use the default model | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                     name=collection_name, | ||||
|                 embedding_function=app.state.sentence_transformer_ef, | ||||
|             ) | ||||
| 
 | ||||
|             result = collection.query( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek