Merge branch 'main' into dev

This commit is contained in:
Timothy Jaeryang Baek 2024-01-01 14:09:45 -05:00 committed by GitHub
commit 127886db14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 237 additions and 330 deletions

1
backend/.gitignore vendored
View file

@ -5,3 +5,4 @@ uploads
.ipynb_checkpoints .ipynb_checkpoints
*.db *.db
_test _test
Pipfile

View file

@ -8,7 +8,7 @@ import json
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import extract_token_from_auth_header from utils.utils import decode_token
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = Flask(__name__) app = Flask(__name__)
@ -34,8 +34,12 @@ def proxy(path):
# Basic RBAC support # Basic RBAC support
if WEBUI_AUTH: if WEBUI_AUTH:
if "Authorization" in headers: if "Authorization" in headers:
token = extract_token_from_auth_header(headers["Authorization"]) _, credentials = headers["Authorization"].split()
user = Users.get_user_by_token(token) token_data = decode_token(credentials)
if token_data is None or "email" not in token_data:
return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
user = Users.get_user_by_email(token_data["email"])
if user: if user:
# Only user and admin roles can access # Only user and admin roles can access
if user.role in ["user", "admin"]: if user.role in ["user", "admin"]:

View file

@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Depends, HTTPException from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
@ -16,13 +16,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])

View file

@ -3,8 +3,6 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
@ -85,14 +83,6 @@ class UsersTable:
except: except:
return None return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))

View file

@ -19,11 +19,7 @@ from apps.web.models.auths import (
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import ( from utils.utils import get_password_hash, get_current_user, create_token
get_password_hash,
bearer_scheme,
create_token,
)
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -36,10 +32,7 @@ router = APIRouter()
@router.get("/", response_model=UserResponse) @router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)): async def get_session_user(user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return { return {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
@ -47,11 +40,6 @@ async def get_session_user(cred=Depends(bearer_scheme)):
"role": user.role, "role": user.role,
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
} }
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -60,10 +48,9 @@ async def get_session_user(cred=Depends(bearer_scheme)):
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)): async def update_password(
token = cred.credentials form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
session_user = Users.get_user_by_token(token) ):
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)

View file

@ -1,8 +1,7 @@
from fastapi import Response from fastapi import Depends, Request, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -30,17 +29,10 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_user_chats(
token = cred.credentials user=Depends(get_current_user), skip: int = 0, limit: int = 50
user = Users.get_user_by_token(token) ):
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -49,20 +41,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)): async def get_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id) for chat in Chats.get_all_chats_by_user_id(user.id)
] ]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -71,18 +54,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -91,24 +65,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
@ -118,11 +82,9 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): async def update_chat_by_id(
token = cred.credentials id: str, form_data: ChatForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
@ -134,11 +96,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -147,18 +104,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -167,15 +115,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(cred=Depends(bearer_scheme)): async def delete_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
result = Chats.delete_chats_by_user_id(user.id) result = Chats.delete_chats_by_user_id(user.id)
return result return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -1,4 +1,3 @@
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
@ -6,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import ( from apps.web.models.modelfiles import (
Modelfiles, Modelfiles,
ModelfileForm, ModelfileForm,
@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import ( from utils.utils import get_current_user
bearer_scheme,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -29,17 +24,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -48,13 +34,15 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)): async def create_new_modelfile(
token = cred.credentials form_data: ModelfileForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
# Admin Only
if user.role == "admin":
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
@ -69,16 +57,6 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -87,13 +65,7 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name( async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
@ -108,11 +80,6 @@ async def get_modelfile_by_tag_name(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -122,13 +89,13 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme) form_data: ModelfileUpdateForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token) raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
if user: detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
if user.role == "admin": )
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
@ -151,16 +118,6 @@ async def update_modelfile_by_tag_name(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -170,22 +127,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) form_data: ModelfileTagNameForm, user=Depends(get_current_user)
): ):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException( result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
status_code=status.HTTP_401_UNAUTHORIZED, return result
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -12,11 +12,7 @@ from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from utils.utils import ( from utils.utils import get_current_user
get_password_hash,
bearer_scheme,
create_token,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
@ -27,23 +23,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else: return Users.get_users(skip, limit)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -52,12 +38,15 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)): async def update_user_role(
token = cred.credentials form_data: UserRoleUpdateForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
if user.role == "admin":
if user.id != form_data.id: if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
else: else:
@ -65,16 +54,6 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
@ -83,11 +62,7 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)): async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin": if user.role == "admin":
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = Auths.delete_auth_by_id(user_id)
@ -109,8 +84,3 @@ async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)

View file

@ -18,3 +18,5 @@ bcrypt
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]
black

View file

@ -1,7 +1,9 @@
from fastapi.security import HTTPBasicCredentials, HTTPBearer from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
@ -53,16 +55,18 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_token(request): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
try: data = decode_token(auth_token.credentials)
bearer = request.headers["authorization"] if data != None and "email" in data:
if bearer: user = Users.get_user_by_email(data["email"])
token = bearer[len("Bearer ") :] if user is None:
decoded = jwt.decode( raise HTTPException(
token, JWT_SECRET_KEY, options={"verify_signature": False} status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return decoded return user
else: else:
return None raise HTTPException(
except Exception as e: status_code=status.HTTP_401_UNAUTHORIZED,
return None detail=ERROR_MESSAGES.UNAUTHORIZED,
)

View file

@ -16,7 +16,7 @@ html {
code { code {
/* white-space-collapse: preserve !important; */ /* white-space-collapse: preserve !important; */
white-space: pre; overflow-x: auto;
width: auto; width: auto;
} }

View file

@ -298,6 +298,24 @@
submitPrompt(prompt); submitPrompt(prompt);
} }
}} }}
on:keydown={(e) => {
if (prompt === '' && e.key == 'ArrowUp') {
e.preventDefault();
const userMessageElement = [
...document.getElementsByClassName('user-message')
]?.at(-1);
const editButton = [
...document.getElementsByClassName('edit-user-message-button')
]?.at(-1);
console.log(userMessageElement);
userMessageElement.scrollIntoView({ block: 'center' });
editButton?.click();
}
}}
rows="1" rows="1"
on:input={(e) => { on:input={(e) => {
e.target.style.height = ''; e.target.style.height = '';

View file

@ -88,6 +88,7 @@
let code = block.querySelector('code'); let code = block.querySelector('code');
code.style.borderTopRightRadius = 0; code.style.borderTopRightRadius = 0;
code.style.borderTopLeftRadius = 0; code.style.borderTopLeftRadius = 0;
code.style.whiteSpace = 'pre';
let topBarDiv = document.createElement('div'); let topBarDiv = document.createElement('div');
topBarDiv.style.backgroundColor = '#202123'; topBarDiv.style.backgroundColor = '#202123';

View file

@ -24,6 +24,8 @@
editElement.style.height = ''; editElement.style.height = '';
editElement.style.height = `${editElement.scrollHeight}px`; editElement.style.height = `${editElement.scrollHeight}px`;
editElement?.focus();
}; };
const editMessageConfirmHandler = async () => { const editMessageConfirmHandler = async () => {
@ -43,7 +45,9 @@
<ProfileImage src={user?.profile_image_url ?? '/user.png'} /> <ProfileImage src={user?.profile_image_url ?? '/user.png'} />
<div class="w-full overflow-hidden"> <div class="w-full overflow-hidden">
<div class="user-message">
<Name>You</Name> <Name>You</Name>
</div>
<div <div
class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line" class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line"
@ -145,7 +149,7 @@
{/if} {/if}
<button <button
class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition" class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition edit-user-message-button"
on:click={() => { on:click={() => {
editMessageHandler(); editMessageHandler();
}} }}

View file

@ -1,7 +1,10 @@
<script lang="ts"> <script lang="ts">
import toast from 'svelte-french-toast';
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { getChatById } from '$lib/apis/chats'; import { getChatById } from '$lib/apis/chats';
import { chatId, db, modelfiles } from '$lib/stores'; import { chatId, db, modelfiles } from '$lib/stores';
import toast from 'svelte-french-toast';
export let initNewChat: Function; export let initNewChat: Function;
export let title: string = 'Ollama Web UI'; export let title: string = 'Ollama Web UI';
@ -33,6 +36,21 @@
false false
); );
}; };
const downloadChat = async () => {
const chat = (await getChatById(localStorage.token, $chatId)).chat;
console.log('download', chat);
const chatText = chat.messages.reduce((a, message, i, arr) => {
return `${a}### ${message.role.toUpperCase()}\n${message.content}\n\n`;
}, '');
let blob = new Blob([chatText], {
type: 'text/plain'
});
saveAs(blob, `chat-${chat.title}.txt`);
};
</script> </script>
<nav <nav
@ -69,7 +87,30 @@
</div> </div>
{#if shareEnabled} {#if shareEnabled}
<div class="pl-2"> <div class="pl-2 flex space-x-1.5">
<button
class=" cursor-pointer p-2 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
on:click={async () => {
downloadChat();
}}
>
<div class=" m-auto self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
/>
<path
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
/>
</svg>
</div>
</button>
<button <button
class=" cursor-pointer p-2 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600" class=" cursor-pointer p-2 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
on:click={async () => { on:click={async () => {
@ -79,15 +120,15 @@
<div class=" m-auto self-center"> <div class=" m-auto self-center">
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20" viewBox="0 0 16 16"
fill="currentColor" fill="currentColor"
class="w-4 h-4" class="w-4 h-4"
> >
<path <path
d="M9.25 13.25a.75.75 0 001.5 0V4.636l2.955 3.129a.75.75 0 001.09-1.03l-4.25-4.5a.75.75 0 00-1.09 0l-4.25 4.5a.75.75 0 101.09 1.03L9.25 4.636v8.614z" d="M7.25 10.25a.75.75 0 0 0 1.5 0V4.56l2.22 2.22a.75.75 0 1 0 1.06-1.06l-3.5-3.5a.75.75 0 0 0-1.06 0l-3.5 3.5a.75.75 0 0 0 1.06 1.06l2.22-2.22v5.69Z"
/> />
<path <path
d="M3.5 12.75a.75.75 0 00-1.5 0v2.5A2.75 2.75 0 004.75 18h10.5A2.75 2.75 0 0018 15.25v-2.5a.75.75 0 00-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5z" d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
/> />
</svg> </svg>
</div> </div>