forked from open-webui/open-webui
		
	choose embedding model when using docker
This commit is contained in:
		
							parent
							
								
									4c3edd0375
								
							
						
					
					
						commit
						1846c1e80d
					
				
					 3 changed files with 46 additions and 20 deletions
				
			
		
							
								
								
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							|  | @ -30,10 +30,16 @@ ENV WEBUI_SECRET_KEY "" | |||
| ENV SCARF_NO_ANALYTICS true | ||||
| ENV DO_NOT_TRACK true | ||||
| 
 | ||||
| #Whisper TTS Settings | ||||
| # whisper TTS Settings | ||||
| ENV WHISPER_MODEL="base" | ||||
| ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" | ||||
| 
 | ||||
| # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers | ||||
| # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard  | ||||
| # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" | ||||
| # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. | ||||
| ENV DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL="all-MiniLM-L6-v2" | ||||
| 
 | ||||
| WORKDIR /app/backend | ||||
| 
 | ||||
| # install python dependencies | ||||
|  | @ -48,7 +54,9 @@ RUN apt-get update \ | |||
|     && apt-get install -y pandoc netcat-openbsd \ | ||||
|     && rm -rf /var/lib/apt/lists/* | ||||
| 
 | ||||
| # RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" | ||||
| # preload embedding model | ||||
| RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL'])" | ||||
| # preload tts model | ||||
| RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,6 +1,5 @@ | |||
| from fastapi import ( | ||||
|     FastAPI, | ||||
|     Request, | ||||
|     Depends, | ||||
|     HTTPException, | ||||
|     status, | ||||
|  | @ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware | |||
| import os, shutil | ||||
| from typing import List | ||||
| 
 | ||||
| # from chromadb.utils import embedding_functions | ||||
| from chromadb.utils import embedding_functions | ||||
| 
 | ||||
| from langchain_community.document_loaders import ( | ||||
|     WebBaseLoader, | ||||
|  | @ -28,24 +27,19 @@ from langchain_community.document_loaders import ( | |||
|     UnstructuredExcelLoader, | ||||
| ) | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| from langchain_community.vectorstores import Chroma | ||||
| from langchain.chains import RetrievalQA | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| import uuid | ||||
| import time | ||||
| 
 | ||||
| from utils.misc import calculate_sha256, calculate_sha256_string | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
| #     model_name=EMBED_MODEL | ||||
| # ) | ||||
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
|  | @ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool: | |||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
|         if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|     # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) | ||||
|          | ||||
|         else: | ||||
|     # for local development use the default model | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|          ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|  | @ -109,9 +109,17 @@ def query_doc( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=form_data.collection_name, | ||||
|         ) | ||||
|         if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|         # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|                 embedding_function=sentence_transformer_ef | ||||
|             ) | ||||
|         else: | ||||
|         # for local development use the default model | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|             ) | ||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||
|         return result | ||||
|     except Exception as e: | ||||
|  | @ -182,9 +190,18 @@ def query_collection( | |||
| 
 | ||||
|     for collection_name in form_data.collection_names: | ||||
|         try: | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=collection_name, | ||||
|             if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|              # if you use docker use the model from the environment variable | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                     name=form_data.collection_name, | ||||
|                     embedding_function=sentence_transformer_ef | ||||
|                 ) | ||||
|             else: | ||||
|             # for local development use the default model | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|             ) | ||||
|                  | ||||
|             result = collection.query( | ||||
|                 query_texts=[form_data.query], n_results=form_data.k | ||||
|             ) | ||||
|  |  | |||
|  | @ -128,7 +128,8 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": | |||
| #################################### | ||||
| 
 | ||||
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | ||||
| EMBED_MODEL = "all-MiniLM-L6-v2" | ||||
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) | ||||
| SENTENCE_TRANSFORMER_EMBED_MODEL = os.getenv("DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL") | ||||
| CHROMA_CLIENT = chromadb.PersistentClient( | ||||
|     path=CHROMA_DATA_PATH, | ||||
|     settings=Settings(allow_reset=True, anonymized_telemetry=False), | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jannik Streidl
						Jannik Streidl