forked from open-webui/open-webui
feat: rag folder scan support
This commit is contained in:
parent
9f869f6573
commit
e07001e5f6
9 changed files with 350 additions and 12 deletions
|
@ -10,6 +10,8 @@ from fastapi import (
|
|||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import os, shutil
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# from chromadb.utils import embedding_functions
|
||||
|
@ -28,19 +30,39 @@ from langchain_community.document_loaders import (
|
|||
UnstructuredExcelLoader,
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
import mimetypes
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
|
||||
from utils.misc import calculate_sha256, calculate_sha256_string
|
||||
|
||||
from apps.web.models.documents import (
|
||||
Documents,
|
||||
DocumentForm,
|
||||
DocumentResponse,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
calculate_sha256_string,
|
||||
sanitize_filename,
|
||||
extract_folders_after_data_docs,
|
||||
)
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||
from config import (
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
EMBED_MODEL,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
|
@ -220,8 +242,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
|
||||
def get_loader(file, file_path):
|
||||
file_ext = file.filename.split(".")[-1].lower()
|
||||
def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
known_type = True
|
||||
|
||||
known_source_ext = [
|
||||
|
@ -279,20 +301,20 @@ def get_loader(file, file_path):
|
|||
loader = UnstructuredXMLLoader(file_path)
|
||||
elif file_ext == "md":
|
||||
loader = UnstructuredMarkdownLoader(file_path)
|
||||
elif file.content_type == "application/epub+zip":
|
||||
elif file_content_type == "application/epub+zip":
|
||||
loader = UnstructuredEPubLoader(file_path)
|
||||
elif (
|
||||
file.content_type
|
||||
file_content_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
or file_ext in ["doc", "docx"]
|
||||
):
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file.content_type in [
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
|
||||
elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
|
||||
loader = TextLoader(file_path)
|
||||
else:
|
||||
loader = TextLoader(file_path)
|
||||
|
@ -323,7 +345,7 @@ def store_doc(
|
|||
collection_name = calculate_sha256(f)[:63]
|
||||
f.close()
|
||||
|
||||
loader, known_type = get_loader(file, file_path)
|
||||
loader, known_type = get_loader(file.filename, file.content_type, file_path)
|
||||
data = loader.load()
|
||||
result = store_data_in_vector_db(data, collection_name)
|
||||
|
||||
|
@ -353,6 +375,61 @@ def store_doc(
|
|||
)
|
||||
|
||||
|
||||
@app.get("/scan")
|
||||
def scan_docs_dir(user=Depends(get_admin_user)):
|
||||
try:
|
||||
for path in Path(DOCS_DIR).rglob("./**/*"):
|
||||
if path.is_file() and not path.name.startswith("."):
|
||||
tags = extract_folders_after_data_docs(path)
|
||||
filename = path.name
|
||||
file_content_type = mimetypes.guess_type(path)
|
||||
|
||||
f = open(path, "rb")
|
||||
collection_name = calculate_sha256(f)[:63]
|
||||
f.close()
|
||||
|
||||
loader, known_type = get_loader(filename, file_content_type, str(path))
|
||||
data = loader.load()
|
||||
|
||||
result = store_data_in_vector_db(data, collection_name)
|
||||
|
||||
if result:
|
||||
sanitized_filename = sanitize_filename(filename)
|
||||
doc = Documents.get_doc_by_name(sanitized_filename)
|
||||
|
||||
if doc == None:
|
||||
doc = Documents.insert_new_doc(
|
||||
user.id,
|
||||
DocumentForm(
|
||||
**{
|
||||
"name": sanitized_filename,
|
||||
"title": filename,
|
||||
"collection_name": collection_name,
|
||||
"filename": filename,
|
||||
"content": (
|
||||
json.dumps(
|
||||
{
|
||||
"tags": list(
|
||||
map(
|
||||
lambda name: {"name": name},
|
||||
tags,
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
if len(tags)
|
||||
else "{}"
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/reset/db")
|
||||
def reset_vector_db(user=Depends(get_admin_user)):
|
||||
CHROMA_CLIENT.reset()
|
||||
|
|
|
@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
|
|||
############################
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TagDocumentForm(BaseModel):
|
||||
name: str
|
||||
tags: List[dict]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue