forked from open-webui/open-webui
		
	Merge Updates & Dockerfile improvements
This commit is contained in:
		
							parent
							
								
									fdef2abdfb
								
							
						
					
					
						commit
						9763d885be
					
				
					 155 changed files with 14509 additions and 4803 deletions
				
			
		|  | @ -8,7 +8,7 @@ from fastapi import ( | |||
|     Form, | ||||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| import os, shutil | ||||
| import os, shutil, logging | ||||
| 
 | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
|  | @ -21,6 +21,7 @@ from langchain_community.document_loaders import ( | |||
|     TextLoader, | ||||
|     PyPDFLoader, | ||||
|     CSVLoader, | ||||
|     BSHTMLLoader, | ||||
|     Docx2txtLoader, | ||||
|     UnstructuredEPubLoader, | ||||
|     UnstructuredWordDocumentLoader, | ||||
|  | @ -54,6 +55,7 @@ from utils.misc import ( | |||
| ) | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|  | @ -66,6 +68,9 @@ from config import ( | |||
| 
 | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| # | ||||
| # if RAG_EMBEDDING_MODEL: | ||||
| #    sentence_transformer_ef = SentenceTransformer( | ||||
|  | @ -111,39 +116,6 @@ class StoreWebForm(CollectionNameForm): | |||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     print(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return { | ||||
|  | @ -273,7 +245,7 @@ def query_doc_handler( | |||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|  | @ -317,13 +289,69 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | |||
|             "filename": form_data.url, | ||||
|         } | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: | ||||
| 
 | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.split_documents(data) | ||||
| 
 | ||||
|     if len(docs) > 0: | ||||
|         return store_docs_in_vector_db(docs, collection_name, overwrite), None | ||||
|     else: | ||||
|         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) | ||||
| 
 | ||||
| 
 | ||||
| def store_text_in_vector_db( | ||||
|     text, metadata, collection_name, overwrite: bool = False | ||||
| ) -> bool: | ||||
|     text_splitter = RecursiveCharacterTextSplitter( | ||||
|         chunk_size=app.state.CHUNK_SIZE, | ||||
|         chunk_overlap=app.state.CHUNK_OVERLAP, | ||||
|         add_start_index=True, | ||||
|     ) | ||||
|     docs = text_splitter.create_documents([text], metadatas=[metadata]) | ||||
|     return store_docs_in_vector_db(docs, collection_name, overwrite) | ||||
| 
 | ||||
| 
 | ||||
| def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|         if overwrite: | ||||
|             for collection in CHROMA_CLIENT.list_collections(): | ||||
|                 if collection_name == collection.name: | ||||
|                     log.info(f"deleting existing collection {collection_name}") | ||||
|                     CHROMA_CLIENT.delete_collection(name=collection_name) | ||||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection( | ||||
|             name=collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
| 
 | ||||
|         collection.add( | ||||
|             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] | ||||
|         ) | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         if e.__class__.__name__ == "UniqueConstraintError": | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def get_loader(filename: str, file_content_type: str, file_path: str): | ||||
|     file_ext = filename.split(".")[-1].lower() | ||||
|     known_type = True | ||||
|  | @ -381,6 +409,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): | |||
|         loader = UnstructuredRSTLoader(file_path, mode="elements") | ||||
|     elif file_ext == "xml": | ||||
|         loader = UnstructuredXMLLoader(file_path) | ||||
|     elif file_ext in ["htm", "html"]: | ||||
|         loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") | ||||
|     elif file_ext == "md": | ||||
|         loader = UnstructuredMarkdownLoader(file_path) | ||||
|     elif file_content_type == "application/epub+zip": | ||||
|  | @ -399,9 +429,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): | |||
|     elif file_ext in known_source_ext or ( | ||||
|         file_content_type and file_content_type.find("text/") >= 0 | ||||
|     ): | ||||
|         loader = TextLoader(file_path) | ||||
|         loader = TextLoader(file_path, autodetect_encoding=True) | ||||
|     else: | ||||
|         loader = TextLoader(file_path) | ||||
|         loader = TextLoader(file_path, autodetect_encoding=True) | ||||
|         known_type = False | ||||
| 
 | ||||
|     return loader, known_type | ||||
|  | @ -415,7 +445,7 @@ def store_doc( | |||
| ): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
| 
 | ||||
|     print(file.content_type) | ||||
|     log.info(f"file.content_type: {file.content_type}") | ||||
|     try: | ||||
|         filename = file.filename | ||||
|         file_path = f"{UPLOAD_DIR}/{filename}" | ||||
|  | @ -431,22 +461,24 @@ def store_doc( | |||
| 
 | ||||
|         loader, known_type = get_loader(file.filename, file.content_type, file_path) | ||||
|         data = loader.load() | ||||
|         result = store_data_in_vector_db(data, collection_name) | ||||
| 
 | ||||
|         if result: | ||||
|             return { | ||||
|                 "status": True, | ||||
|                 "collection_name": collection_name, | ||||
|                 "filename": filename, | ||||
|                 "known_type": known_type, | ||||
|             } | ||||
|         else: | ||||
|         try: | ||||
|             result = store_data_in_vector_db(data, collection_name) | ||||
| 
 | ||||
|             if result: | ||||
|                 return { | ||||
|                     "status": True, | ||||
|                     "collection_name": collection_name, | ||||
|                     "filename": filename, | ||||
|                     "known_type": known_type, | ||||
|                 } | ||||
|         except Exception as e: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|                 detail=ERROR_MESSAGES.DEFAULT(), | ||||
|                 detail=e, | ||||
|             ) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
|         if "No pandoc was found" in str(e): | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  | @ -459,6 +491,37 @@ def store_doc( | |||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| class TextRAGForm(BaseModel): | ||||
|     name: str | ||||
|     content: str | ||||
|     collection_name: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/text") | ||||
| def store_text( | ||||
|     form_data: TextRAGForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
| 
 | ||||
|     collection_name = form_data.collection_name | ||||
|     if collection_name == None: | ||||
|         collection_name = calculate_sha256_string(form_data.content) | ||||
| 
 | ||||
|     result = store_text_in_vector_db( | ||||
|         form_data.content, | ||||
|         metadata={"name": form_data.name, "created_by": user.id}, | ||||
|         collection_name=collection_name, | ||||
|     ) | ||||
| 
 | ||||
|     if result: | ||||
|         return {"status": True, "collection_name": collection_name} | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/scan") | ||||
| def scan_docs_dir(user=Depends(get_admin_user)): | ||||
|     for path in Path(DOCS_DIR).rglob("./**/*"): | ||||
|  | @ -477,41 +540,45 @@ def scan_docs_dir(user=Depends(get_admin_user)): | |||
|                 ) | ||||
|                 data = loader.load() | ||||
| 
 | ||||
|                 result = store_data_in_vector_db(data, collection_name) | ||||
|                 try: | ||||
|                     result = store_data_in_vector_db(data, collection_name) | ||||
| 
 | ||||
|                 if result: | ||||
|                     sanitized_filename = sanitize_filename(filename) | ||||
|                     doc = Documents.get_doc_by_name(sanitized_filename) | ||||
|                     if result: | ||||
|                         sanitized_filename = sanitize_filename(filename) | ||||
|                         doc = Documents.get_doc_by_name(sanitized_filename) | ||||
| 
 | ||||
|                     if doc == None: | ||||
|                         doc = Documents.insert_new_doc( | ||||
|                             user.id, | ||||
|                             DocumentForm( | ||||
|                                 **{ | ||||
|                                     "name": sanitized_filename, | ||||
|                                     "title": filename, | ||||
|                                     "collection_name": collection_name, | ||||
|                                     "filename": filename, | ||||
|                                     "content": ( | ||||
|                                         json.dumps( | ||||
|                                             { | ||||
|                                                 "tags": list( | ||||
|                                                     map( | ||||
|                                                         lambda name: {"name": name}, | ||||
|                                                         tags, | ||||
|                         if doc == None: | ||||
|                             doc = Documents.insert_new_doc( | ||||
|                                 user.id, | ||||
|                                 DocumentForm( | ||||
|                                     **{ | ||||
|                                         "name": sanitized_filename, | ||||
|                                         "title": filename, | ||||
|                                         "collection_name": collection_name, | ||||
|                                         "filename": filename, | ||||
|                                         "content": ( | ||||
|                                             json.dumps( | ||||
|                                                 { | ||||
|                                                     "tags": list( | ||||
|                                                         map( | ||||
|                                                             lambda name: {"name": name}, | ||||
|                                                             tags, | ||||
|                                                         ) | ||||
|                                                     ) | ||||
|                                                 ) | ||||
|                                             } | ||||
|                                         ) | ||||
|                                         if len(tags) | ||||
|                                         else "{}" | ||||
|                                     ), | ||||
|                                 } | ||||
|                             ), | ||||
|                         ) | ||||
|                                                 } | ||||
|                                             ) | ||||
|                                             if len(tags) | ||||
|                                             else "{}" | ||||
|                                         ), | ||||
|                                     } | ||||
|                                 ), | ||||
|                             ) | ||||
|                 except Exception as e: | ||||
|                     log.exception(e) | ||||
|                     pass | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
| 
 | ||||
|     return True | ||||
| 
 | ||||
|  | @ -532,11 +599,11 @@ def reset(user=Depends(get_admin_user)) -> bool: | |||
|             elif os.path.isdir(file_path): | ||||
|                 shutil.rmtree(file_path) | ||||
|         except Exception as e: | ||||
|             print("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||
|             log.error("Failed to delete %s. Reason: %s" % (file_path, e)) | ||||
| 
 | ||||
|     try: | ||||
|         CHROMA_CLIENT.reset() | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         log.exception(e) | ||||
| 
 | ||||
|     return True | ||||
|  |  | |||
|  | @ -1,7 +1,11 @@ | |||
| import re | ||||
| import logging | ||||
| from typing import List | ||||
| 
 | ||||
| from config import CHROMA_CLIENT | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| 
 | ||||
| def query_doc(collection_name: str, query: str, k: int, embedding_function): | ||||
|  | @ -91,14 +95,13 @@ def query_collection( | |||
| 
 | ||||
| 
 | ||||
| def rag_template(template: str, context: str, query: str): | ||||
|     template = re.sub(r"\[context\]", context, template) | ||||
|     template = re.sub(r"\[query\]", query, template) | ||||
| 
 | ||||
|     template = template.replace("[context]", context) | ||||
|     template = template.replace("[query]", query) | ||||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages(docs, messages, template, k, embedding_function): | ||||
|     print(docs) | ||||
|     log.debug(f"docs: {docs}") | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|     for i in range(len(messages) - 1, -1, -1): | ||||
|  | @ -138,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|             elif doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 context = query_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|  | @ -146,11 +151,13 @@ def rag_messages(docs, messages, template, k, embedding_function): | |||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             log.exception(e) | ||||
|             context = None | ||||
| 
 | ||||
|         relevant_contexts.append(context) | ||||
| 
 | ||||
|     log.debug(f"relevant_contexts: {relevant_contexts}") | ||||
| 
 | ||||
|     context_string = "" | ||||
|     for context in relevant_contexts: | ||||
|         if context: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 lainedfles
						lainedfles