from fastapi import ( FastAPI, Request, Depends, HTTPException, status, UploadFile, File, Form, ) from fastapi.middleware.cors import CORSMiddleware import os, shutil from typing import List # from chromadb.utils import embedding_functions from langchain_community.document_loaders import ( WebBaseLoader, TextLoader, PyPDFLoader, CSVLoader, Docx2txtLoader, UnstructuredEPubLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, UnstructuredXMLLoader, UnstructuredRSTLoader, UnstructuredExcelLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA from pydantic import BaseModel from typing import Optional import uuid import time from utils.misc import calculate_sha256, calculate_sha256_string from utils.utils import get_current_user from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from constants import ERROR_MESSAGES # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( # model_name=EMBED_MODEL # ) app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class CollectionNameForm(BaseModel): collection_name: Optional[str] = "test" class StoreWebForm(CollectionNameForm): url: str def store_data_in_vector_db(data, collection_name) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP ) docs = text_splitter.split_documents(data) texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] try: collection = CHROMA_CLIENT.create_collection(name=collection_name) collection.add( documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] ) return True except Exception as e: print(e) if e.__class__.__name__ == "UniqueConstraintError": return True return False @app.get("/") async def get_status(): return {"status": True} class QueryCollectionForm(BaseModel): collection_name: str query: str k: Optional[int] = 4 @app.post("/query/collection") def query_collection( form_data: QueryCollectionForm, user=Depends(get_current_user), ): try: collection = CHROMA_CLIENT.get_collection( name=form_data.collection_name, ) result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) class QueryCollectionsForm(BaseModel): collection_names: List[str] query: str k: Optional[int] = 4 @app.post("/query/collections") def query_collections( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): results = [] for collection_name in form_data.collection_names: try: collection = CHROMA_CLIENT.get_collection( name=collection_name, ) result = collection.query( query_texts=[form_data.query], n_results=form_data.k ) results.append(result) except: pass return results @app.post("/web") def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" try: loader = WebBaseLoader(form_data.url) data = loader.load() collection_name = form_data.collection_name if collection_name == "": collection_name = calculate_sha256_string(form_data.url)[:63] store_data_in_vector_db(data, collection_name) return { "status": True, "collection_name": collection_name, "filename": form_data.url, } except Exception as e: print(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) def get_loader(file, file_path): file_ext = file.filename.split(".")[-1].lower() known_type = True known_source_ext = [ "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", "css", "cpp", "hpp", "h", "c", "cs", "sql", "log", "ini", "pl", "pm", "r", "dart", "dockerfile", "env", "php", "hs", "hsc", "lua", "nginxconf", "conf", "m", "mm", "plsql", "perl", "rb", "rs", "db2", "scala", "bash", "swift", "vue", "svelte", ] if file_ext == "pdf": loader = PyPDFLoader(file_path) elif file_ext == "csv": loader = CSVLoader(file_path) elif file_ext == "rst": loader = UnstructuredRSTLoader(file_path, mode="elements") elif file_ext == "xml": loader = UnstructuredXMLLoader(file_path) elif file_ext == "md": loader = UnstructuredMarkdownLoader(file_path) elif file.content_type == "application/epub+zip": loader = UnstructuredEPubLoader(file_path) elif ( file.content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext in ["doc", "docx"] ): loader = Docx2txtLoader(file_path) elif file.content_type in [ "application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ] or file_ext in ["xls", "xlsx"]: loader = UnstructuredExcelLoader(file_path) elif file_ext in known_source_ext or file.content_type.find("text/") >= 0: loader = TextLoader(file_path) else: loader = TextLoader(file_path) known_type = False return loader, known_type @app.post("/doc") def store_doc( collection_name: Optional[str] = Form(None), file: UploadFile = File(...), user=Depends(get_current_user), ): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" print(file.content_type) try: filename = file.filename file_path = f"{UPLOAD_DIR}/{filename}" contents = file.file.read() with open(file_path, "wb") as f: f.write(contents) f.close() f = open(file_path, "rb") if collection_name == None: collection_name = calculate_sha256(f)[:63] f.close() loader, known_type = get_loader(file, file_path) data = loader.load() result = store_data_in_vector_db(data, collection_name) if result: return { "status": True, "collection_name": collection_name, "filename": filename, "known_type": known_type, } else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=ERROR_MESSAGES.DEFAULT(), ) except Exception as e: print(e) if "No pandoc was found" in str(e): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) @app.get("/reset/db") def reset_vector_db(user=Depends(get_current_user)): if user.role == "admin": CHROMA_CLIENT.reset() else: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) @app.get("/reset") def reset(user=Depends(get_current_user)) -> bool: if user.role == "admin": folder = f"{UPLOAD_DIR}" for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print("Failed to delete %s. Reason: %s" % (file_path, e)) try: CHROMA_CLIENT.reset() except Exception as e: print(e) return True else: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, )