forked from open-webui/open-webui
		
	feat: openai embeddings integration
This commit is contained in:
		
							parent
							
								
									b48e73fa43
								
							
						
					
					
						commit
						b1b72441bb
					
				
					 6 changed files with 155 additions and 46 deletions
				
			
		|  | @ -659,7 +659,7 @@ def generate_ollama_embeddings( | |||
|     url_idx: Optional[int] = None, | ||||
| ): | ||||
| 
 | ||||
|     log.info("generate_ollama_embeddings", form_data) | ||||
|     log.info(f"generate_ollama_embeddings {form_data}") | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         model = form_data.model | ||||
|  | @ -688,7 +688,7 @@ def generate_ollama_embeddings( | |||
| 
 | ||||
|         data = r.json() | ||||
| 
 | ||||
|         log.info("generate_ollama_embeddings", data) | ||||
|         log.info(f"generate_ollama_embeddings {data}") | ||||
| 
 | ||||
|         if "embedding" in data: | ||||
|             return data["embedding"] | ||||
|  |  | |||
|  | @ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b | |||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     if len(docs) > 0: | ||||
|         log.info("store_data_in_vector_db", "store_docs_in_vector_db") | ||||
|         log.info(f"store_data_in_vector_db {docs}") | ||||
|         return store_docs_in_vector_db(docs, collection_name, overwrite), None | ||||
|     else: | ||||
|         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) | ||||
|  | @ -440,7 +440,7 @@ def store_text_in_vector_db( | |||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
|     log.info("store_docs_in_vector_db", docs, collection_name) | ||||
|     log.info(f"store_docs_in_vector_db {docs} {collection_name}") | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
|  | @ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|                 collection.add(*batch) | ||||
| 
 | ||||
|         else: | ||||
|             collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|             if app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|                 embeddings = [ | ||||
|                     generate_ollama_embeddings( | ||||
|  |  | |||
|  | @ -6,9 +6,12 @@ import requests | |||
| 
 | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||
| 
 | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
|  | @ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): | |||
| def query_embeddings_doc(collection_name: str, query_embeddings, k: int): | ||||
|     try: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         log.info("query_embeddings_doc", query_embeddings) | ||||
|         log.info(f"query_embeddings_doc {query_embeddings}") | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=collection_name, | ||||
|         ) | ||||
|  | @ -118,7 +121,7 @@ def query_collection( | |||
| def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): | ||||
| 
 | ||||
|     results = [] | ||||
|     log.info("query_embeddings_collection", query_embeddings) | ||||
|     log.info(f"query_embeddings_collection {query_embeddings}") | ||||
| 
 | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|  | @ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str): | |||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages(docs, messages, template, k, embedding_function): | ||||
| def rag_messages( | ||||
|     docs, | ||||
|     messages, | ||||
|     template, | ||||
|     k, | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|     openai_key, | ||||
|     openai_url, | ||||
| ): | ||||
|     log.debug(f"docs: {docs}") | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|  | @ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|         context = None | ||||
| 
 | ||||
|         try: | ||||
|             if doc["type"] == "collection": | ||||
|                 context = query_collection( | ||||
|                     collection_names=doc["collection_names"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|             elif doc["type"] == "text": | ||||
| 
 | ||||
|             if doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 context = query_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|                 if embedding_engine == "": | ||||
|                     if doc["type"] == "collection": | ||||
|                         context = query_collection( | ||||
|                             collection_names=doc["collection_names"], | ||||
|                             query=query, | ||||
|                             k=k, | ||||
|                             embedding_function=embedding_function, | ||||
|                         ) | ||||
|                     else: | ||||
|                         context = query_doc( | ||||
|                             collection_name=doc["collection_name"], | ||||
|                             query=query, | ||||
|                             k=k, | ||||
|                             embedding_function=embedding_function, | ||||
|                         ) | ||||
| 
 | ||||
|                 else: | ||||
|                     if embedding_engine == "ollama": | ||||
|                         query_embeddings = generate_ollama_embeddings( | ||||
|                             GenerateEmbeddingsForm( | ||||
|                                 **{ | ||||
|                                     "model": embedding_model, | ||||
|                                     "prompt": query, | ||||
|                                 } | ||||
|                             ) | ||||
|                         ) | ||||
|                     elif embedding_engine == "openai": | ||||
|                         query_embeddings = generate_openai_embeddings( | ||||
|                             model=embedding_model, | ||||
|                             text=query, | ||||
|                             key=openai_key, | ||||
|                             url=openai_url, | ||||
|                         ) | ||||
| 
 | ||||
|                     if doc["type"] == "collection": | ||||
|                         context = query_embeddings_collection( | ||||
|                             collection_names=doc["collection_names"], | ||||
|                             query_embeddings=query_embeddings, | ||||
|                             k=k, | ||||
|                         ) | ||||
|                     else: | ||||
|                         context = query_embeddings_doc( | ||||
|                             collection_name=doc["collection_name"], | ||||
|                             query_embeddings=query_embeddings, | ||||
|                             k=k, | ||||
|                         ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             log.exception(e) | ||||
|             context = None | ||||
|  |  | |||
|  | @ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|                     data["messages"], | ||||
|                     rag_app.state.RAG_TEMPLATE, | ||||
|                     rag_app.state.TOP_K, | ||||
|                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||
|                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||
|                     rag_app.state.sentence_transformer_ef, | ||||
|                     rag_app.state.RAG_OPENAI_API_KEY, | ||||
|                     rag_app.state.RAG_OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek