forked from open-webui/open-webui
chore: 🚨 lint and format
This commit is contained in:
parent
037793161e
commit
07cc7f15d5
25 changed files with 190 additions and 180 deletions
|
@ -8,7 +8,6 @@ from pydantic import BaseModel
|
|||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
from apps.web.models.auths import (
|
||||
SigninForm,
|
||||
SignupForm,
|
||||
|
@ -19,12 +18,10 @@ from apps.web.models.auths import (
|
|||
)
|
||||
from apps.web.models.users import Users
|
||||
|
||||
|
||||
from utils.utils import get_password_hash, get_current_user, create_token
|
||||
from utils.misc import get_gravatar_url, validate_email_format
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
|
@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/update/password", response_model=bool)
|
||||
async def update_password(
|
||||
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
||||
):
|
||||
async def update_password(form_data: UpdatePasswordForm,
|
||||
session_user=Depends(get_current_user)):
|
||||
if session_user:
|
||||
user = Auths.authenticate_user(session_user.email, form_data.password)
|
||||
|
||||
|
@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
try:
|
||||
role = "admin" if Users.get_num_users() == 0 else "pending"
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(), hashed, form_data.name, role
|
||||
)
|
||||
user = Auths.insert_new_auth(form_data.email.lower(),
|
||||
hashed, form_data.name, role)
|
||||
|
||||
if user:
|
||||
token = create_token(data={"email": user.email})
|
||||
|
@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
|
|||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
||||
)
|
||||
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
raise HTTPException(500,
|
||||
detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
|
||||
raise HTTPException(400,
|
||||
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ from apps.web.models.chats import (
|
|||
)
|
||||
|
||||
from utils.utils import (
|
||||
bearer_scheme,
|
||||
)
|
||||
bearer_scheme, )
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -30,8 +29,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=List[ChatTitleIdResponse])
|
||||
async def get_user_chats(
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||
):
|
||||
user=Depends(get_current_user), skip: int = 0, limit: int = 50):
|
||||
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
|
||||
|
||||
|
||||
|
@ -43,8 +41,9 @@ async def get_user_chats(
|
|||
@router.get("/all", response_model=List[ChatResponse])
|
||||
async def get_all_user_chats(user=Depends(get_current_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_all_chats_by_user_id(user.id)
|
||||
ChatResponse(**{
|
||||
**chat.model_dump(), "chat": json.loads(chat.chat)
|
||||
}) for chat in Chats.get_all_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
|
@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**{
|
||||
**chat.model_dump(), "chat": json.loads(chat.chat)
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND)
|
||||
|
||||
|
||||
############################
|
||||
|
@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/{id}", response_model=Optional[ChatResponse])
|
||||
async def update_chat_by_id(
|
||||
id: str, form_data: ChatForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def update_chat_by_id(id: str,
|
||||
form_data: ChatForm,
|
||||
user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
updated_chat = {**json.loads(chat.chat), **form_data.chat}
|
||||
|
||||
chat = Chats.update_chat_by_id(id, updated_chat)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
return ChatResponse(**{
|
||||
**chat.model_dump(), "chat": json.loads(chat.chat)
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
@ -10,7 +10,6 @@ import uuid
|
|||
|
||||
from apps.web.models.users import Users
|
||||
|
||||
|
||||
from utils.utils import get_password_hash, get_current_user, create_token
|
||||
from utils.misc import get_gravatar_url, validate_email_format
|
||||
from constants import ERROR_MESSAGES
|
||||
|
@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel):
|
|||
|
||||
|
||||
@router.post("/default/models", response_model=str)
|
||||
async def set_global_default_models(
|
||||
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def set_global_default_models(request: Request,
|
||||
form_data: SetDefaultModelsForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role == "admin":
|
||||
request.app.state.DEFAULT_MODELS = form_data.models
|
||||
return request.app.state.DEFAULT_MODELS
|
||||
|
|
|
@ -24,7 +24,9 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=List[ModelfileResponse])
|
||||
async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
|
||||
async def get_modelfiles(skip: int = 0,
|
||||
limit: int = 50,
|
||||
user=Depends(get_current_user)):
|
||||
return Modelfiles.get_modelfiles(skip, limit)
|
||||
|
||||
|
||||
|
@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[ModelfileResponse])
|
||||
async def create_new_modelfile(
|
||||
form_data: ModelfileForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def create_new_modelfile(form_data: ModelfileForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -49,9 +50,9 @@ async def create_new_modelfile(
|
|||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
"modelfile":
|
||||
json.loads(modelfile.modelfile),
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -65,16 +66,17 @@ async def create_new_modelfile(
|
|||
|
||||
|
||||
@router.post("/", response_model=Optional[ModelfileResponse])
|
||||
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
|
||||
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
|
||||
user=Depends(get_current_user)):
|
||||
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
|
||||
|
||||
if modelfile:
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
"modelfile":
|
||||
json.loads(modelfile.modelfile),
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend
|
|||
|
||||
|
||||
@router.post("/update", response_model=Optional[ModelfileResponse])
|
||||
async def update_modelfile_by_tag_name(
|
||||
form_data: ModelfileUpdateForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name(
|
|||
}
|
||||
|
||||
modelfile = Modelfiles.update_modelfile_by_tag_name(
|
||||
form_data.tag_name, updated_modelfile
|
||||
)
|
||||
form_data.tag_name, updated_modelfile)
|
||||
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
"modelfile":
|
||||
json.loads(modelfile.modelfile),
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name(
|
|||
|
||||
|
||||
@router.delete("/delete", response_model=bool)
|
||||
async def delete_modelfile_by_tag_name(
|
||||
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
@ -6,7 +6,6 @@ from fastapi import APIRouter
|
|||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
|
||||
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
|
||||
|
||||
from utils.utils import get_current_user
|
||||
|
@ -30,7 +29,8 @@ async def get_prompts(user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
|
||||
async def create_new_prompt(form_data: PromptForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -79,9 +79,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/{command}/update", response_model=Optional[PromptModel])
|
||||
async def update_prompt_by_command(
|
||||
command: str, form_data: PromptForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def update_prompt_by_command(command: str,
|
||||
form_data: PromptForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -104,7 +104,8 @@ async def update_prompt_by_command(
|
|||
|
||||
|
||||
@router.delete("/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
async def delete_prompt_by_command(command: str,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
@ -11,11 +11,9 @@ import uuid
|
|||
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
|
||||
from apps.web.models.auths import Auths
|
||||
|
||||
|
||||
from utils.utils import get_current_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
|
@ -24,7 +22,9 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=List[UserModel])
|
||||
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
|
||||
async def get_users(skip: int = 0,
|
||||
limit: int = 50,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
@ -39,9 +39,8 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
|
|||
|
||||
|
||||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(
|
||||
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def update_user_role(form_data: UserRoleUpdateForm,
|
||||
user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
|
@ -9,12 +9,10 @@ import os
|
|||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
from utils.misc import calculate_sha256
|
||||
|
||||
from config import OLLAMA_API_BASE_URL
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
|
|||
return None
|
||||
|
||||
|
||||
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
|
||||
async def download_file_stream(url,
|
||||
file_path,
|
||||
file_name,
|
||||
chunk_size=1024 * 1024):
|
||||
done = False
|
||||
|
||||
if os.path.exists(file_path):
|
||||
|
@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
|
|||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
total_size = int(response.headers.get("content-length",
|
||||
0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
async for data in response.content.iter_chunked(chunk_size):
|
||||
|
@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
|
|||
|
||||
|
||||
@router.get("/download")
|
||||
async def download(
|
||||
url: str,
|
||||
):
|
||||
async def download(url: str, ):
|
||||
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
||||
file_name = parse_huggingface_url(url)
|
||||
|
||||
|
@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
|
|||
res = {"error": str(e)}
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
return StreamingResponse(file_write_stream(), media_type="text/event-stream")
|
||||
return StreamingResponse(file_write_stream(),
|
||||
media_type="text/event-stream")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue