Merge branch 'dev' into dockerfile-optimisation

This commit is contained in:
Jannik S 2024-04-03 11:37:46 +02:00 committed by GitHub
commit f669c0e78e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1644 additions and 380 deletions

View file

@ -81,6 +81,12 @@ async def check_url(request: Request, call_next):
return response
@app.head("/")
@app.get("/")
async def get_status():
return {"status": True}
@app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}

View file

@ -1,4 +1,5 @@
from peewee import *
from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR
import os
import logging
@ -16,4 +17,6 @@ else:
DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
DB.connect()
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run()
DB.connect(reuse_if_open=True)

View file

@ -0,0 +1,149 @@
"""Peewee migrations -- 001_initial_schema.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.CharField(max_length=255)
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
chat = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.CharField()
filename = pw.CharField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
content = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("user")
migrator.remove_model("tag")
migrator.remove_model("prompt")
migrator.remove_model("modelfile")
migrator.remove_model("document")
migrator.remove_model("chatidtag")
migrator.remove_model("chat")
migrator.remove_model("auth")

View file

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id")

View file

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")

View file

@ -0,0 +1,21 @@
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.

View file

@ -20,6 +20,7 @@ from config import (
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
)
app = FastAPI()
@ -34,7 +35,7 @@ app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware(
CORSMiddleware,

View file

@ -47,6 +47,10 @@ class Token(BaseModel):
token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel):
id: str
email: str
@ -123,6 +127,28 @@ class AuthsTable:
except:
return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:
user = Users.get_user_by_id(auth.id)
return user
except:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try:
query = Auth.update(password=new_password).where(Auth.id == id)

View file

@ -20,6 +20,7 @@ class Chat(Model):
title = CharField()
chat = TextField() # Save Chat JSON as Text
timestamp = DateField()
share_id = CharField(null=True, unique=True)
class Meta:
database = DB
@ -31,6 +32,7 @@ class ChatModel(BaseModel):
title: str
chat: str
timestamp: int # timestamp in epoch
share_id: Optional[str] = None
####################
@ -52,6 +54,7 @@ class ChatResponse(BaseModel):
title: str
chat: dict
timestamp: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
class ChatTitleIdResponse(BaseModel):
@ -95,6 +98,71 @@ class ChatTable:
except:
return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"timestamp": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
print(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
query = Chat.update(
share_id=share_id,
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
@ -131,6 +199,13 @@ class ChatTable:
.order_by(Chat.timestamp.desc())
]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
@ -149,12 +224,15 @@ class ChatTable:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
return True
return True and self.delete_shared_chat_by_chat_id(id)
except:
return False
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed.
@ -162,5 +240,19 @@ class ChatTable:
except:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Chats = ChatTable(DB)

View file

@ -20,6 +20,7 @@ class User(Model):
role = CharField()
profile_image_url = CharField()
timestamp = DateField()
api_key = CharField(null=True, unique=True)
class Meta:
database = DB
@ -32,6 +33,7 @@ class UserModel(BaseModel):
role: str = "pending"
profile_image_url: str = "/user.png"
timestamp: int # timestamp in epoch
api_key: Optional[str] = None
####################
@ -82,6 +84,13 @@ class UsersTable:
except:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
user = User.get(User.email == email)
@ -149,5 +158,21 @@ class UsersTable:
except:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
return user.api_key
except:
return None
Users = UsersTable(DB)

View file

@ -1,13 +1,10 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import Request
from fastapi import Depends, HTTPException, status
from fastapi import APIRouter, status
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
import re
import uuid
from apps.web.models.auths import (
SigninForm,
@ -17,6 +14,7 @@ from apps.web.models.auths import (
UserResponse,
SigninResponse,
Auths,
ApiKey,
)
from apps.web.models.users import Users
@ -25,10 +23,12 @@ from utils.utils import (
get_current_user,
get_admin_user,
create_token,
create_api_key,
)
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter()
@ -79,6 +79,8 @@ async def update_profile(
async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password)
@ -98,7 +100,22 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
if not Users.get_user_by_email(trusted_email.lower()):
await signup(
request,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user:
token = create_token(
data={"id": user.id},
@ -249,3 +266,40 @@ async def update_token_expires_duration(
return request.app.state.JWT_EXPIRES_IN
else:
return request.app.state.JWT_EXPIRES_IN
############################
# API Key
############################
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
if api_key:
return {
"api_key": api_key,
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

View file

@ -189,6 +189,78 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result
############################
# ShareChatById
############################
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeletedSharedChatById
############################
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if not chat.share_id:
return False
result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetChatTagsById
############################