forked from open-webui/open-webui
		
	feat: external embeddings support
This commit is contained in:
		
							parent
							
								
									8b10b058e5
								
							
						
					
					
						commit
						2952e61167
					
				
					 6 changed files with 312 additions and 118 deletions
				
			
		|  | @ -654,6 +654,55 @@ async def generate_embeddings( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def generate_ollama_embeddings( | ||||
|     form_data: GenerateEmbeddingsForm, | ||||
|     url_idx: Optional[int] = None, | ||||
| ): | ||||
|     if url_idx == None: | ||||
|         model = form_data.model | ||||
| 
 | ||||
|         if ":" not in model: | ||||
|             model = f"{model}:latest" | ||||
| 
 | ||||
|         if model in app.state.MODELS: | ||||
|             url_idx = random.choice(app.state.MODELS[model]["urls"]) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=400, | ||||
|                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | ||||
|             ) | ||||
| 
 | ||||
|     url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|     log.info(f"url: {url}") | ||||
| 
 | ||||
|     try: | ||||
|         r = requests.request( | ||||
|             method="POST", | ||||
|             url=f"{url}/api/embeddings", | ||||
|             data=form_data.model_dump_json(exclude_none=True).encode(), | ||||
|         ) | ||||
|         r.raise_for_status() | ||||
| 
 | ||||
|         data = r.json() | ||||
| 
 | ||||
|         if "embedding" in data: | ||||
|             return data["embedding"] | ||||
|         else: | ||||
|             raise "Something went wrong :/" | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         error_detail = "Open WebUI: Server Connection Error" | ||||
|         if r is not None: | ||||
|             try: | ||||
|                 res = r.json() | ||||
|                 if "error" in res: | ||||
|                     error_detail = f"Ollama: {res['error']}" | ||||
|             except: | ||||
|                 error_detail = f"Ollama: {e}" | ||||
| 
 | ||||
|         raise error_detail | ||||
| 
 | ||||
| 
 | ||||
| class GenerateCompletionForm(BaseModel): | ||||
|     model: str | ||||
|     prompt: str | ||||
|  |  | |||
|  | @ -39,13 +39,21 @@ import uuid | |||
| import json | ||||
| 
 | ||||
| 
 | ||||
| from apps.ollama.main import generate_ollama_embeddings | ||||
| 
 | ||||
| from apps.web.models.documents import ( | ||||
|     Documents, | ||||
|     DocumentForm, | ||||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from apps.rag.utils import query_doc, query_collection, get_embedding_model_path | ||||
| from apps.rag.utils import ( | ||||
|     query_doc, | ||||
|     query_embeddings_doc, | ||||
|     query_collection, | ||||
|     query_embeddings_collection, | ||||
|     get_embedding_model_path, | ||||
| ) | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|     calculate_sha256, | ||||
|  | @ -58,6 +66,7 @@ from config import ( | |||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_ENGINE, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
|     DEVICE_TYPE, | ||||
|  | @ -74,17 +83,20 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | |||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
| 
 | ||||
| app.state.TOP_K = 4 | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| 
 | ||||
| 
 | ||||
| app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| 
 | ||||
| 
 | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
| 
 | ||||
| 
 | ||||
| app.state.TOP_K = 4 | ||||
| 
 | ||||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=get_embedding_model_path( | ||||
|  | @ -121,6 +133,7 @@ async def get_status(): | |||
|         "chunk_size": app.state.CHUNK_SIZE, | ||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|     } | ||||
| 
 | ||||
|  | @ -252,6 +265,17 @@ def query_doc_handler( | |||
| ): | ||||
| 
 | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} | ||||
|             ) | ||||
| 
 | ||||
|             return query_embeddings_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|         else: | ||||
|             return query_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|  | @ -277,12 +301,30 @@ def query_collection_handler( | |||
|     form_data: QueryCollectionsForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} | ||||
|             ) | ||||
| 
 | ||||
|             return query_embeddings_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|         else: | ||||
|             return query_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 embedding_function=app.state.sentence_transformer_ef, | ||||
|             ) | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/web") | ||||
|  | @ -317,6 +359,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b | |||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
| 
 | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     if len(docs) > 0: | ||||
|  | @ -337,7 +380,9 @@ def store_text_in_vector_db( | |||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
| async def store_docs_in_vector_db( | ||||
|     docs, collection_name, overwrite: bool = False | ||||
| ) -> bool: | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
|  | @ -349,6 +394,22 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|                     log.info(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|             for batch in create_batches( | ||||
|                 api=CHROMA_CLIENT, | ||||
|                 ids=[str(uuid.uuid1()) for _ in texts], | ||||
|                 metadatas=metadatas, | ||||
|                 embeddings=[ | ||||
|                     generate_ollama_embeddings( | ||||
|                         {"model": RAG_EMBEDDING_MODEL, "prompt": text} | ||||
|                     ) | ||||
|                     for text in texts | ||||
|                 ], | ||||
|             ): | ||||
|                 collection.add(*batch) | ||||
|         else: | ||||
|             collection = CHROMA_CLIENT.create_collection( | ||||
|                 name=collection_name, | ||||
|                 embedding_function=app.state.sentence_transformer_ef, | ||||
|  |  | |||
|  | @ -2,6 +2,9 @@ import os | |||
| import re | ||||
| import logging | ||||
| from typing import List | ||||
| import requests | ||||
| 
 | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
|  | @ -26,6 +29,21 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): | |||
|         raise e | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_doc(collection_name: str, query_embeddings, k: int): | ||||
|     try: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=collection_name, | ||||
|         ) | ||||
|         result = collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=k, | ||||
|         ) | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         raise e | ||||
| 
 | ||||
| 
 | ||||
| def merge_and_sort_query_results(query_results, k): | ||||
|     # Initialize lists to store combined data | ||||
|     combined_ids = [] | ||||
|  | @ -96,6 +114,24 @@ def query_collection( | |||
|     return merge_and_sort_query_results(results, k) | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): | ||||
| 
 | ||||
|     results = [] | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|             result = collection.query( | ||||
|                 query_embeddings=[query_embeddings], | ||||
|                 n_results=k, | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     return merge_and_sort_query_results(results, k) | ||||
| 
 | ||||
| 
 | ||||
| def rag_template(template: str, context: str, query: str): | ||||
|     template = template.replace("[context]", context) | ||||
|     template = template.replace("[query]", query) | ||||
|  |  | |||
|  | @ -405,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": | |||
| 
 | ||||
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | ||||
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) | ||||
| 
 | ||||
| RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") | ||||
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | ||||
| 
 | ||||
|  |  | |||
|  | @ -220,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa | |||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const generateEmbeddings = async (token: string = '', model: string, text: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			Accept: 'application/json', | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			model: model, | ||||
| 			prompt: text | ||||
| 		}) | ||||
| 	}).catch((err) => { | ||||
| 		error = err; | ||||
| 		return null; | ||||
| 	}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const generateTextCompletion = async (token: string = '', model: string, text: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ | |||
| 
 | ||||
| 	let showResetConfirm = false; | ||||
| 
 | ||||
| 	let embeddingEngine = ''; | ||||
| 	let chunkSize = 0; | ||||
| 	let chunkOverlap = 0; | ||||
| 	let pdfExtractImages = true; | ||||
|  | @ -119,58 +120,25 @@ | |||
| 			<div class=" mb-2 text-sm font-medium">{$i18n.t('General Settings')}</div> | ||||
| 
 | ||||
| 			<div class=" flex w-full justify-between"> | ||||
| 				<div class=" self-center text-xs font-medium"> | ||||
| 					{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} | ||||
| 				</div> | ||||
| 
 | ||||
| 				<button | ||||
| 					class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading | ||||
| 						? ' cursor-not-allowed' | ||||
| 						: ''}" | ||||
| 					on:click={() => { | ||||
| 						scanHandler(); | ||||
| 						console.log('check'); | ||||
| 					}} | ||||
| 					type="button" | ||||
| 					disabled={scanDirLoading} | ||||
| 				<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Engine')}</div> | ||||
| 				<div class="flex items-center relative"> | ||||
| 					<select | ||||
| 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" | ||||
| 						bind:value={embeddingEngine} | ||||
| 						placeholder="Select an embedding engine" | ||||
| 					> | ||||
| 					<div class="self-center font-medium">{$i18n.t('Scan')}</div> | ||||
| 
 | ||||
| 					{#if scanDirLoading} | ||||
| 						<div class="ml-3 self-center"> | ||||
| 							<svg | ||||
| 								class=" w-3 h-3" | ||||
| 								viewBox="0 0 24 24" | ||||
| 								fill="currentColor" | ||||
| 								xmlns="http://www.w3.org/2000/svg" | ||||
| 								><style> | ||||
| 									.spinner_ajPY { | ||||
| 										transform-origin: center; | ||||
| 										animation: spinner_AtaB 0.75s infinite linear; | ||||
| 									} | ||||
| 									@keyframes spinner_AtaB { | ||||
| 										100% { | ||||
| 											transform: rotate(360deg); | ||||
| 										} | ||||
| 									} | ||||
| 								</style><path | ||||
| 									d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z" | ||||
| 									opacity=".25" | ||||
| 								/><path | ||||
| 									d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z" | ||||
| 									class="spinner_ajPY" | ||||
| 								/></svg | ||||
| 							> | ||||
| 						</div> | ||||
| 					{/if} | ||||
| 				</button> | ||||
| 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> | ||||
| 						<option value="ollama">{$i18n.t('Ollama')}</option> | ||||
| 					</select> | ||||
| 				</div> | ||||
| 			</div> | ||||
| 		</div> | ||||
| 
 | ||||
| 		<hr class=" dark:border-gray-700" /> | ||||
| 
 | ||||
| 		<div class="space-y-2"> | ||||
| 			<div> | ||||
| 				{#if embeddingEngine === 'ollama'} | ||||
| 					<div>da</div> | ||||
| 				{:else} | ||||
| 					<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Embedding Model')}</div> | ||||
| 					<div class="flex w-full"> | ||||
| 						<div class="flex-1 mr-2"> | ||||
|  | @ -238,6 +206,57 @@ | |||
| 							'Warning: If you update or change your embedding model, you will need to re-import all documents.' | ||||
| 						)} | ||||
| 					</div> | ||||
| 				{/if} | ||||
| 
 | ||||
| 				<hr class=" dark:border-gray-700 my-3" /> | ||||
| 
 | ||||
| 				<div class="  flex w-full justify-between"> | ||||
| 					<div class=" self-center text-xs font-medium"> | ||||
| 						{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} | ||||
| 					</div> | ||||
| 
 | ||||
| 					<button | ||||
| 						class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded-lg flex flex-row space-x-1 items-center {scanDirLoading | ||||
| 							? ' cursor-not-allowed' | ||||
| 							: ''}" | ||||
| 						on:click={() => { | ||||
| 							scanHandler(); | ||||
| 							console.log('check'); | ||||
| 						}} | ||||
| 						type="button" | ||||
| 						disabled={scanDirLoading} | ||||
| 					> | ||||
| 						<div class="self-center font-medium">{$i18n.t('Scan')}</div> | ||||
| 
 | ||||
| 						{#if scanDirLoading} | ||||
| 							<div class="ml-3 self-center"> | ||||
| 								<svg | ||||
| 									class=" w-3 h-3" | ||||
| 									viewBox="0 0 24 24" | ||||
| 									fill="currentColor" | ||||
| 									xmlns="http://www.w3.org/2000/svg" | ||||
| 									><style> | ||||
| 										.spinner_ajPY { | ||||
| 											transform-origin: center; | ||||
| 											animation: spinner_AtaB 0.75s infinite linear; | ||||
| 										} | ||||
| 										@keyframes spinner_AtaB { | ||||
| 											100% { | ||||
| 												transform: rotate(360deg); | ||||
| 											} | ||||
| 										} | ||||
| 									</style><path | ||||
| 										d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z" | ||||
| 										opacity=".25" | ||||
| 									/><path | ||||
| 										d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z" | ||||
| 										class="spinner_ajPY" | ||||
| 									/></svg | ||||
| 								> | ||||
| 							</div> | ||||
| 						{/if} | ||||
| 					</button> | ||||
| 				</div> | ||||
| 
 | ||||
| 				<hr class=" dark:border-gray-700 my-3" /> | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek