Merge branch 'main' into choose-embedding-model

This commit is contained in:
Jannik S 2024-02-18 09:20:54 +01:00 committed by GitHub
commit 4b88e7e44f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 771 additions and 35 deletions

View file

@ -9,6 +9,8 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from pathlib import Path
from typing import List
from chromadb.utils import embedding_functions
@ -28,23 +30,45 @@ from langchain_community.document_loaders import (
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from utils.misc import calculate_sha256, calculate_sha256_string
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import (
UPLOAD_DIR,
DOCS_DIR,
SENTENCE_TRANSFORMER_EMBED_MODEL,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
origins = ["*"]
app.add_middleware(
@ -66,7 +90,7 @@ class StoreWebForm(CollectionNameForm):
def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
@ -96,7 +120,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/")
async def get_status():
return {"status": True}
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
@app.post("/chunk/update")
async def update_chunk_params(
form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk_overlap
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
class RAGTemplateForm(BaseModel):
template: str
@app.post("/template/update")
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
# TODO: check template requirements
app.state.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else RAG_TEMPLATE
)
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
@ -239,8 +316,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_loader(file, file_path):
file_ext = file.filename.split(".")[-1].lower()
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
known_source_ext = [
@ -298,20 +375,20 @@ def get_loader(file, file_path):
loader = UnstructuredXMLLoader(file_path)
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file.content_type == "application/epub+zip":
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file.content_type
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext in ["doc", "docx"]
):
loader = Docx2txtLoader(file_path)
elif file.content_type in [
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
loader = TextLoader(file_path)
else:
loader = TextLoader(file_path)
@ -342,7 +419,7 @@ def store_doc(
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(file, file_path)
loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
@ -372,6 +449,63 @@ def store_doc(
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
try:
for path in Path(DOCS_DIR).rglob("./**/*"):
if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path)
filename = path.name
file_content_type = mimetypes.guess_type(path)
f = open(path, "rb")
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
filename, file_content_type[0], str(path)
)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
print(e)
return True
@app.get("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()

View file

@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel):
name: str
tags: List[dict]

View file

@ -43,6 +43,14 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Docs DIR
####################################
DOCS_DIR = f"{DATA_DIR}/docs"
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# OLLAMA_API_BASE_URL
####################################
@ -137,6 +145,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
####################################
# Transcribe
####################################

32
backend/start_windows.bat Normal file
View file

@ -0,0 +1,32 @@
:: This method is not recommended, and we recommend you use the `start.sh` file with WSL instead.
@echo off
SETLOCAL ENABLEDELAYEDEXPANSION
:: Get the directory of the current script
SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
SET "PORT=%PORT:8080%"
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
:: Check if WEBUI_SECRET_KEY and WEBUI_JWT_SECRET_KEY are not set
IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
echo No WEBUI_SECRET_KEY provided
IF NOT EXIST "%KEY_FILE%" (
echo Generating WEBUI_SECRET_KEY
:: Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one
SET /p WEBUI_SECRET_KEY=<nul
FOR /L %%i IN (1,1,12) DO SET /p WEBUI_SECRET_KEY=<!random!>>%KEY_FILE%
echo WEBUI_SECRET_KEY generated
)
echo Loading WEBUI_SECRET_KEY from %KEY_FILE%
SET /p WEBUI_SECRET_KEY=<%KEY_FILE%
)
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'

View file

@ -1,3 +1,4 @@
from pathlib import Path
import hashlib
import re
@ -38,3 +39,40 @@ def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False
return True
def sanitize_filename(file_name):
# Convert to lowercase
lower_case_file_name = file_name.lower()
# Remove special characters using regular expression
sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
# Replace spaces with dashes
final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
return final_file_name
def extract_folders_after_data_docs(path):
# Convert the path to a Path object if it's not already
path = Path(path)
# Extract parts of the path
parts = path.parts
# Find the index of '/data/docs' in the path
try:
index_data_docs = parts.index("data") + 1
index_docs = parts.index("docs", index_data_docs) + 1
except ValueError:
return []
# Exclude the filename and accumulate folder names
tags = []
folders = parts[index_docs:-1]
for idx, part in enumerate(folders):
tags.append("/".join(folders[: idx + 1]))
return tags