diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0e7f9b07..4d563ab7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -9,6 +9,7 @@ from fastapi import ( Form, ) from fastapi.middleware.cors import CORSMiddleware +import os, shutil from chromadb.utils import embedding_functions @@ -23,7 +24,7 @@ from typing import Optional import uuid -from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( @@ -51,7 +52,7 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name): +def store_data_in_vector_db(data, collection_name) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP ) @@ -60,13 +61,22 @@ def store_data_in_vector_db(data, collection_name): texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] - collection = CHROMA_CLIENT.create_collection( - name=collection_name, embedding_function=EMBEDDING_FUNC - ) + try: + collection = CHROMA_CLIENT.create_collection( + name=collection_name, embedding_function=EMBEDDING_FUNC + ) - collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + return True + except Exception as e: + print(e) + print(e.__class__.__name__) + if e.__class__.__name__ == "UniqueConstraintError": + return True + + return False @app.get("/") @@ -116,7 +126,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): try: filename = file.filename - file_path = f"./data/{filename}" + file_path = f"{UPLOAD_DIR}/{filename}" contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) @@ -128,8 +138,15 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): loader = TextLoader(file_path) data = loader.load() - store_data_in_vector_db(data, collection_name) - return {"status": True, "collection_name": collection_name} + result = store_data_in_vector_db(data, collection_name) + + if result: + return {"status": True, "collection_name": collection_name} + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) except Exception as e: print(e) raise HTTPException( @@ -138,6 +155,27 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): ) +@app.get("/reset/db") def reset_vector_db(): CHROMA_CLIENT.reset() + + +@app.get("/reset") +def reset(): + folder = f"{UPLOAD_DIR}" + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print("Failed to delete %s. Reason: %s" % (file_path, e)) + + try: + CHROMA_CLIENT.reset() + except Exception as e: + print(e) + return {"status": True} diff --git a/backend/config.py b/backend/config.py index df57c829..03718a06 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,14 +1,31 @@ from dotenv import load_dotenv, find_dotenv import os + + import chromadb +from chromadb import Settings + from secrets import token_bytes from base64 import b64encode from constants import ERROR_MESSAGES + +from pathlib import Path + load_dotenv(find_dotenv("../.env")) + +#################################### +# File Upload +#################################### + + +UPLOAD_DIR = "./data/uploads" +Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # ENV (dev,test,prod) #################################### @@ -64,6 +81,8 @@ if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": CHROMA_DATA_PATH = "./data/vector_db" EMBED_MODEL = "all-MiniLM-L6-v2" -CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH) +CHROMA_CLIENT = chromadb.PersistentClient( + path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True) +) CHUNK_SIZE = 1500 CHUNK_OVERLAP = 100 diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 13078d64..54baaace 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -124,16 +124,16 @@ reader.readAsDataURL(file); } else if (['application/pdf', 'text/plain'].includes(file['type'])) { console.log(file); - const hash = await calculateSHA256(file); - // const res = uploadDocToVectorDB(localStorage.token, hash,file); + const hash = (await calculateSHA256(file)).substring(0, 63); + const res = await uploadDocToVectorDB(localStorage.token, hash, file); - if (true) { + if (res) { files = [ ...files, { type: 'doc', name: file.name, - collection_name: hash + collection_name: res.collection_name } ]; } @@ -243,16 +243,16 @@ reader.readAsDataURL(file); } else if (['application/pdf', 'text/plain'].includes(file['type'])) { console.log(file); - const hash = await calculateSHA256(file); - // const res = uploadDocToVectorDB(localStorage.token,hash,file); + const hash = (await calculateSHA256(file)).substring(0, 63); + const res = await uploadDocToVectorDB(localStorage.token, hash, file); - if (true) { + if (res) { files = [ ...files, { type: 'doc', name: file.name, - collection_name: hash + collection_name: res.collection_name } ]; filesInputElement.value = ''; @@ -280,7 +280,7 @@ {:else if file.type === 'doc'}