forked from open-webui/open-webui
		
	feat: openai embeddings integration
This commit is contained in:
		
							parent
							
								
									b48e73fa43
								
							
						
					
					
						commit
						b1b72441bb
					
				
					 6 changed files with 155 additions and 46 deletions
				
			
		|  | @ -659,7 +659,7 @@ def generate_ollama_embeddings( | ||||||
|     url_idx: Optional[int] = None, |     url_idx: Optional[int] = None, | ||||||
| ): | ): | ||||||
| 
 | 
 | ||||||
|     log.info("generate_ollama_embeddings", form_data) |     log.info(f"generate_ollama_embeddings {form_data}") | ||||||
| 
 | 
 | ||||||
|     if url_idx == None: |     if url_idx == None: | ||||||
|         model = form_data.model |         model = form_data.model | ||||||
|  | @ -688,7 +688,7 @@ def generate_ollama_embeddings( | ||||||
| 
 | 
 | ||||||
|         data = r.json() |         data = r.json() | ||||||
| 
 | 
 | ||||||
|         log.info("generate_ollama_embeddings", data) |         log.info(f"generate_ollama_embeddings {data}") | ||||||
| 
 | 
 | ||||||
|         if "embedding" in data: |         if "embedding" in data: | ||||||
|             return data["embedding"] |             return data["embedding"] | ||||||
|  |  | ||||||
|  | @ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b | ||||||
|     docs = text_splitter.split_documents(data) |     docs = text_splitter.split_documents(data) | ||||||
| 
 | 
 | ||||||
|     if len(docs) > 0: |     if len(docs) > 0: | ||||||
|         log.info("store_data_in_vector_db", "store_docs_in_vector_db") |         log.info(f"store_data_in_vector_db {docs}") | ||||||
|         return store_docs_in_vector_db(docs, collection_name, overwrite), None |         return store_docs_in_vector_db(docs, collection_name, overwrite), None | ||||||
|     else: |     else: | ||||||
|         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) |         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) | ||||||
|  | @ -440,7 +440,7 @@ def store_text_in_vector_db( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||||
|     log.info("store_docs_in_vector_db", docs, collection_name) |     log.info(f"store_docs_in_vector_db {docs} {collection_name}") | ||||||
| 
 | 
 | ||||||
|     texts = [doc.page_content for doc in docs] |     texts = [doc.page_content for doc in docs] | ||||||
|     metadatas = [doc.metadata for doc in docs] |     metadatas = [doc.metadata for doc in docs] | ||||||
|  | @ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | ||||||
|                 collection.add(*batch) |                 collection.add(*batch) | ||||||
| 
 | 
 | ||||||
|         else: |         else: | ||||||
|  |             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||||
|  | 
 | ||||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": |             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||||
|                 embeddings = [ |                 embeddings = [ | ||||||
|                     generate_ollama_embeddings( |                     generate_ollama_embeddings( | ||||||
|  |  | ||||||
|  | @ -6,9 +6,12 @@ import requests | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||||
|  | from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| log = logging.getLogger(__name__) | log = logging.getLogger(__name__) | ||||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||||
| 
 | 
 | ||||||
|  | @ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): | ||||||
| def query_embeddings_doc(collection_name: str, query_embeddings, k: int): | def query_embeddings_doc(collection_name: str, query_embeddings, k: int): | ||||||
|     try: |     try: | ||||||
|         # if you use docker use the model from the environment variable |         # if you use docker use the model from the environment variable | ||||||
|         log.info("query_embeddings_doc", query_embeddings) |         log.info(f"query_embeddings_doc {query_embeddings}") | ||||||
|         collection = CHROMA_CLIENT.get_collection( |         collection = CHROMA_CLIENT.get_collection( | ||||||
|             name=collection_name, |             name=collection_name, | ||||||
|         ) |         ) | ||||||
|  | @ -118,7 +121,7 @@ def query_collection( | ||||||
| def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): | def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): | ||||||
| 
 | 
 | ||||||
|     results = [] |     results = [] | ||||||
|     log.info("query_embeddings_collection", query_embeddings) |     log.info(f"query_embeddings_collection {query_embeddings}") | ||||||
| 
 | 
 | ||||||
|     for collection_name in collection_names: |     for collection_name in collection_names: | ||||||
|         try: |         try: | ||||||
|  | @ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str): | ||||||
|     return template |     return template | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def rag_messages(docs, messages, template, k, embedding_function): | def rag_messages( | ||||||
|  |     docs, | ||||||
|  |     messages, | ||||||
|  |     template, | ||||||
|  |     k, | ||||||
|  |     embedding_engine, | ||||||
|  |     embedding_model, | ||||||
|  |     embedding_function, | ||||||
|  |     openai_key, | ||||||
|  |     openai_url, | ||||||
|  | ): | ||||||
|     log.debug(f"docs: {docs}") |     log.debug(f"docs: {docs}") | ||||||
| 
 | 
 | ||||||
|     last_user_message_idx = None |     last_user_message_idx = None | ||||||
|  | @ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function): | ||||||
|         context = None |         context = None | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             if doc["type"] == "collection": | 
 | ||||||
|                 context = query_collection( |             if doc["type"] == "text": | ||||||
|                     collection_names=doc["collection_names"], |  | ||||||
|                     query=query, |  | ||||||
|                     k=k, |  | ||||||
|                     embedding_function=embedding_function, |  | ||||||
|                 ) |  | ||||||
|             elif doc["type"] == "text": |  | ||||||
|                 context = doc["content"] |                 context = doc["content"] | ||||||
|             else: |             else: | ||||||
|                 context = query_doc( |                 if embedding_engine == "": | ||||||
|                     collection_name=doc["collection_name"], |                     if doc["type"] == "collection": | ||||||
|                     query=query, |                         context = query_collection( | ||||||
|                     k=k, |                             collection_names=doc["collection_names"], | ||||||
|                     embedding_function=embedding_function, |                             query=query, | ||||||
|                 ) |                             k=k, | ||||||
|  |                             embedding_function=embedding_function, | ||||||
|  |                         ) | ||||||
|  |                     else: | ||||||
|  |                         context = query_doc( | ||||||
|  |                             collection_name=doc["collection_name"], | ||||||
|  |                             query=query, | ||||||
|  |                             k=k, | ||||||
|  |                             embedding_function=embedding_function, | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     if embedding_engine == "ollama": | ||||||
|  |                         query_embeddings = generate_ollama_embeddings( | ||||||
|  |                             GenerateEmbeddingsForm( | ||||||
|  |                                 **{ | ||||||
|  |                                     "model": embedding_model, | ||||||
|  |                                     "prompt": query, | ||||||
|  |                                 } | ||||||
|  |                             ) | ||||||
|  |                         ) | ||||||
|  |                     elif embedding_engine == "openai": | ||||||
|  |                         query_embeddings = generate_openai_embeddings( | ||||||
|  |                             model=embedding_model, | ||||||
|  |                             text=query, | ||||||
|  |                             key=openai_key, | ||||||
|  |                             url=openai_url, | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |                     if doc["type"] == "collection": | ||||||
|  |                         context = query_embeddings_collection( | ||||||
|  |                             collection_names=doc["collection_names"], | ||||||
|  |                             query_embeddings=query_embeddings, | ||||||
|  |                             k=k, | ||||||
|  |                         ) | ||||||
|  |                     else: | ||||||
|  |                         context = query_embeddings_doc( | ||||||
|  |                             collection_name=doc["collection_name"], | ||||||
|  |                             query_embeddings=query_embeddings, | ||||||
|  |                             k=k, | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             log.exception(e) |             log.exception(e) | ||||||
|             context = None |             context = None | ||||||
|  |  | ||||||
|  | @ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware): | ||||||
|                     data["messages"], |                     data["messages"], | ||||||
|                     rag_app.state.RAG_TEMPLATE, |                     rag_app.state.RAG_TEMPLATE, | ||||||
|                     rag_app.state.TOP_K, |                     rag_app.state.TOP_K, | ||||||
|  |                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||||
|  |                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||||
|                     rag_app.state.sentence_transformer_ef, |                     rag_app.state.sentence_transformer_ef, | ||||||
|  |                     rag_app.state.RAG_OPENAI_API_KEY, | ||||||
|  |                     rag_app.state.RAG_OPENAI_API_BASE_URL, | ||||||
|                 ) |                 ) | ||||||
|                 del data["docs"] |                 del data["docs"] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => { | ||||||
| 	return res; | 	return res; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | type OpenAIConfigForm = { | ||||||
|  | 	key: string; | ||||||
|  | 	url: string; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| type EmbeddingModelUpdateForm = { | type EmbeddingModelUpdateForm = { | ||||||
|  | 	openai_config?: OpenAIConfigForm; | ||||||
| 	embedding_engine: string; | 	embedding_engine: string; | ||||||
| 	embedding_model: string; | 	embedding_model: string; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -29,6 +29,9 @@ | ||||||
| 	let embeddingEngine = ''; | 	let embeddingEngine = ''; | ||||||
| 	let embeddingModel = ''; | 	let embeddingModel = ''; | ||||||
| 
 | 
 | ||||||
|  | 	let openAIKey = ''; | ||||||
|  | 	let openAIUrl = ''; | ||||||
|  | 
 | ||||||
| 	let chunkSize = 0; | 	let chunkSize = 0; | ||||||
| 	let chunkOverlap = 0; | 	let chunkOverlap = 0; | ||||||
| 	let pdfExtractImages = true; | 	let pdfExtractImages = true; | ||||||
|  | @ -50,15 +53,6 @@ | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
| 	const embeddingModelUpdateHandler = async () => { | 	const embeddingModelUpdateHandler = async () => { | ||||||
| 		if (embeddingModel === '') { |  | ||||||
| 			toast.error( |  | ||||||
| 				$i18n.t( |  | ||||||
| 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' |  | ||||||
| 				) |  | ||||||
| 			); |  | ||||||
| 			return; |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { | 		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { | ||||||
| 			toast.error( | 			toast.error( | ||||||
| 				$i18n.t( | 				$i18n.t( | ||||||
|  | @ -67,21 +61,46 @@ | ||||||
| 			); | 			); | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
|  | 		if (embeddingEngine === 'ollama' && embeddingModel === '') { | ||||||
|  | 			toast.error( | ||||||
|  | 				$i18n.t( | ||||||
|  | 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' | ||||||
|  | 				) | ||||||
|  | 			); | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if (embeddingEngine === 'openai' && embeddingModel === '') { | ||||||
|  | 			toast.error( | ||||||
|  | 				$i18n.t( | ||||||
|  | 					'Model filesystem path detected. Model shortname is required for update, cannot continue.' | ||||||
|  | 				) | ||||||
|  | 			); | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') { | ||||||
|  | 			toast.error($i18n.t('OpenAI URL/Key required.')); | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		console.log('Update embedding model attempt:', embeddingModel); | 		console.log('Update embedding model attempt:', embeddingModel); | ||||||
| 
 | 
 | ||||||
| 		updateEmbeddingModelLoading = true; | 		updateEmbeddingModelLoading = true; | ||||||
| 		const res = await updateEmbeddingConfig(localStorage.token, { | 		const res = await updateEmbeddingConfig(localStorage.token, { | ||||||
| 			embedding_engine: embeddingEngine, | 			embedding_engine: embeddingEngine, | ||||||
| 			embedding_model: embeddingModel | 			embedding_model: embeddingModel, | ||||||
|  | 			...(embeddingEngine === 'openai' | ||||||
|  | 				? { | ||||||
|  | 						openai_config: { | ||||||
|  | 							key: openAIKey, | ||||||
|  | 							url: openAIUrl | ||||||
|  | 						} | ||||||
|  | 				  } | ||||||
|  | 				: {}) | ||||||
| 		}).catch(async (error) => { | 		}).catch(async (error) => { | ||||||
| 			toast.error(error); | 			toast.error(error); | ||||||
| 
 | 			await setEmbeddingConfig(); | ||||||
| 			const embeddingConfig = await getEmbeddingConfig(localStorage.token); |  | ||||||
| 			if (embeddingConfig) { |  | ||||||
| 				embeddingEngine = embeddingConfig.embedding_engine; |  | ||||||
| 				embeddingModel = embeddingConfig.embedding_model; |  | ||||||
| 			} |  | ||||||
| 			return null; | 			return null; | ||||||
| 		}); | 		}); | ||||||
| 		updateEmbeddingModelLoading = false; | 		updateEmbeddingModelLoading = false; | ||||||
|  | @ -89,7 +108,7 @@ | ||||||
| 		if (res) { | 		if (res) { | ||||||
| 			console.log('embeddingModelUpdateHandler:', res); | 			console.log('embeddingModelUpdateHandler:', res); | ||||||
| 			if (res.status === true) { | 			if (res.status === true) { | ||||||
| 				toast.success($i18n.t('Model {{embedding_model}} update complete!', res), { | 				toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), { | ||||||
| 					duration: 1000 * 10 | 					duration: 1000 * 10 | ||||||
| 				}); | 				}); | ||||||
| 			} | 			} | ||||||
|  | @ -107,6 +126,18 @@ | ||||||
| 		querySettings = await updateQuerySettings(localStorage.token, querySettings); | 		querySettings = await updateQuerySettings(localStorage.token, querySettings); | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
|  | 	const setEmbeddingConfig = async () => { | ||||||
|  | 		const embeddingConfig = await getEmbeddingConfig(localStorage.token); | ||||||
|  | 
 | ||||||
|  | 		if (embeddingConfig) { | ||||||
|  | 			embeddingEngine = embeddingConfig.embedding_engine; | ||||||
|  | 			embeddingModel = embeddingConfig.embedding_model; | ||||||
|  | 
 | ||||||
|  | 			openAIKey = embeddingConfig.openai_config.key; | ||||||
|  | 			openAIUrl = embeddingConfig.openai_config.url; | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
| 	onMount(async () => { | 	onMount(async () => { | ||||||
| 		const res = await getRAGConfig(localStorage.token); | 		const res = await getRAGConfig(localStorage.token); | ||||||
| 
 | 
 | ||||||
|  | @ -117,12 +148,7 @@ | ||||||
| 			chunkOverlap = res.chunk.chunk_overlap; | 			chunkOverlap = res.chunk.chunk_overlap; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		const embeddingConfig = await getEmbeddingConfig(localStorage.token); | 		await setEmbeddingConfig(); | ||||||
| 
 |  | ||||||
| 		if (embeddingConfig) { |  | ||||||
| 			embeddingEngine = embeddingConfig.embedding_engine; |  | ||||||
| 			embeddingModel = embeddingConfig.embedding_model; |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		querySettings = await getQuerySettings(localStorage.token); | 		querySettings = await getQuerySettings(localStorage.token); | ||||||
| 	}); | 	}); | ||||||
|  | @ -146,15 +172,38 @@ | ||||||
| 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" | 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" | ||||||
| 						bind:value={embeddingEngine} | 						bind:value={embeddingEngine} | ||||||
| 						placeholder="Select an embedding engine" | 						placeholder="Select an embedding engine" | ||||||
| 						on:change={() => { | 						on:change={(e) => { | ||||||
| 							embeddingModel = ''; | 							if (e.target.value === 'ollama') { | ||||||
|  | 								embeddingModel = ''; | ||||||
|  | 							} else if (e.target.value === 'openai') { | ||||||
|  | 								embeddingModel = 'text-embedding-3-small'; | ||||||
|  | 							} | ||||||
| 						}} | 						}} | ||||||
| 					> | 					> | ||||||
| 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> | 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> | ||||||
| 						<option value="ollama">{$i18n.t('Ollama')}</option> | 						<option value="ollama">{$i18n.t('Ollama')}</option> | ||||||
|  | 						<option value="openai">{$i18n.t('OpenAI')}</option> | ||||||
| 					</select> | 					</select> | ||||||
| 				</div> | 				</div> | ||||||
| 			</div> | 			</div> | ||||||
|  | 
 | ||||||
|  | 			{#if embeddingEngine === 'openai'} | ||||||
|  | 				<div class="mt-1 flex gap-2"> | ||||||
|  | 					<input | ||||||
|  | 						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||||
|  | 						placeholder={$i18n.t('API Base URL')} | ||||||
|  | 						bind:value={openAIUrl} | ||||||
|  | 						required | ||||||
|  | 					/> | ||||||
|  | 
 | ||||||
|  | 					<input | ||||||
|  | 						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||||
|  | 						placeholder={$i18n.t('API Key')} | ||||||
|  | 						bind:value={openAIKey} | ||||||
|  | 						required | ||||||
|  | 					/> | ||||||
|  | 				</div> | ||||||
|  | 			{/if} | ||||||
| 		</div> | 		</div> | ||||||
| 
 | 
 | ||||||
| 		<div class="space-y-2"> | 		<div class="space-y-2"> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek