From e07001e5f6c0570b7a4dbf594b30b1d5f1ccbd54 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 17 Feb 2024 21:06:08 -0800 Subject: [PATCH] feat: rag folder scan support --- backend/apps/rag/main.py | 99 ++++++++++++++++--- backend/apps/web/routers/documents.py | 4 + backend/config.py | 8 ++ backend/utils/misc.py | 38 +++++++ src/lib/apis/rag/index.ts | 26 +++++ .../chat/Messages/ResponseMessage.svelte | 2 +- .../documents/Settings/General.svelte | 68 +++++++++++++ .../components/documents/SettingsModal.svelte | 86 ++++++++++++++++ src/routes/(app)/documents/+page.svelte | 31 ++++++ 9 files changed, 350 insertions(+), 12 deletions(-) create mode 100644 src/lib/components/documents/Settings/General.svelte create mode 100644 src/lib/components/documents/SettingsModal.svelte diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 07a30ade..ec9e0a8b 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -10,6 +10,8 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware import os, shutil + +from pathlib import Path from typing import List # from chromadb.utils import embedding_functions @@ -28,19 +30,39 @@ from langchain_community.document_loaders import ( UnstructuredExcelLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA +from langchain_community.vectorstores import Chroma from pydantic import BaseModel from typing import Optional - +import mimetypes import uuid +import json import time -from utils.misc import calculate_sha256, calculate_sha256_string + +from apps.web.models.documents import ( + Documents, + DocumentForm, + DocumentResponse, +) + +from utils.misc import ( + calculate_sha256, + calculate_sha256_string, + sanitize_filename, + extract_folders_after_data_docs, +) from utils.utils import get_current_user, get_admin_user -from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP +from config import ( + UPLOAD_DIR, + DOCS_DIR, + EMBED_MODEL, + CHROMA_CLIENT, + CHUNK_SIZE, + CHUNK_OVERLAP, +) from constants import ERROR_MESSAGES # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( @@ -220,8 +242,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ) -def get_loader(file, file_path): - file_ext = file.filename.split(".")[-1].lower() +def get_loader(filename: str, file_content_type: str, file_path: str): + file_ext = filename.split(".")[-1].lower() known_type = True known_source_ext = [ @@ -279,20 +301,20 @@ def get_loader(file, file_path): loader = UnstructuredXMLLoader(file_path) elif file_ext == "md": loader = UnstructuredMarkdownLoader(file_path) - elif file.content_type == "application/epub+zip": + elif file_content_type == "application/epub+zip": loader = UnstructuredEPubLoader(file_path) elif ( - file.content_type + file_content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext in ["doc", "docx"] ): loader = Docx2txtLoader(file_path) - elif file.content_type in [ + elif file_content_type in [ "application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ] or file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(file_path) - elif file_ext in known_source_ext or file.content_type.find("text/") >= 0: + elif file_ext in known_source_ext or file_content_type.find("text/") >= 0: loader = TextLoader(file_path) else: loader = TextLoader(file_path) @@ -323,7 +345,7 @@ def store_doc( collection_name = calculate_sha256(f)[:63] f.close() - loader, known_type = get_loader(file, file_path) + loader, known_type = get_loader(file.filename, file.content_type, file_path) data = loader.load() result = store_data_in_vector_db(data, collection_name) @@ -353,6 +375,61 @@ def store_doc( ) +@app.get("/scan") +def scan_docs_dir(user=Depends(get_admin_user)): + try: + for path in Path(DOCS_DIR).rglob("./**/*"): + if path.is_file() and not path.name.startswith("."): + tags = extract_folders_after_data_docs(path) + filename = path.name + file_content_type = mimetypes.guess_type(path) + + f = open(path, "rb") + collection_name = calculate_sha256(f)[:63] + f.close() + + loader, known_type = get_loader(filename, file_content_type, str(path)) + data = loader.load() + + result = store_data_in_vector_db(data, collection_name) + + if result: + sanitized_filename = sanitize_filename(filename) + doc = Documents.get_doc_by_name(sanitized_filename) + + if doc == None: + doc = Documents.insert_new_doc( + user.id, + DocumentForm( + **{ + "name": sanitized_filename, + "title": filename, + "collection_name": collection_name, + "filename": filename, + "content": ( + json.dumps( + { + "tags": list( + map( + lambda name: {"name": name}, + tags, + ) + ) + } + ) + if len(tags) + else "{}" + ), + } + ), + ) + + except Exception as e: + print(e) + + return True + + @app.get("/reset/db") def reset_vector_db(user=Depends(get_admin_user)): CHROMA_CLIENT.reset() diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py index 5bc473fa..7c69514f 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/web/routers/documents.py @@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): ############################ +class TagItem(BaseModel): + name: str + + class TagDocumentForm(BaseModel): name: str tags: List[dict] diff --git a/backend/config.py b/backend/config.py index d7c89b3b..f5acf06b 100644 --- a/backend/config.py +++ b/backend/config.py @@ -43,6 +43,14 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) CACHE_DIR = f"{DATA_DIR}/cache" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) + +#################################### +# Docs DIR +#################################### + +DOCS_DIR = f"{DATA_DIR}/docs" +Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) + #################################### # OLLAMA_API_BASE_URL #################################### diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 385a2c41..5e9d5876 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,3 +1,4 @@ +from pathlib import Path import hashlib import re @@ -38,3 +39,40 @@ def validate_email_format(email: str) -> bool: if not re.match(r"[^@]+@[^@]+\.[^@]+", email): return False return True + + +def sanitize_filename(file_name): + # Convert to lowercase + lower_case_file_name = file_name.lower() + + # Remove special characters using regular expression + sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name) + + # Replace spaces with dashes + final_file_name = re.sub(r"\s+", "-", sanitized_file_name) + + return final_file_name + + +def extract_folders_after_data_docs(path): + # Convert the path to a Path object if it's not already + path = Path(path) + + # Extract parts of the path + parts = path.parts + + # Find the index of '/data/docs' in the path + try: + index_data_docs = parts.index("data") + 1 + index_docs = parts.index("docs", index_data_docs) + 1 + except ValueError: + return [] + + # Exclude the filename and accumulate folder names + tags = [] + + folders = parts[index_docs:-1] + for idx, part in enumerate(folders): + tags.append("/".join(folders[: idx + 1])) + + return tags diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 3f4f30bf..fc3571aa 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -138,6 +138,32 @@ export const queryCollection = async ( return res; }; +export const scanDocs = async (token: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/scan`, { + method: 'GET', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const resetVectorDB = async (token: string) => { let error = null; diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index a0ffc83c..cc42d0b9 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -366,7 +366,7 @@ {#if message.done}
{#if siblings.length > 1}
diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte new file mode 100644 index 00000000..19a6493b --- /dev/null +++ b/src/lib/components/documents/Settings/General.svelte @@ -0,0 +1,68 @@ + + +
{ + // console.log('submit'); + saveHandler(); + }} +> +
+
+
General Settings
+ +
+
Scan for documents from '/data/docs'
+ + +
+
+
+ + +
diff --git a/src/lib/components/documents/SettingsModal.svelte b/src/lib/components/documents/SettingsModal.svelte new file mode 100644 index 00000000..332bf9d5 --- /dev/null +++ b/src/lib/components/documents/SettingsModal.svelte @@ -0,0 +1,86 @@ + + + +
+
+
Document Settings
+ +
+
+ +
+
+ +
+
+ {#if selectedTab === 'general'} + { + show = false; + }} + /> + + + {/if} +
+
+
+
diff --git a/src/routes/(app)/documents/+page.svelte b/src/routes/(app)/documents/+page.svelte index a5653483..ab3d6553 100644 --- a/src/routes/(app)/documents/+page.svelte +++ b/src/routes/(app)/documents/+page.svelte @@ -13,6 +13,7 @@ import EditDocModal from '$lib/components/documents/EditDocModal.svelte'; import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte'; + import SettingsModal from '$lib/components/documents/SettingsModal.svelte'; let importFiles = ''; let inputFiles = ''; @@ -20,6 +21,7 @@ let tags = []; + let showSettingsModal = false; let showEditDocModal = false; let selectedDoc; let selectedTag = ''; @@ -179,11 +181,38 @@ }} /> + +
My Documents
+ +
+ +
@@ -419,6 +448,8 @@
+ +
{/each} {#if $documents.length > 0}