From 2952e6116762f1d1b43466fa92bf8762bf09e447 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 14 Apr 2024 17:55:00 -0400 Subject: [PATCH] feat: external embeddings support --- backend/apps/ollama/main.py | 49 +++++ backend/apps/rag/main.py | 121 ++++++++--- backend/apps/rag/utils.py | 36 ++++ backend/config.py | 3 + src/lib/apis/ollama/index.ts | 26 +++ .../documents/Settings/General.svelte | 195 ++++++++++-------- 6 files changed, 312 insertions(+), 118 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 7140cad9..0132179f 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -654,6 +654,55 @@ async def generate_embeddings( ) +def generate_ollama_embeddings( + form_data: GenerateEmbeddingsForm, + url_idx: Optional[int] = None, +): + if url_idx == None: + model = form_data.model + + if ":" not in model: + model = f"{model}:latest" + + if model in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[model]["urls"]) + else: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), + ) + + url = app.state.OLLAMA_BASE_URLS[url_idx] + log.info(f"url: {url}") + + try: + r = requests.request( + method="POST", + url=f"{url}/api/embeddings", + data=form_data.model_dump_json(exclude_none=True).encode(), + ) + r.raise_for_status() + + data = r.json() + + if "embedding" in data: + return data["embedding"] + else: + raise "Something went wrong :/" + except Exception as e: + log.exception(e) + error_detail = "Open WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"Ollama: {res['error']}" + except: + error_detail = f"Ollama: {e}" + + raise error_detail + + class GenerateCompletionForm(BaseModel): model: str prompt: str diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index f03aa4b7..423f1e03 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -39,13 +39,21 @@ import uuid import json +from apps.ollama.main import generate_ollama_embeddings + from apps.web.models.documents import ( Documents, DocumentForm, DocumentResponse, ) -from apps.rag.utils import query_doc, query_collection, get_embedding_model_path +from apps.rag.utils import ( + query_doc, + query_embeddings_doc, + query_collection, + query_embeddings_collection, + get_embedding_model_path, +) from utils.misc import ( calculate_sha256, @@ -58,6 +66,7 @@ from config import ( SRC_LOG_LEVELS, UPLOAD_DIR, DOCS_DIR, + RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, DEVICE_TYPE, @@ -74,17 +83,20 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) app = FastAPI() -app.state.PDF_EXTRACT_IMAGES = False + +app.state.TOP_K = 4 app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_OVERLAP = CHUNK_OVERLAP + + +app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_TEMPLATE = RAG_TEMPLATE -app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.PDF_EXTRACT_IMAGES = False -app.state.TOP_K = 4 - app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( model_name=get_embedding_model_path( @@ -121,6 +133,7 @@ async def get_status(): "chunk_size": app.state.CHUNK_SIZE, "chunk_overlap": app.state.CHUNK_OVERLAP, "template": app.state.RAG_TEMPLATE, + "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, } @@ -252,12 +265,23 @@ def query_doc_handler( ): try: - return query_doc( - collection_name=form_data.collection_name, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} + ) + + return query_embeddings_doc( + collection_name=form_data.collection_name, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + else: + return query_doc( + collection_name=form_data.collection_name, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, + ) except Exception as e: log.exception(e) raise HTTPException( @@ -277,12 +301,30 @@ def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): - return query_collection( - collection_names=form_data.collection_names, - query=form_data.query, - k=form_data.k if form_data.k else app.state.TOP_K, - embedding_function=app.state.sentence_transformer_ef, - ) + try: + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + query_embeddings = generate_ollama_embeddings( + {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query} + ) + + return query_embeddings_collection( + collection_names=form_data.collection_names, + query_embeddings=query_embeddings, + k=form_data.k if form_data.k else app.state.TOP_K, + ) + else: + return query_collection( + collection_names=form_data.collection_names, + query=form_data.query, + k=form_data.k if form_data.k else app.state.TOP_K, + embedding_function=app.state.sentence_transformer_ef, + ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) @app.post("/web") @@ -317,6 +359,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b chunk_overlap=app.state.CHUNK_OVERLAP, add_start_index=True, ) + docs = text_splitter.split_documents(data) if len(docs) > 0: @@ -337,7 +380,9 @@ def store_text_in_vector_db( return store_docs_in_vector_db(docs, collection_name, overwrite) -def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: +async def store_docs_in_vector_db( + docs, collection_name, overwrite: bool = False +) -> bool: texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] @@ -349,20 +394,36 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b log.info(f"deleting existing collection {collection_name}") CHROMA_CLIENT.delete_collection(name=collection_name) - collection = CHROMA_CLIENT.create_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) + if app.state.RAG_EMBEDDING_ENGINE == "ollama": + collection = CHROMA_CLIENT.create_collection(name=collection_name) - for batch in create_batches( - api=CHROMA_CLIENT, - ids=[str(uuid.uuid1()) for _ in texts], - metadatas=metadatas, - documents=texts, - ): - collection.add(*batch) + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + embeddings=[ + generate_ollama_embeddings( + {"model": RAG_EMBEDDING_MODEL, "prompt": text} + ) + for text in texts + ], + ): + collection.add(*batch) + else: + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) - return True + for batch in create_batches( + api=CHROMA_CLIENT, + ids=[str(uuid.uuid1()) for _ in texts], + metadatas=metadatas, + documents=texts, + ): + collection.add(*batch) + + return True except Exception as e: log.exception(e) if e.__class__.__name__ == "UniqueConstraintError": diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 7bbfe0b8..301c63b9 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -2,6 +2,9 @@ import os import re import logging from typing import List +import requests + + from huggingface_hub import snapshot_download from config import SRC_LOG_LEVELS, CHROMA_CLIENT @@ -26,6 +29,21 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): raise e +def query_embeddings_doc(collection_name: str, query_embeddings, k: int): + try: + # if you use docker use the model from the environment variable + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + return result + except Exception as e: + raise e + + def merge_and_sort_query_results(query_results, k): # Initialize lists to store combined data combined_ids = [] @@ -96,6 +114,24 @@ def query_collection( return merge_and_sort_query_results(results, k) +def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): + + results = [] + for collection_name in collection_names: + try: + collection = CHROMA_CLIENT.get_collection(name=collection_name) + + result = collection.query( + query_embeddings=[query_embeddings], + n_results=k, + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, k) + + def rag_template(template: str, context: str, query: str): template = template.replace("[context]", context) template = template.replace("[query]", query) diff --git a/backend/config.py b/backend/config.py index 6d93115b..938df996 100644 --- a/backend/config.py +++ b/backend/config.py @@ -405,6 +405,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) + +RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") + RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), diff --git a/src/lib/apis/ollama/index.ts b/src/lib/apis/ollama/index.ts index 4618acc4..a94aceac 100644 --- a/src/lib/apis/ollama/index.ts +++ b/src/lib/apis/ollama/index.ts @@ -220,6 +220,32 @@ export const generatePrompt = async (token: string = '', model: string, conversa return res; }; +export const generateEmbeddings = async (token: string = '', model: string, text: string) => { + let error = null; + + const res = await fetch(`${OLLAMA_API_BASE_URL}/api/embeddings`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + model: model, + prompt: text + }) + }).catch((err) => { + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const generateTextCompletion = async (token: string = '', model: string, text: string) => { let error = null; diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c94c1250..85df678c 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -26,6 +26,7 @@ let showResetConfirm = false; + let embeddingEngine = ''; let chunkSize = 0; let chunkOverlap = 0; let pdfExtractImages = true; @@ -118,81 +119,119 @@
{$i18n.t('General Settings')}
-
-
- {$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} +
+
{$i18n.t('Embedding Engine')}
+
+
- -
-
-
-
{$i18n.t('Update Embedding Model')}
-
-
- -
- +
+ +
+ {$i18n.t( + 'Warning: If you update or change your embedding model, you will need to re-import all documents.' + )} +
+ {/if} + +
+ +
+
+ {$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} +
+ +
-
- {$i18n.t( - 'Warning: If you update or change your embedding model, you will need to re-import all documents.' - )} -
-