forked from open-webui/open-webui
		
	feat: dynamic embedding model load
This commit is contained in:
		
							parent
							
								
									ab104d5905
								
							
						
					
					
						commit
						7c127c35fc
					
				
					 1 changed files with 56 additions and 36 deletions
				
			
		|  | @ -35,6 +35,8 @@ from pydantic import BaseModel | ||||||
| from typing import Optional | from typing import Optional | ||||||
| import mimetypes | import mimetypes | ||||||
| import uuid | import uuid | ||||||
|  | import json | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| from apps.web.models.documents import ( | from apps.web.models.documents import ( | ||||||
|     Documents, |     Documents, | ||||||
|  | @ -63,24 +65,26 @@ from config import ( | ||||||
| from constants import ERROR_MESSAGES | from constants import ERROR_MESSAGES | ||||||
| 
 | 
 | ||||||
| # | # | ||||||
| #if RAG_EMBEDDING_MODEL: | # if RAG_EMBEDDING_MODEL: | ||||||
| #    sentence_transformer_ef = SentenceTransformer( | #    sentence_transformer_ef = SentenceTransformer( | ||||||
| #        model_name_or_path=RAG_EMBEDDING_MODEL, | #        model_name_or_path=RAG_EMBEDDING_MODEL, | ||||||
| #        cache_folder=RAG_EMBEDDING_MODEL_DIR, | #        cache_folder=RAG_EMBEDDING_MODEL_DIR, | ||||||
| #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
| #    ) | #    ) | ||||||
| 
 | 
 | ||||||
| if RAG_EMBEDDING_MODEL: |  | ||||||
|     sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( |  | ||||||
|         model_name=RAG_EMBEDDING_MODEL, |  | ||||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| 
 | 
 | ||||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | app.state.CHUNK_SIZE = CHUNK_SIZE | ||||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||||
|  | app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||||
|  | app.state.sentence_transformer_ef = ( | ||||||
|  |     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||||
|  |         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| origins = ["*"] | origins = ["*"] | ||||||
|  | @ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: | ||||||
|     metadatas = [doc.metadata for doc in docs] |     metadatas = [doc.metadata for doc in docs] | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         if RAG_EMBEDDING_MODEL: |         collection = CHROMA_CLIENT.create_collection( | ||||||
|             # if you use docker use the model from the environment variable |             name=collection_name, | ||||||
|             collection = CHROMA_CLIENT.create_collection( |             embedding_function=app.state.sentence_transformer_ef, | ||||||
|                 name=collection_name, embedding_function=sentence_transformer_ef |         ) | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             # for local development use the default model |  | ||||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) |  | ||||||
| 
 | 
 | ||||||
|         collection.add( |         collection.add( | ||||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] |             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||||
|  | @ -139,6 +139,38 @@ async def get_status(): | ||||||
|         "status": True, |         "status": True, | ||||||
|         "chunk_size": app.state.CHUNK_SIZE, |         "chunk_size": app.state.CHUNK_SIZE, | ||||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, |         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||||
|  |         "template": app.state.RAG_TEMPLATE, | ||||||
|  |         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.get("/embedding/model") | ||||||
|  | async def get_embedding_model(user=Depends(get_admin_user)): | ||||||
|  |     return { | ||||||
|  |         "status": True, | ||||||
|  |         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EmbeddingModelUpdateForm(BaseModel): | ||||||
|  |     embedding_model: str | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.post("/embedding/model/update") | ||||||
|  | async def update_embedding_model( | ||||||
|  |     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||||
|  | ): | ||||||
|  |     app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||||
|  |     app.state.sentence_transformer_ef = ( | ||||||
|  |         embedding_functions.SentenceTransformerEmbeddingFunction( | ||||||
|  |             model_name=app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |             device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     return { | ||||||
|  |         "status": True, | ||||||
|  |         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -203,17 +235,11 @@ def query_doc( | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|         if RAG_EMBEDDING_MODEL: |         # if you use docker use the model from the environment variable | ||||||
|             # if you use docker use the model from the environment variable |         collection = CHROMA_CLIENT.get_collection( | ||||||
|             collection = CHROMA_CLIENT.get_collection( |             name=form_data.collection_name, | ||||||
|                 name=form_data.collection_name, |             embedding_function=app.state.sentence_transformer_ef, | ||||||
|                 embedding_function=sentence_transformer_ef, |         ) | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             # for local development use the default model |  | ||||||
|             collection = CHROMA_CLIENT.get_collection( |  | ||||||
|                 name=form_data.collection_name, |  | ||||||
|             ) |  | ||||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) |         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||||
|         return result |         return result | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|  | @ -284,17 +310,11 @@ def query_collection( | ||||||
| 
 | 
 | ||||||
|     for collection_name in form_data.collection_names: |     for collection_name in form_data.collection_names: | ||||||
|         try: |         try: | ||||||
|             if RAG_EMBEDDING_MODEL: |             # if you use docker use the model from the environment variable | ||||||
|                 # if you use docker use the model from the environment variable |             collection = CHROMA_CLIENT.get_collection( | ||||||
|                 collection = CHROMA_CLIENT.get_collection( |                 name=collection_name, | ||||||
|                     name=collection_name, |                 embedding_function=app.state.sentence_transformer_ef, | ||||||
|                     embedding_function=sentence_transformer_ef, |             ) | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|                 # for local development use the default model |  | ||||||
|                 collection = CHROMA_CLIENT.get_collection( |  | ||||||
|                     name=collection_name, |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|             result = collection.query( |             result = collection.query( | ||||||
|                 query_texts=[form_data.query], n_results=form_data.k |                 query_texts=[form_data.query], n_results=form_data.k | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek