forked from open-webui/open-webui
Merge pull request #682 from explorigin/simplify-endpoint-code
Simplify endpoint role checking
This commit is contained in:
commit
9f3346a6ec
11 changed files with 127 additions and 251 deletions
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import FastAPI, Request, Response, HTTPException, Depends
|
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
from utils.utils import decode_token, get_current_user
|
from utils.utils import decode_token, get_current_user, get_admin_user
|
||||||
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
|
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
@ -31,11 +31,8 @@ REQUEST_POOL = []
|
||||||
|
|
||||||
|
|
||||||
@app.get("/url")
|
@app.get("/url")
|
||||||
async def get_ollama_api_url(user=Depends(get_current_user)):
|
async def get_ollama_api_url(user=Depends(get_admin_user)):
|
||||||
if user and user.role == "admin":
|
|
||||||
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
class UrlUpdateForm(BaseModel):
|
class UrlUpdateForm(BaseModel):
|
||||||
|
@ -44,13 +41,10 @@ class UrlUpdateForm(BaseModel):
|
||||||
|
|
||||||
@app.post("/url/update")
|
@app.post("/url/update")
|
||||||
async def update_ollama_api_url(
|
async def update_ollama_api_url(
|
||||||
form_data: UrlUpdateForm, user=Depends(get_current_user)
|
form_data: UrlUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if user and user.role == "admin":
|
|
||||||
app.state.OLLAMA_API_BASE_URL = form_data.url
|
app.state.OLLAMA_API_BASE_URL = form_data.url
|
||||||
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/cancel/{request_id}")
|
@app.get("/cancel/{request_id}")
|
||||||
|
@ -74,10 +68,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
||||||
if path in ["pull", "delete", "push", "copy", "create"]:
|
if path in ["pull", "delete", "push", "copy", "create"]:
|
||||||
if user.role != "admin":
|
if user.role != "admin":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||||
|
|
||||||
headers.pop("host", None)
|
headers.pop("host", None)
|
||||||
headers.pop("authorization", None)
|
headers.pop("authorization", None)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
from utils.utils import decode_token, get_current_user
|
from utils.utils import decode_token, get_current_user, get_verified_user, get_admin_user
|
||||||
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
|
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -37,45 +37,32 @@ class KeyUpdateForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@app.get("/url")
|
@app.get("/url")
|
||||||
async def get_openai_url(user=Depends(get_current_user)):
|
async def get_openai_url(user=Depends(get_admin_user)):
|
||||||
if user and user.role == "admin":
|
|
||||||
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
|
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/url/update")
|
@app.post("/url/update")
|
||||||
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
|
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
|
||||||
if user and user.role == "admin":
|
|
||||||
app.state.OPENAI_API_BASE_URL = form_data.url
|
app.state.OPENAI_API_BASE_URL = form_data.url
|
||||||
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
|
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/key")
|
@app.get("/key")
|
||||||
async def get_openai_key(user=Depends(get_current_user)):
|
async def get_openai_key(user=Depends(get_admin_user)):
|
||||||
if user and user.role == "admin":
|
|
||||||
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
|
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/key/update")
|
@app.post("/key/update")
|
||||||
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
|
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
|
||||||
if user and user.role == "admin":
|
|
||||||
app.state.OPENAI_API_KEY = form_data.key
|
app.state.OPENAI_API_KEY = form_data.key
|
||||||
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
|
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/audio/speech")
|
@app.post("/audio/speech")
|
||||||
async def speech(request: Request, user=Depends(get_current_user)):
|
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||||
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
|
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
|
||||||
|
|
||||||
if user.role not in ["user", "admin"]:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
if app.state.OPENAI_API_KEY == "":
|
if app.state.OPENAI_API_KEY == "":
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
|
@ -133,12 +120,10 @@ async def speech(request: Request, user=Depends(get_current_user)):
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
|
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||||
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
|
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
|
||||||
print(target_url, app.state.OPENAI_API_KEY)
|
print(target_url, app.state.OPENAI_API_KEY)
|
||||||
|
|
||||||
if user.role not in ["user", "admin"]:
|
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
||||||
if app.state.OPENAI_API_KEY == "":
|
if app.state.OPENAI_API_KEY == "":
|
||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from utils.misc import calculate_sha256, calculate_sha256_string
|
from utils.misc import calculate_sha256, calculate_sha256_string
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
@ -354,19 +354,12 @@ def store_doc(
|
||||||
|
|
||||||
|
|
||||||
@app.get("/reset/db")
|
@app.get("/reset/db")
|
||||||
def reset_vector_db(user=Depends(get_current_user)):
|
def reset_vector_db(user=Depends(get_admin_user)):
|
||||||
if user.role == "admin":
|
|
||||||
CHROMA_CLIENT.reset()
|
CHROMA_CLIENT.reset()
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/reset")
|
@app.get("/reset")
|
||||||
def reset(user=Depends(get_current_user)) -> bool:
|
def reset(user=Depends(get_admin_user)) -> bool:
|
||||||
if user.role == "admin":
|
|
||||||
folder = f"{UPLOAD_DIR}"
|
folder = f"{UPLOAD_DIR}"
|
||||||
for filename in os.listdir(folder):
|
for filename in os.listdir(folder):
|
||||||
file_path = os.path.join(folder, filename)
|
file_path = os.path.join(folder, filename)
|
||||||
|
@ -384,8 +377,3 @@ def reset(user=Depends(get_current_user)) -> bool:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -19,7 +19,7 @@ from apps.web.models.auths import (
|
||||||
)
|
)
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
|
||||||
from utils.utils import get_password_hash, get_current_user, create_token
|
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
|
||||||
from utils.misc import get_gravatar_url, validate_email_format
|
from utils.misc import get_gravatar_url, validate_email_format
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
@ -116,10 +116,10 @@ async def signin(form_data: SigninForm):
|
||||||
@router.post("/signup", response_model=SigninResponse)
|
@router.post("/signup", response_model=SigninResponse)
|
||||||
async def signup(request: Request, form_data: SignupForm):
|
async def signup(request: Request, form_data: SignupForm):
|
||||||
if not request.app.state.ENABLE_SIGNUP:
|
if not request.app.state.ENABLE_SIGNUP:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||||
|
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
|
||||||
|
|
||||||
if Users.get_user_by_email(form_data.email.lower()):
|
if Users.get_user_by_email(form_data.email.lower()):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
@ -156,23 +156,11 @@ async def signup(request: Request, form_data: SignupForm):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/signup/enabled", response_model=bool)
|
@router.get("/signup/enabled", response_model=bool)
|
||||||
async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
|
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
|
||||||
if user.role == "admin":
|
|
||||||
return request.app.state.ENABLE_SIGNUP
|
return request.app.state.ENABLE_SIGNUP
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/signup/enabled/toggle", response_model=bool)
|
@router.get("/signup/enabled/toggle", response_model=bool)
|
||||||
async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
|
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
||||||
if user.role == "admin":
|
|
||||||
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
|
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
|
||||||
return request.app.state.ENABLE_SIGNUP
|
return request.app.state.ENABLE_SIGNUP
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from fastapi import Depends, Request, HTTPException, status
|
from fastapi import Depends, Request, HTTPException, status
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Union, Optional
|
from typing import List, Union, Optional
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import json
|
import json
|
||||||
|
@ -60,17 +60,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/all/db", response_model=List[ChatResponse])
|
@router.get("/all/db", response_model=List[ChatResponse])
|
||||||
async def get_all_user_chats_in_db(user=Depends(get_current_user)):
|
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||||
if user.role == "admin":
|
|
||||||
return [
|
return [
|
||||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||||
for chat in Chats.get_all_chats()
|
for chat in Chats.get_all_chats()
|
||||||
]
|
]
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
|
@ -10,7 +10,7 @@ import uuid
|
||||||
|
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
|
|
||||||
from utils.utils import get_password_hash, get_current_user, create_token
|
from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
|
||||||
from utils.misc import get_gravatar_url, validate_email_format
|
from utils.misc import get_gravatar_url, validate_email_format
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
@ -37,30 +37,19 @@ class SetDefaultSuggestionsForm(BaseModel):
|
||||||
|
|
||||||
@router.post("/default/models", response_model=str)
|
@router.post("/default/models", response_model=str)
|
||||||
async def set_global_default_models(
|
async def set_global_default_models(
|
||||||
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
|
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if user.role == "admin":
|
|
||||||
request.app.state.DEFAULT_MODELS = form_data.models
|
request.app.state.DEFAULT_MODELS = form_data.models
|
||||||
return request.app.state.DEFAULT_MODELS
|
return request.app.state.DEFAULT_MODELS
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
|
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
|
||||||
async def set_global_default_suggestions(
|
async def set_global_default_suggestions(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: SetDefaultSuggestionsForm,
|
form_data: SetDefaultSuggestionsForm,
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if user.role == "admin":
|
|
||||||
data = form_data.model_dump()
|
data = form_data.model_dump()
|
||||||
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
|
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
|
||||||
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
|
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from apps.web.models.documents import (
|
||||||
DocumentResponse,
|
DocumentResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -44,13 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[DocumentResponse])
|
@router.post("/create", response_model=Optional[DocumentResponse])
|
||||||
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
|
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
doc = Documents.get_doc_by_name(form_data.name)
|
doc = Documents.get_doc_by_name(form_data.name)
|
||||||
if doc == None:
|
if doc == None:
|
||||||
doc = Documents.insert_new_doc(user.id, form_data)
|
doc = Documents.insert_new_doc(user.id, form_data)
|
||||||
|
@ -132,14 +126,8 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
|
||||||
|
|
||||||
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
|
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
|
||||||
async def update_doc_by_name(
|
async def update_doc_by_name(
|
||||||
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
|
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
doc = Documents.update_doc_by_name(name, form_data)
|
doc = Documents.update_doc_by_name(name, form_data)
|
||||||
if doc:
|
if doc:
|
||||||
return DocumentResponse(
|
return DocumentResponse(
|
||||||
|
@ -161,12 +149,6 @@ async def update_doc_by_name(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/name/{name}/delete", response_model=bool)
|
@router.delete("/name/{name}/delete", response_model=bool)
|
||||||
async def delete_doc_by_name(name: str, user=Depends(get_current_user)):
|
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = Documents.delete_doc_by_name(name)
|
result = Documents.delete_doc_by_name(name)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
|
||||||
ModelfileResponse,
|
ModelfileResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[ModelfileResponse])
|
@router.post("/create", response_model=Optional[ModelfileResponse])
|
||||||
async def create_new_modelfile(form_data: ModelfileForm,
|
async def create_new_modelfile(form_data: ModelfileForm,
|
||||||
user=Depends(get_current_user)):
|
user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
|
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
|
||||||
|
|
||||||
if modelfile:
|
if modelfile:
|
||||||
|
@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
|
||||||
|
|
||||||
@router.post("/update", response_model=Optional[ModelfileResponse])
|
@router.post("/update", response_model=Optional[ModelfileResponse])
|
||||||
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
|
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
|
||||||
user=Depends(get_current_user)):
|
user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
|
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
|
||||||
if modelfile:
|
if modelfile:
|
||||||
updated_modelfile = {
|
updated_modelfile = {
|
||||||
|
@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
|
||||||
|
|
||||||
@router.delete("/delete", response_model=bool)
|
@router.delete("/delete", response_model=bool)
|
||||||
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
|
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
|
||||||
user=Depends(get_current_user)):
|
user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
|
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
|
|
||||||
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
|
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
|
||||||
|
|
||||||
from utils.utils import get_current_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create", response_model=Optional[PromptModel])
|
@router.post("/create", response_model=Optional[PromptModel])
|
||||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
|
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||||
if prompt == None:
|
if prompt == None:
|
||||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt
|
return prompt
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.COMMAND_TAKEN,
|
detail=ERROR_MESSAGES.COMMAND_TAKEN,
|
||||||
|
@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||||
|
|
||||||
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
|
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
|
||||||
async def update_prompt_by_command(
|
async def update_prompt_by_command(
|
||||||
command: str, form_data: PromptForm, user=Depends(get_current_user)
|
command: str, form_data: PromptForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||||
if prompt:
|
if prompt:
|
||||||
return prompt
|
return prompt
|
||||||
|
@ -103,12 +89,6 @@ async def update_prompt_by_command(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/command/{command}/delete", response_model=bool)
|
@router.delete("/command/{command}/delete", response_model=bool)
|
||||||
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
|
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -11,7 +11,7 @@ import uuid
|
||||||
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
||||||
from apps.web.models.auths import Auths
|
from apps.web.models.auths import Auths
|
||||||
|
|
||||||
from utils.utils import get_current_user, get_password_hash
|
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -22,12 +22,7 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[UserModel])
|
@router.get("/", response_model=List[UserModel])
|
||||||
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
|
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
return Users.get_users(skip, limit)
|
return Users.get_users(skip, limit)
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,17 +33,11 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
|
||||||
|
|
||||||
@router.post("/update/role", response_model=Optional[UserModel])
|
@router.post("/update/role", response_model=Optional[UserModel])
|
||||||
async def update_user_role(
|
async def update_user_role(
|
||||||
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
|
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user.id != form_data.id:
|
if user.id != form_data.id:
|
||||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||||
|
@ -62,14 +51,8 @@ async def update_user_role(
|
||||||
|
|
||||||
@router.post("/{user_id}/update", response_model=Optional[UserModel])
|
@router.post("/{user_id}/update", response_model=Optional[UserModel])
|
||||||
async def update_user_by_id(
|
async def update_user_by_id(
|
||||||
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
|
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
if session_user.role != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
@ -98,13 +81,12 @@ async def update_user_by_id(
|
||||||
|
|
||||||
if updated_user:
|
if updated_user:
|
||||||
return updated_user
|
return updated_user
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT(),
|
detail=ERROR_MESSAGES.DEFAULT(),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||||
|
@ -117,25 +99,20 @@ async def update_user_by_id(
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", response_model=bool)
|
@router.delete("/{user_id}", response_model=bool)
|
||||||
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
|
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||||
if user.role == "admin":
|
|
||||||
if user.id != user_id:
|
if user.id != user_id:
|
||||||
result = Auths.delete_auth_by_id(user_id)
|
result = Auths.delete_auth_by_id(user_id)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
|
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
|
||||||
)
|
|
||||||
|
|
|
@ -73,3 +73,19 @@ def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(bearer_s
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_verified_user(user: Users = Depends(get_current_user)):
|
||||||
|
if user.role not in {"user", "admin"}:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_admin_user(user: Users = Depends(get_current_user)):
|
||||||
|
if user.role != "admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue