feat: rag folder scan support

This commit is contained in:
Timothy J. Baek 2024-02-17 21:06:08 -08:00
parent 9f869f6573
commit e07001e5f6
9 changed files with 350 additions and 12 deletions

View file

@ -10,6 +10,8 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from pathlib import Path
from typing import List
# from chromadb.utils import embedding_functions
@ -28,19 +30,39 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import Chroma
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import json
import time
from utils.misc import calculate_sha256, calculate_sha256_string
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import (
UPLOAD_DIR,
DOCS_DIR,
EMBED_MODEL,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
)
from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
@ -220,8 +242,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_loader(file, file_path):
file_ext = file.filename.split(".")[-1].lower()
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
known_source_ext = [
@ -279,20 +301,20 @@ def get_loader(file, file_path):
loader = UnstructuredXMLLoader(file_path)
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file.content_type == "application/epub+zip":
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file.content_type
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext in ["doc", "docx"]
):
loader = Docx2txtLoader(file_path)
elif file.content_type in [
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
loader = TextLoader(file_path)
else:
loader = TextLoader(file_path)
@ -323,7 +345,7 @@ def store_doc(
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(file, file_path)
loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
@ -353,6 +375,61 @@ def store_doc(
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
try:
for path in Path(DOCS_DIR).rglob("./**/*"):
if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path)
filename = path.name
file_content_type = mimetypes.guess_type(path)
f = open(path, "rb")
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(filename, file_content_type, str(path))
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
print(e)
return True
@app.get("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()

View file

@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel):
name: str
tags: List[dict]