Merge branch 'main' into choose-embedding-model

This commit is contained in:
Jannik S 2024-02-18 09:20:54 +01:00 committed by GitHub
commit 4b88e7e44f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 771 additions and 35 deletions

View file

@ -9,6 +9,8 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from pathlib import Path
from typing import List
from chromadb.utils import embedding_functions
@ -28,23 +30,45 @@ from langchain_community.document_loaders import (
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from utils.misc import calculate_sha256, calculate_sha256_string
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import (
UPLOAD_DIR,
DOCS_DIR,
SENTENCE_TRANSFORMER_EMBED_MODEL,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
origins = ["*"]
app.add_middleware(
@ -66,7 +90,7 @@ class StoreWebForm(CollectionNameForm):
def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
@ -96,7 +120,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/")
async def get_status():
return {"status": True}
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
@app.post("/chunk/update")
async def update_chunk_params(
form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk_overlap
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
class RAGTemplateForm(BaseModel):
template: str
@app.post("/template/update")
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
# TODO: check template requirements
app.state.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else RAG_TEMPLATE
)
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
@ -239,8 +316,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_loader(file, file_path):
file_ext = file.filename.split(".")[-1].lower()
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
known_source_ext = [
@ -298,20 +375,20 @@ def get_loader(file, file_path):
loader = UnstructuredXMLLoader(file_path)
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file.content_type == "application/epub+zip":
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file.content_type
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext in ["doc", "docx"]
):
loader = Docx2txtLoader(file_path)
elif file.content_type in [
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
loader = TextLoader(file_path)
else:
loader = TextLoader(file_path)
@ -342,7 +419,7 @@ def store_doc(
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(file, file_path)
loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
@ -372,6 +449,63 @@ def store_doc(
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
try:
for path in Path(DOCS_DIR).rglob("./**/*"):
if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path)
filename = path.name
file_content_type = mimetypes.guess_type(path)
f = open(path, "rb")
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
filename, file_content_type[0], str(path)
)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
print(e)
return True
@app.get("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()

View file

@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel):
name: str
tags: List[dict]