forked from open-webui/open-webui
		
	choose embedding model when using docker
This commit is contained in:
		
							parent
							
								
									4c3edd0375
								
							
						
					
					
						commit
						1846c1e80d
					
				
					 3 changed files with 46 additions and 20 deletions
				
			
		|  | @ -1,6 +1,5 @@ | |||
| from fastapi import ( | ||||
|     FastAPI, | ||||
|     Request, | ||||
|     Depends, | ||||
|     HTTPException, | ||||
|     status, | ||||
|  | @ -12,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware | |||
| import os, shutil | ||||
| from typing import List | ||||
| 
 | ||||
| # from chromadb.utils import embedding_functions | ||||
| from chromadb.utils import embedding_functions | ||||
| 
 | ||||
| from langchain_community.document_loaders import ( | ||||
|     WebBaseLoader, | ||||
|  | @ -28,24 +27,19 @@ from langchain_community.document_loaders import ( | |||
|     UnstructuredExcelLoader, | ||||
| ) | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| from langchain_community.vectorstores import Chroma | ||||
| from langchain.chains import RetrievalQA | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| import uuid | ||||
| import time | ||||
| 
 | ||||
| from utils.misc import calculate_sha256, calculate_sha256_string | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
| #     model_name=EMBED_MODEL | ||||
| # ) | ||||
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
|  | @ -78,11 +72,17 @@ def store_data_in_vector_db(data, collection_name) -> bool: | |||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
|         if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|     # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name, embedding_function=sentence_transformer_ef) | ||||
|          | ||||
|         else: | ||||
|     # for local development use the default model | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|          ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|  | @ -109,9 +109,17 @@ def query_doc( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=form_data.collection_name, | ||||
|         ) | ||||
|         if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|         # if you use docker use the model from the environment variable | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|                 embedding_function=sentence_transformer_ef | ||||
|             ) | ||||
|         else: | ||||
|         # for local development use the default model | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|             ) | ||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||
|         return result | ||||
|     except Exception as e: | ||||
|  | @ -182,9 +190,18 @@ def query_collection( | |||
| 
 | ||||
|     for collection_name in form_data.collection_names: | ||||
|         try: | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=collection_name, | ||||
|             if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: | ||||
|              # if you use docker use the model from the environment variable | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                     name=form_data.collection_name, | ||||
|                     embedding_function=sentence_transformer_ef | ||||
|                 ) | ||||
|             else: | ||||
|             # for local development use the default model | ||||
|                 collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=form_data.collection_name, | ||||
|             ) | ||||
|                  | ||||
|             result = collection.query( | ||||
|                 query_texts=[form_data.query], n_results=form_data.k | ||||
|             ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jannik Streidl
						Jannik Streidl