forked from open-webui/open-webui
		
	feat: openai embeddings support
This commit is contained in:
		
							parent
							
								
									36ce157907
								
							
						
					
					
						commit
						b48e73fa43
					
				
					 2 changed files with 127 additions and 54 deletions
				
			
		|  | @ -53,6 +53,7 @@ from apps.rag.utils import ( | |||
|     query_collection, | ||||
|     query_embeddings_collection, | ||||
|     get_embedding_model_path, | ||||
|     generate_openai_embeddings, | ||||
| ) | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|  | @ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | |||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| 
 | ||||
| app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com" | ||||
| app.state.RAG_OPENAI_API_KEY = "" | ||||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
| 
 | ||||
|  | @ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)): | |||
|         "status": True, | ||||
|         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "openai_config": { | ||||
|             "url": app.state.RAG_OPENAI_API_BASE_URL, | ||||
|             "key": app.state.RAG_OPENAI_API_KEY, | ||||
|         }, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIConfigForm(BaseModel): | ||||
|     url: str | ||||
|     key: str | ||||
| 
 | ||||
| 
 | ||||
| class EmbeddingModelUpdateForm(BaseModel): | ||||
|     openai_config: Optional[OpenAIConfigForm] = None | ||||
|     embedding_engine: str | ||||
|     embedding_model: str | ||||
| 
 | ||||
|  | @ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel): | |||
| async def update_embedding_config( | ||||
|     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     log.info( | ||||
|         f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" | ||||
|     ) | ||||
| 
 | ||||
|     try: | ||||
|         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|         if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: | ||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|             app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
|             if form_data.openai_config != None: | ||||
|                 app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url | ||||
|                 app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key | ||||
|         else: | ||||
|             sentence_transformer_ef = ( | ||||
|                 embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|  | @ -183,6 +198,10 @@ async def update_embedding_config( | |||
|             "status": True, | ||||
|             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, | ||||
|             "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|             "openai_config": { | ||||
|                 "url": app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                 "key": app.state.RAG_OPENAI_API_KEY, | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     except Exception as e: | ||||
|  | @ -275,28 +294,37 @@ def query_doc_handler( | |||
| ): | ||||
| 
 | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                         "prompt": form_data.query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|             return query_embeddings_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|         else: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             return query_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query=form_data.query, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 embedding_function=app.state.sentence_transformer_ef, | ||||
|             ) | ||||
|         else: | ||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|                 query_embeddings = generate_ollama_embeddings( | ||||
|                     GenerateEmbeddingsForm( | ||||
|                         **{ | ||||
|                             "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                             "prompt": form_data.query, | ||||
|                         } | ||||
|                     ) | ||||
|                 ) | ||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|                 query_embeddings = generate_openai_embeddings( | ||||
|                     model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                     text=form_data.query, | ||||
|                     key=app.state.RAG_OPENAI_API_KEY, | ||||
|                     url=app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
| 
 | ||||
|             return query_embeddings_doc( | ||||
|                 collection_name=form_data.collection_name, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|  | @ -317,28 +345,38 @@ def query_collection_handler( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                         "prompt": form_data.query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|             return query_embeddings_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|         else: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             return query_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query=form_data.query, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|                 embedding_function=app.state.sentence_transformer_ef, | ||||
|             ) | ||||
|         else: | ||||
| 
 | ||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|                 query_embeddings = generate_ollama_embeddings( | ||||
|                     GenerateEmbeddingsForm( | ||||
|                         **{ | ||||
|                             "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                             "prompt": form_data.query, | ||||
|                         } | ||||
|                     ) | ||||
|                 ) | ||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|                 query_embeddings = generate_openai_embeddings( | ||||
|                     model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                     text=form_data.query, | ||||
|                     key=app.state.RAG_OPENAI_API_KEY, | ||||
|                     url=app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
| 
 | ||||
|             return query_embeddings_collection( | ||||
|                 collection_names=form_data.collection_names, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|  | @ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|                     log.info(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|             for batch in create_batches( | ||||
|                 api=CHROMA_CLIENT, | ||||
|                 ids=[str(uuid.uuid1()) for _ in texts], | ||||
|                 metadatas=metadatas, | ||||
|                 embeddings=[ | ||||
|                     generate_ollama_embeddings( | ||||
|                         GenerateEmbeddingsForm( | ||||
|                             **{"model": RAG_EMBEDDING_MODEL, "prompt": text} | ||||
|                         ) | ||||
|                     ) | ||||
|                     for text in texts | ||||
|                 ], | ||||
|             ): | ||||
|                 collection.add(*batch) | ||||
|         else: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
| 
 | ||||
|             collection = CHROMA_CLIENT.create_collection( | ||||
|                 name=collection_name, | ||||
|  | @ -446,7 +467,36 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|             ): | ||||
|                 collection.add(*batch) | ||||
| 
 | ||||
|             return True | ||||
|         else: | ||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|                 embeddings = [ | ||||
|                     generate_ollama_embeddings( | ||||
|                         GenerateEmbeddingsForm( | ||||
|                             **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} | ||||
|                         ) | ||||
|                     ) | ||||
|                     for text in texts | ||||
|                 ] | ||||
|             elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|                 embeddings = [ | ||||
|                     generate_openai_embeddings( | ||||
|                         model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                         text=text, | ||||
|                         key=app.state.RAG_OPENAI_API_KEY, | ||||
|                         url=app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                     ) | ||||
|                     for text in texts | ||||
|                 ] | ||||
| 
 | ||||
|             for batch in create_batches( | ||||
|                 api=CHROMA_CLIENT, | ||||
|                 ids=[str(uuid.uuid1()) for _ in texts], | ||||
|                 metadatas=metadatas, | ||||
|                 embeddings=embeddings, | ||||
|             ): | ||||
|                 collection.add(*batch) | ||||
| 
 | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|  |  | |||
|  | @ -269,3 +269,26 @@ def get_embedding_model_path( | |||
|     except Exception as e: | ||||
|         log.exception(f"Cannot determine embedding model snapshot path: {e}") | ||||
|         return embedding_model | ||||
| 
 | ||||
| 
 | ||||
| def generate_openai_embeddings( | ||||
|     model: str, text: str, key: str, url: str = "https://api.openai.com" | ||||
| ): | ||||
|     try: | ||||
|         r = requests.post( | ||||
|             f"{url}/v1/embeddings", | ||||
|             headers={ | ||||
|                 "Content-Type": "application/json", | ||||
|                 "Authorization": f"Bearer {key}", | ||||
|             }, | ||||
|             json={"input": text, "model": model}, | ||||
|         ) | ||||
|         r.raise_for_status() | ||||
|         data = r.json() | ||||
|         if "data" in data: | ||||
|             return data["data"][0]["embedding"] | ||||
|         else: | ||||
|             raise "Something went wrong :/" | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         return None | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek