forked from open-webui/open-webui
		
	feat: openai embeddings integration
This commit is contained in:
		
							parent
							
								
									b48e73fa43
								
							
						
					
					
						commit
						b1b72441bb
					
				
					 6 changed files with 155 additions and 46 deletions
				
			
		|  | @ -659,7 +659,7 @@ def generate_ollama_embeddings( | |||
|     url_idx: Optional[int] = None, | ||||
| ): | ||||
| 
 | ||||
|     log.info("generate_ollama_embeddings", form_data) | ||||
|     log.info(f"generate_ollama_embeddings {form_data}") | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         model = form_data.model | ||||
|  | @ -688,7 +688,7 @@ def generate_ollama_embeddings( | |||
| 
 | ||||
|         data = r.json() | ||||
| 
 | ||||
|         log.info("generate_ollama_embeddings", data) | ||||
|         log.info(f"generate_ollama_embeddings {data}") | ||||
| 
 | ||||
|         if "embedding" in data: | ||||
|             return data["embedding"] | ||||
|  |  | |||
|  | @ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b | |||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     if len(docs) > 0: | ||||
|         log.info("store_data_in_vector_db", "store_docs_in_vector_db") | ||||
|         log.info(f"store_data_in_vector_db {docs}") | ||||
|         return store_docs_in_vector_db(docs, collection_name, overwrite), None | ||||
|     else: | ||||
|         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) | ||||
|  | @ -440,7 +440,7 @@ def store_text_in_vector_db( | |||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
|     log.info("store_docs_in_vector_db", docs, collection_name) | ||||
|     log.info(f"store_docs_in_vector_db {docs} {collection_name}") | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
|  | @ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|                 collection.add(*batch) | ||||
| 
 | ||||
|         else: | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|                 embeddings = [ | ||||
|                     generate_ollama_embeddings( | ||||
|  |  | |||
|  | @ -6,9 +6,12 @@ import requests | |||
| 
 | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||
| 
 | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
|  | @ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): | |||
| def query_embeddings_doc(collection_name: str, query_embeddings, k: int): | ||||
|     try: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         log.info("query_embeddings_doc", query_embeddings) | ||||
|         log.info(f"query_embeddings_doc {query_embeddings}") | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=collection_name, | ||||
|         ) | ||||
|  | @ -118,7 +121,7 @@ def query_collection( | |||
| def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): | ||||
| 
 | ||||
|     results = [] | ||||
|     log.info("query_embeddings_collection", query_embeddings) | ||||
|     log.info(f"query_embeddings_collection {query_embeddings}") | ||||
| 
 | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|  | @ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str): | |||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages(docs, messages, template, k, embedding_function): | ||||
| def rag_messages( | ||||
|     docs, | ||||
|     messages, | ||||
|     template, | ||||
|     k, | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|     openai_key, | ||||
|     openai_url, | ||||
| ): | ||||
|     log.debug(f"docs: {docs}") | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|  | @ -175,6 +188,11 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|         context = None | ||||
| 
 | ||||
|         try: | ||||
| 
 | ||||
|             if doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 if embedding_engine == "": | ||||
|                     if doc["type"] == "collection": | ||||
|                         context = query_collection( | ||||
|                             collection_names=doc["collection_names"], | ||||
|  | @ -182,8 +200,6 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                             k=k, | ||||
|                             embedding_function=embedding_function, | ||||
|                         ) | ||||
|             elif doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|                     else: | ||||
|                         context = query_doc( | ||||
|                             collection_name=doc["collection_name"], | ||||
|  | @ -191,6 +207,38 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                             k=k, | ||||
|                             embedding_function=embedding_function, | ||||
|                         ) | ||||
| 
 | ||||
|                 else: | ||||
|                     if embedding_engine == "ollama": | ||||
|                         query_embeddings = generate_ollama_embeddings( | ||||
|                             GenerateEmbeddingsForm( | ||||
|                                 **{ | ||||
|                                     "model": embedding_model, | ||||
|                                     "prompt": query, | ||||
|                                 } | ||||
|                             ) | ||||
|                         ) | ||||
|                     elif embedding_engine == "openai": | ||||
|                         query_embeddings = generate_openai_embeddings( | ||||
|                             model=embedding_model, | ||||
|                             text=query, | ||||
|                             key=openai_key, | ||||
|                             url=openai_url, | ||||
|                         ) | ||||
| 
 | ||||
|                     if doc["type"] == "collection": | ||||
|                         context = query_embeddings_collection( | ||||
|                             collection_names=doc["collection_names"], | ||||
|                             query_embeddings=query_embeddings, | ||||
|                             k=k, | ||||
|                         ) | ||||
|                     else: | ||||
|                         context = query_embeddings_doc( | ||||
|                             collection_name=doc["collection_name"], | ||||
|                             query_embeddings=query_embeddings, | ||||
|                             k=k, | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             log.exception(e) | ||||
|             context = None | ||||
|  |  | |||
|  | @ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|                     data["messages"], | ||||
|                     rag_app.state.RAG_TEMPLATE, | ||||
|                     rag_app.state.TOP_K, | ||||
|                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||
|                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||
|                     rag_app.state.sentence_transformer_ef, | ||||
|                     rag_app.state.RAG_OPENAI_API_KEY, | ||||
|                     rag_app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|  |  | |||
|  | @ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => { | |||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| type OpenAIConfigForm = { | ||||
| 	key: string; | ||||
| 	url: string; | ||||
| }; | ||||
| 
 | ||||
| type EmbeddingModelUpdateForm = { | ||||
| 	openai_config?: OpenAIConfigForm; | ||||
| 	embedding_engine: string; | ||||
| 	embedding_model: string; | ||||
| }; | ||||
|  |  | |||
|  | @ -29,6 +29,9 @@ | |||
| 	let embeddingEngine = ''; | ||||
| 	let embeddingModel = ''; | ||||
| 
 | ||||
| 	let openAIKey = ''; | ||||
| 	let openAIUrl = ''; | ||||
| 
 | ||||
| 	let chunkSize = 0; | ||||
| 	let chunkOverlap = 0; | ||||
| 	let pdfExtractImages = true; | ||||
|  | @ -50,15 +53,6 @@ | |||
| 	}; | ||||
| 
 | ||||
| 	const embeddingModelUpdateHandler = async () => { | ||||
| 		if (embeddingModel === '') { | ||||
| 			toast.error( | ||||
| 				$i18n.t( | ||||
| 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' | ||||
| 				) | ||||
| 			); | ||||
| 			return; | ||||
| 		} | ||||
| 
 | ||||
| 		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { | ||||
| 			toast.error( | ||||
| 				$i18n.t( | ||||
|  | @ -67,21 +61,46 @@ | |||
| 			); | ||||
| 			return; | ||||
| 		} | ||||
| 		if (embeddingEngine === 'ollama' && embeddingModel === '') { | ||||
| 			toast.error( | ||||
| 				$i18n.t( | ||||
| 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' | ||||
| 				) | ||||
| 			); | ||||
| 			return; | ||||
| 		} | ||||
| 
 | ||||
| 		if (embeddingEngine === 'openai' && embeddingModel === '') { | ||||
| 			toast.error( | ||||
| 				$i18n.t( | ||||
| 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' | ||||
| 				) | ||||
| 			); | ||||
| 			return; | ||||
| 		} | ||||
| 
 | ||||
| 		if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') { | ||||
| 			toast.error($i18n.t('OpenAI URL/Key required.')); | ||||
| 			return; | ||||
| 		} | ||||
| 
 | ||||
| 		console.log('Update embedding model attempt:', embeddingModel); | ||||
| 
 | ||||
| 		updateEmbeddingModelLoading = true; | ||||
| 		const res = await updateEmbeddingConfig(localStorage.token, { | ||||
| 			embedding_engine: embeddingEngine, | ||||
| 			embedding_model: embeddingModel | ||||
| 			embedding_model: embeddingModel, | ||||
| 			...(embeddingEngine === 'openai' | ||||
| 				? { | ||||
| 						openai_config: { | ||||
| 							key: openAIKey, | ||||
| 							url: openAIUrl | ||||
| 						} | ||||
| 				  } | ||||
| 				: {}) | ||||
| 		}).catch(async (error) => { | ||||
| 			toast.error(error); | ||||
| 
 | ||||
| 			const embeddingConfig = await getEmbeddingConfig(localStorage.token); | ||||
| 			if (embeddingConfig) { | ||||
| 				embeddingEngine = embeddingConfig.embedding_engine; | ||||
| 				embeddingModel = embeddingConfig.embedding_model; | ||||
| 			} | ||||
| 			await setEmbeddingConfig(); | ||||
| 			return null; | ||||
| 		}); | ||||
| 		updateEmbeddingModelLoading = false; | ||||
|  | @ -89,7 +108,7 @@ | |||
| 		if (res) { | ||||
| 			console.log('embeddingModelUpdateHandler:', res); | ||||
| 			if (res.status === true) { | ||||
| 				toast.success($i18n.t('Model {{embedding_model}} update complete!', res), { | ||||
| 				toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), { | ||||
| 					duration: 1000 * 10 | ||||
| 				}); | ||||
| 			} | ||||
|  | @ -107,6 +126,18 @@ | |||
| 		querySettings = await updateQuerySettings(localStorage.token, querySettings); | ||||
| 	}; | ||||
| 
 | ||||
| 	const setEmbeddingConfig = async () => { | ||||
| 		const embeddingConfig = await getEmbeddingConfig(localStorage.token); | ||||
| 
 | ||||
| 		if (embeddingConfig) { | ||||
| 			embeddingEngine = embeddingConfig.embedding_engine; | ||||
| 			embeddingModel = embeddingConfig.embedding_model; | ||||
| 
 | ||||
| 			openAIKey = embeddingConfig.openai_config.key; | ||||
| 			openAIUrl = embeddingConfig.openai_config.url; | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	onMount(async () => { | ||||
| 		const res = await getRAGConfig(localStorage.token); | ||||
| 
 | ||||
|  | @ -117,12 +148,7 @@ | |||
| 			chunkOverlap = res.chunk.chunk_overlap; | ||||
| 		} | ||||
| 
 | ||||
| 		const embeddingConfig = await getEmbeddingConfig(localStorage.token); | ||||
| 
 | ||||
| 		if (embeddingConfig) { | ||||
| 			embeddingEngine = embeddingConfig.embedding_engine; | ||||
| 			embeddingModel = embeddingConfig.embedding_model; | ||||
| 		} | ||||
| 		await setEmbeddingConfig(); | ||||
| 
 | ||||
| 		querySettings = await getQuerySettings(localStorage.token); | ||||
| 	}); | ||||
|  | @ -146,15 +172,38 @@ | |||
| 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" | ||||
| 						bind:value={embeddingEngine} | ||||
| 						placeholder="Select an embedding engine" | ||||
| 						on:change={() => { | ||||
| 						on:change={(e) => { | ||||
| 							if (e.target.value === 'ollama') { | ||||
| 								embeddingModel = ''; | ||||
| 							} else if (e.target.value === 'openai') { | ||||
| 								embeddingModel = 'text-embedding-3-small'; | ||||
| 							} | ||||
| 						}} | ||||
| 					> | ||||
| 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> | ||||
| 						<option value="ollama">{$i18n.t('Ollama')}</option> | ||||
| 						<option value="openai">{$i18n.t('OpenAI')}</option> | ||||
| 					</select> | ||||
| 				</div> | ||||
| 			</div> | ||||
| 
 | ||||
| 			{#if embeddingEngine === 'openai'} | ||||
| 				<div class="mt-1 flex gap-2"> | ||||
| 					<input | ||||
| 						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 						placeholder={$i18n.t('API Base URL')} | ||||
| 						bind:value={openAIUrl} | ||||
| 						required | ||||
| 					/> | ||||
| 
 | ||||
| 					<input | ||||
| 						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 						placeholder={$i18n.t('API Key')} | ||||
| 						bind:value={openAIKey} | ||||
| 						required | ||||
| 					/> | ||||
| 				</div> | ||||
| 			{/if} | ||||
| 		</div> | ||||
| 
 | ||||
| 		<div class="space-y-2"> | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek