From 9634e2da3e4ac8db6f13407d301ff27af9901d1f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 7 Jan 2024 01:40:36 -0800 Subject: [PATCH] feat: full integration --- backend/apps/rag/main.py | 60 +++++++++++++++---- backend/config.py | 21 ++++++- src/lib/components/chat/MessageInput.svelte | 18 +++--- .../chat/Messages/UserMessage.svelte | 32 +++++++++- src/lib/utils/index.ts | 3 +- src/routes/(app)/+page.svelte | 7 ++- 6 files changed, 116 insertions(+), 25 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0e7f9b07..4d563ab7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -9,6 +9,7 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware +import os, shutil from chromadb.utils import embedding_functions @@ -23,7 +24,7 @@ from typing import Optional import uuid -from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( @@ -51,7 +52,7 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name): +def store_data_in_vector_db(data, collection_name) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP ) @@ -60,13 +61,22 @@ def store_data_in_vector_db(data, collection_name): texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] - collection = CHROMA_CLIENT.create_collection( - name=collection_name, embedding_function=EMBEDDING_FUNC - ) + try: + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=EMBEDDING_FUNC + ) - collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + return True + except Exception as e: + print(e) + print(e.__class__.__name__) + if e.__class__.__name__ == "UniqueConstraintError": + return True + + return False @app.get("/") @@ -116,7 +126,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): try: filename = file.filename - file_path = f"./data/{filename}" + file_path = f"{UPLOAD_DIR}/{filename}" contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) @@ -128,8 +138,15 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): loader = TextLoader(file_path) data = loader.load() - store_data_in_vector_db(data, collection_name) - return {"status": True, "collection_name": collection_name} + result = store_data_in_vector_db(data, collection_name) + + if result: + return {"status": True, "collection_name": collection_name} + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) except Exception as e: print(e) raise HTTPException( @@ -138,6 +155,27 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): ) +@app.get("/reset/db") def reset_vector_db(): CHROMA_CLIENT.reset() + + +@app.get("/reset") +def reset(): + folder = f"{UPLOAD_DIR}" + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + try: + CHROMA_CLIENT.reset() + except Exception as e: + print(e) + return {"status": True} diff --git a/backend/config.py b/backend/config.py index df57c829..03718a06 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,14 +1,31 @@ from dotenv import load_dotenv, find_dotenv import os + + import chromadb +from chromadb import Settings + from secrets import token_bytes from base64 import b64encode from constants import ERROR_MESSAGES + +from pathlib import Path + load_dotenv(find_dotenv("../.env")) + +#################################### +# File Upload +#################################### + + +UPLOAD_DIR = "./data/uploads" +Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # ENV (dev,test,prod) #################################### @@ -64,6 +81,8 @@ if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": CHROMA_DATA_PATH = "./data/vector_db" EMBED_MODEL = "all-MiniLM-L6-v2" -CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH) +CHROMA_CLIENT = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True) +) CHUNK_SIZE = 1500 CHUNK_OVERLAP = 100 diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 13078d64..54baaace 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -124,16 +124,16 @@ reader.readAsDataURL(file); } else if (['application/pdf', 'text/plain'].includes(file['type'])) { console.log(file); - const hash = await calculateSHA256(file); - // const res = uploadDocToVectorDB(localStorage.token, hash,file); + const hash = (await calculateSHA256(file)).substring(0, 63); + const res = await uploadDocToVectorDB(localStorage.token, hash, file); - if (true) { + if (res) { files = [ ...files, { type: 'doc', name: file.name, - collection_name: hash + collection_name: res.collection_name } ]; } @@ -243,16 +243,16 @@ reader.readAsDataURL(file); } else if (['application/pdf', 'text/plain'].includes(file['type'])) { console.log(file); - const hash = await calculateSHA256(file); - // const res = uploadDocToVectorDB(localStorage.token,hash,file); + const hash = (await calculateSHA256(file)).substring(0, 63); + const res = await uploadDocToVectorDB(localStorage.token, hash, file); - if (true) { + if (res) { files = [ ...files, { type: 'doc', name: file.name, - collection_name: hash + collection_name: res.collection_name } ]; filesInputElement.value = ''; @@ -280,7 +280,7 @@ input {:else if file.type === 'doc'}
{#if message.files} -
+
{#each message.files as file}
{#if file.type === 'image'} input + {:else if file.type === 'doc'} +
+
+ + + + +
+ +
+
+ {file.name} +
+ +
Document
+
+
{/if}
{/each} diff --git a/src/lib/utils/index.ts b/src/lib/utils/index.ts index d9f6fd7d..46bc8f04 100644 --- a/src/lib/utils/index.ts +++ b/src/lib/utils/index.ts @@ -129,7 +129,6 @@ export const findWordIndices = (text) => { }; export const calculateSHA256 = async (file) => { - console.log(file); // Create a FileReader to read the file asynchronously const reader = new FileReader(); @@ -156,7 +155,7 @@ export const calculateSHA256 = async (file) => { const hashArray = Array.from(new Uint8Array(hashBuffer)); const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join(''); - return `sha256:${hashHex}`; + return `${hashHex}`; } catch (error) { console.error('Error calculating SHA-256 hash:', error); throw error; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 12ccadbe..193d6d14 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -186,8 +186,11 @@ const _chatId = JSON.parse(JSON.stringify($chatId)); // TODO: update below to include all ancestral files - const docs = history.messages[parentId].files.filter((item) => item.type === 'file'); + console.log(history.messages[parentId]); + const docs = history.messages[parentId]?.files?.filter((item) => item.type === 'doc') ?? []; + + console.log(docs); if (docs.length > 0) { const query = history.messages[parentId].content; @@ -207,6 +210,8 @@ return `${a}${context.documents.join(' ')}\n`; }, ''); + console.log(contextString); + history.messages[parentId].raContent = RAGTemplate(contextString, query); history.messages[parentId].contexts = relevantContexts; await tick();