feat: rag backend auth

This commit is contained in:
Timothy J. Baek 2024-01-07 02:46:12 -08:00
parent c43df8850f
commit 70d2571be1

View file

@ -24,6 +24,8 @@ from typing import Optional
import uuid import uuid
from utils.utils import get_current_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -84,7 +86,12 @@ async def get_status():
@app.get("/query/{collection_name}") @app.get("/query/{collection_name}")
def query_collection(collection_name: str, query: str, k: Optional[int] = 4): def query_collection(
collection_name: str,
query: str,
k: Optional[int] = 4,
user=Depends(get_current_user),
):
try: try:
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=collection_name,
@ -101,7 +108,7 @@ def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
@app.post("/web") @app.post("/web")
def store_web(form_data: StoreWebForm): def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = WebBaseLoader(form_data.url) loader = WebBaseLoader(form_data.url)
@ -117,7 +124,11 @@ def store_web(form_data: StoreWebForm):
@app.post("/doc") @app.post("/doc")
def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)): def store_doc(
collection_name: str = Form(...),
file: UploadFile = File(...),
user=Depends(get_current_user),
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
file.filename = f"{collection_name}-{file.filename}" file.filename = f"{collection_name}-{file.filename}"
@ -159,26 +170,38 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
@app.get("/reset/db") @app.get("/reset/db")
def reset_vector_db(): def reset_vector_db(user=Depends(get_current_user)):
CHROMA_CLIENT.reset() if user.role == "admin":
CHROMA_CLIENT.reset()
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@app.get("/reset") @app.get("/reset")
def reset(): def reset(user=Depends(get_current_user)):
folder = f"{UPLOAD_DIR}" if user.role == "admin":
for filename in os.listdir(folder): folder = f"{UPLOAD_DIR}"
file_path = os.path.join(folder, filename) for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e))
try: try:
if os.path.isfile(file_path) or os.path.islink(file_path): CHROMA_CLIENT.reset()
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e: except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e)) print(e)
try: return {"status": True}
CHROMA_CLIENT.reset() else:
except Exception as e: raise HTTPException(
print(e) status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
return {"status": True} )