forked from open-webui/open-webui
		
	Merge branch 'main' into choose-embedding-model
This commit is contained in:
		
						commit
						4b88e7e44f
					
				
					 16 changed files with 771 additions and 35 deletions
				
			
		|  | @ -9,6 +9,8 @@ from fastapi import ( | |||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| import os, shutil | ||||
| 
 | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
| 
 | ||||
| from chromadb.utils import embedding_functions | ||||
|  | @ -28,23 +30,45 @@ from langchain_community.document_loaders import ( | |||
| ) | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import Optional | ||||
| 
 | ||||
| import mimetypes | ||||
| import uuid | ||||
| 
 | ||||
| from apps.web.models.documents import ( | ||||
|     Documents, | ||||
|     DocumentForm, | ||||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from utils.misc import calculate_sha256, calculate_sha256_string | ||||
| from utils.misc import ( | ||||
|     calculate_sha256, | ||||
|     calculate_sha256_string, | ||||
|     sanitize_filename, | ||||
|     extract_folders_after_data_docs, | ||||
| ) | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP | ||||
| from config import ( | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     SENTENCE_TRANSFORMER_EMBED_MODEL, | ||||
|     CHROMA_CLIENT, | ||||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|     RAG_TEMPLATE, | ||||
| ) | ||||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| 
 | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| app.add_middleware( | ||||
|  | @ -66,7 +90,7 @@ class StoreWebForm(CollectionNameForm): | |||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP | ||||
|         chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|  | @ -96,7 +120,60 @@ def store_data_in_vector_db(data, collection_name) -> bool: | |||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return {"status": True} | ||||
|     return { | ||||
|         "status": True, | ||||
|         "chunk_size": app.state.CHUNK_SIZE, | ||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/chunk") | ||||
| async def get_chunk_params(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "status": True, | ||||
|         "chunk_size": app.state.CHUNK_SIZE, | ||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class ChunkParamUpdateForm(BaseModel): | ||||
|     chunk_size: int | ||||
|     chunk_overlap: int | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/chunk/update") | ||||
| async def update_chunk_params( | ||||
|     form_data: ChunkParamUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     app.state.CHUNK_SIZE = form_data.chunk_size | ||||
|     app.state.CHUNK_OVERLAP = form_data.chunk_overlap | ||||
| 
 | ||||
|     return { | ||||
|         "status": True, | ||||
|         "chunk_size": app.state.CHUNK_SIZE, | ||||
|         "chunk_overlap": app.state.CHUNK_OVERLAP, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/template") | ||||
| async def get_rag_template(user=Depends(get_current_user)): | ||||
|     return { | ||||
|         "status": True, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class RAGTemplateForm(BaseModel): | ||||
|     template: str | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/template/update") | ||||
| async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): | ||||
|     # TODO: check template requirements | ||||
|     app.state.RAG_TEMPLATE = ( | ||||
|         form_data.template if form_data.template != "" else RAG_TEMPLATE | ||||
|     ) | ||||
|     return {"status": True, "template": app.state.RAG_TEMPLATE} | ||||
| 
 | ||||
| 
 | ||||
| class QueryDocForm(BaseModel): | ||||
|  | @ -239,8 +316,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def get_loader(file, file_path): | ||||
|     file_ext = file.filename.split(".")[-1].lower() | ||||
| def get_loader(filename: str, file_content_type: str, file_path: str): | ||||
|     file_ext = filename.split(".")[-1].lower() | ||||
|     known_type = True | ||||
| 
 | ||||
|     known_source_ext = [ | ||||
|  | @ -298,20 +375,20 @@ def get_loader(file, file_path): | |||
|         loader = UnstructuredXMLLoader(file_path) | ||||
|     elif file_ext == "md": | ||||
|         loader = UnstructuredMarkdownLoader(file_path) | ||||
|     elif file.content_type == "application/epub+zip": | ||||
|     elif file_content_type == "application/epub+zip": | ||||
|         loader = UnstructuredEPubLoader(file_path) | ||||
|     elif ( | ||||
|         file.content_type | ||||
|         file_content_type | ||||
|         == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | ||||
|         or file_ext in ["doc", "docx"] | ||||
|     ): | ||||
|         loader = Docx2txtLoader(file_path) | ||||
|     elif file.content_type in [ | ||||
|     elif file_content_type in [ | ||||
|         "application/vnd.ms-excel", | ||||
|         "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | ||||
|     ] or file_ext in ["xls", "xlsx"]: | ||||
|         loader = UnstructuredExcelLoader(file_path) | ||||
|     elif file_ext in known_source_ext or file.content_type.find("text/") >= 0: | ||||
|     elif file_ext in known_source_ext or file_content_type.find("text/") >= 0: | ||||
|         loader = TextLoader(file_path) | ||||
|     else: | ||||
|         loader = TextLoader(file_path) | ||||
|  | @ -342,7 +419,7 @@ def store_doc( | |||
|             collection_name = calculate_sha256(f)[:63] | ||||
|         f.close() | ||||
| 
 | ||||
|         loader, known_type = get_loader(file, file_path) | ||||
|         loader, known_type = get_loader(file.filename, file.content_type, file_path) | ||||
|         data = loader.load() | ||||
|         result = store_data_in_vector_db(data, collection_name) | ||||
| 
 | ||||
|  | @ -372,6 +449,63 @@ def store_doc( | |||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/scan") | ||||
| def scan_docs_dir(user=Depends(get_admin_user)): | ||||
|     try: | ||||
|         for path in Path(DOCS_DIR).rglob("./**/*"): | ||||
|             if path.is_file() and not path.name.startswith("."): | ||||
|                 tags = extract_folders_after_data_docs(path) | ||||
|                 filename = path.name | ||||
|                 file_content_type = mimetypes.guess_type(path) | ||||
| 
 | ||||
|                 f = open(path, "rb") | ||||
|                 collection_name = calculate_sha256(f)[:63] | ||||
|                 f.close() | ||||
| 
 | ||||
|                 loader, known_type = get_loader( | ||||
|                     filename, file_content_type[0], str(path) | ||||
|                 ) | ||||
|                 data = loader.load() | ||||
| 
 | ||||
|                 result = store_data_in_vector_db(data, collection_name) | ||||
| 
 | ||||
|                 if result: | ||||
|                     sanitized_filename = sanitize_filename(filename) | ||||
|                     doc = Documents.get_doc_by_name(sanitized_filename) | ||||
| 
 | ||||
|                     if doc == None: | ||||
|                         doc = Documents.insert_new_doc( | ||||
|                             user.id, | ||||
|                             DocumentForm( | ||||
|                                 **{ | ||||
|                                     "name": sanitized_filename, | ||||
|                                     "title": filename, | ||||
|                                     "collection_name": collection_name, | ||||
|                                     "filename": filename, | ||||
|                                     "content": ( | ||||
|                                         json.dumps( | ||||
|                                             { | ||||
|                                                 "tags": list( | ||||
|                                                     map( | ||||
|                                                         lambda name: {"name": name}, | ||||
|                                                         tags, | ||||
|                                                     ) | ||||
|                                                 ) | ||||
|                                             } | ||||
|                                         ) | ||||
|                                         if len(tags) | ||||
|                                         else "{}" | ||||
|                                     ), | ||||
|                                 } | ||||
|                             ), | ||||
|                         ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
| 
 | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/reset/db") | ||||
| def reset_vector_db(user=Depends(get_admin_user)): | ||||
|     CHROMA_CLIENT.reset() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jannik S
						Jannik S