From b1b72441bbf0a60d1d0bc873a4f0b86f35f9a0f1 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 14 Apr 2024 19:48:15 -0400 Subject: [PATCH] feat: openai embeddings integration --- backend/apps/ollama/main.py | 4 +- backend/apps/rag/main.py | 6 +- backend/apps/rag/utils.py | 82 +++++++++++---- backend/main.py | 4 + src/lib/apis/rag/index.ts | 6 ++ .../documents/Settings/General.svelte | 99 ++++++++++++++----- 6 files changed, 155 insertions(+), 46 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 387ff05d..9258efa6 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -659,7 +659,7 @@ def generate_ollama_embeddings( url_idx: Optional[int] = None, ): - log.info("generate_ollama_embeddings", form_data) + log.info(f"generate_ollama_embeddings {form_data}") if url_idx == None: model = form_data.model @@ -688,7 +688,7 @@ def generate_ollama_embeddings( data = r.json() - log.info("generate_ollama_embeddings", data) + log.info(f"generate_ollama_embeddings {data}") if "embedding" in data: return data["embedding"] diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 11860032..04554c3d 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b docs = text_splitter.split_documents(data) if len(docs) > 0: - log.info("store_data_in_vector_db", "store_docs_in_vector_db") + log.info(f"store_data_in_vector_db {docs}") return store_docs_in_vector_db(docs, collection_name, overwrite), None else: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -440,7 +440,7 @@ def store_text_in_vector_db( def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: - log.info("store_docs_in_vector_db", docs, collection_name) + log.info(f"store_docs_in_vector_db {docs} {collection_name}") texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] @@ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b collection.add(*batch) else: + collection = CHROMA_CLIENT.create_collection(name=collection_name) + if app.state.RAG_EMBEDDING_ENGINE == "ollama": embeddings = [ generate_ollama_embeddings( diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index a0956e2f..140fd88e 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -6,9 +6,12 @@ import requests from huggingface_hub import snapshot_download +from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm + from config import SRC_LOG_LEVELS, CHROMA_CLIENT + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_embeddings_doc(collection_name: str, query_embeddings, k: int): try: # if you use docker use the model from the environment variable - log.info("query_embeddings_doc", query_embeddings) + log.info(f"query_embeddings_doc {query_embeddings}") collection = CHROMA_CLIENT.get_collection( name=collection_name, ) @@ -118,7 +121,7 @@ def query_collection( def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int): results = [] - log.info("query_embeddings_collection", query_embeddings) + log.info(f"query_embeddings_collection {query_embeddings}") for collection_name in collection_names: try: @@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str): return template -def rag_messages(docs, messages, template, k, embedding_function): +def rag_messages( + docs, + messages, + template, + k, + embedding_engine, + embedding_model, + embedding_function, + openai_key, + openai_url, +): log.debug(f"docs: {docs}") last_user_message_idx = None @@ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function): context = None try: - if doc["type"] == "collection": - context = query_collection( - collection_names=doc["collection_names"], - query=query, - k=k, - embedding_function=embedding_function, - ) - elif doc["type"] == "text": + + if doc["type"] == "text": context = doc["content"] else: - context = query_doc( - collection_name=doc["collection_name"], - query=query, - k=k, - embedding_function=embedding_function, - ) + if embedding_engine == "": + if doc["type"] == "collection": + context = query_collection( + collection_names=doc["collection_names"], + query=query, + k=k, + embedding_function=embedding_function, + ) + else: + context = query_doc( + collection_name=doc["collection_name"], + query=query, + k=k, + embedding_function=embedding_function, + ) + + else: + if embedding_engine == "ollama": + query_embeddings = generate_ollama_embeddings( + GenerateEmbeddingsForm( + **{ + "model": embedding_model, + "prompt": query, + } + ) + ) + elif embedding_engine == "openai": + query_embeddings = generate_openai_embeddings( + model=embedding_model, + text=query, + key=openai_key, + url=openai_url, + ) + + if doc["type"] == "collection": + context = query_embeddings_collection( + collection_names=doc["collection_names"], + query_embeddings=query_embeddings, + k=k, + ) + else: + context = query_embeddings_doc( + collection_name=doc["collection_name"], + query_embeddings=query_embeddings, + k=k, + ) + except Exception as e: log.exception(e) context = None diff --git a/backend/main.py b/backend/main.py index d63847bc..4b1809a2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware): data["messages"], rag_app.state.RAG_TEMPLATE, rag_app.state.TOP_K, + rag_app.state.RAG_EMBEDDING_ENGINE, + rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.sentence_transformer_ef, + rag_app.state.RAG_OPENAI_API_KEY, + rag_app.state.RAG_OPENAI_API_BASE_URL, ) del data["docs"] diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index bfcee55f..8a63b69c 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => { return res; }; +type OpenAIConfigForm = { + key: string; + url: string; +}; + type EmbeddingModelUpdateForm = { + openai_config?: OpenAIConfigForm; embedding_engine: string; embedding_model: string; }; diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte index c9142fbe..63d78562 100644 --- a/src/lib/components/documents/Settings/General.svelte +++ b/src/lib/components/documents/Settings/General.svelte @@ -29,6 +29,9 @@ let embeddingEngine = ''; let embeddingModel = ''; + let openAIKey = ''; + let openAIUrl = ''; + let chunkSize = 0; let chunkOverlap = 0; let pdfExtractImages = true; @@ -50,15 +53,6 @@ }; const embeddingModelUpdateHandler = async () => { - if (embeddingModel === '') { - toast.error( - $i18n.t( - 'Model filesystem path detected. Model shortname is required for update, cannot continue.' - ) - ); - return; - } - if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) { toast.error( $i18n.t( @@ -67,21 +61,46 @@ ); return; } + if (embeddingEngine === 'ollama' && embeddingModel === '') { + toast.error( + $i18n.t( + 'Model filesystem path detected. Model shortname is required for update, cannot continue.' + ) + ); + return; + } + + if (embeddingEngine === 'openai' && embeddingModel === '') { + toast.error( + $i18n.t( + 'Model filesystem path detected. Model shortname is required for update, cannot continue.' + ) + ); + return; + } + + if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') { + toast.error($i18n.t('OpenAI URL/Key required.')); + return; + } console.log('Update embedding model attempt:', embeddingModel); updateEmbeddingModelLoading = true; const res = await updateEmbeddingConfig(localStorage.token, { embedding_engine: embeddingEngine, - embedding_model: embeddingModel + embedding_model: embeddingModel, + ...(embeddingEngine === 'openai' + ? { + openai_config: { + key: openAIKey, + url: openAIUrl + } + } + : {}) }).catch(async (error) => { toast.error(error); - - const embeddingConfig = await getEmbeddingConfig(localStorage.token); - if (embeddingConfig) { - embeddingEngine = embeddingConfig.embedding_engine; - embeddingModel = embeddingConfig.embedding_model; - } + await setEmbeddingConfig(); return null; }); updateEmbeddingModelLoading = false; @@ -89,7 +108,7 @@ if (res) { console.log('embeddingModelUpdateHandler:', res); if (res.status === true) { - toast.success($i18n.t('Model {{embedding_model}} update complete!', res), { + toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), { duration: 1000 * 10 }); } @@ -107,6 +126,18 @@ querySettings = await updateQuerySettings(localStorage.token, querySettings); }; + const setEmbeddingConfig = async () => { + const embeddingConfig = await getEmbeddingConfig(localStorage.token); + + if (embeddingConfig) { + embeddingEngine = embeddingConfig.embedding_engine; + embeddingModel = embeddingConfig.embedding_model; + + openAIKey = embeddingConfig.openai_config.key; + openAIUrl = embeddingConfig.openai_config.url; + } + }; + onMount(async () => { const res = await getRAGConfig(localStorage.token); @@ -117,12 +148,7 @@ chunkOverlap = res.chunk.chunk_overlap; } - const embeddingConfig = await getEmbeddingConfig(localStorage.token); - - if (embeddingConfig) { - embeddingEngine = embeddingConfig.embedding_engine; - embeddingModel = embeddingConfig.embedding_model; - } + await setEmbeddingConfig(); querySettings = await getQuerySettings(localStorage.token); }); @@ -146,15 +172,38 @@ class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right" bind:value={embeddingEngine} placeholder="Select an embedding engine" - on:change={() => { - embeddingModel = ''; + on:change={(e) => { + if (e.target.value === 'ollama') { + embeddingModel = ''; + } else if (e.target.value === 'openai') { + embeddingModel = 'text-embedding-3-small'; + } }} > + + + {#if embeddingEngine === 'openai'} +
+ + + +
+ {/if}