forked from open-webui/open-webui
Merge branch 'main' into rag
This commit is contained in:
commit
fa598b59e2
31 changed files with 1917 additions and 143 deletions
|
@ -1,7 +1,7 @@
|
|||
from fastapi import FastAPI, Depends
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from apps.web.routers import auths, users, chats, modelfiles, utils
|
||||
from apps.web.routers import auths, users, chats, modelfiles, prompts, configs, utils
|
||||
from config import WEBUI_VERSION, WEBUI_AUTH
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -9,6 +9,7 @@ app = FastAPI()
|
|||
origins = ["*"]
|
||||
|
||||
app.state.ENABLE_SIGNUP = True
|
||||
app.state.DEFAULT_MODELS = None
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
@ -19,13 +20,21 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH}
|
||||
return {
|
||||
"status": True,
|
||||
"version": WEBUI_VERSION,
|
||||
"auth": WEBUI_AUTH,
|
||||
"default_models": app.state.DEFAULT_MODELS,
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ from apps.web.internal.db import DB
|
|||
import json
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
# Modelfile DB Schema
|
||||
####################
|
||||
|
||||
|
||||
|
|
117
backend/apps/web/models/prompts.py
Normal file
117
backend/apps/web/models/prompts.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
|
||||
from utils.utils import decode_token
|
||||
from utils.misc import get_gravatar_url
|
||||
|
||||
from apps.web.internal.db import DB
|
||||
|
||||
import json
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Prompt(Model):
|
||||
command = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
title = CharField()
|
||||
content = TextField()
|
||||
timestamp = DateField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class PromptModel(BaseModel):
|
||||
command: str
|
||||
user_id: str
|
||||
title: str
|
||||
content: str
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class PromptForm(BaseModel):
|
||||
command: str
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptsTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([Prompt])
|
||||
|
||||
def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
prompt = PromptModel(
|
||||
**{
|
||||
"user_id": user_id,
|
||||
"command": form_data.command,
|
||||
"title": form_data.title,
|
||||
"content": form_data.content,
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Prompt.create(**prompt.model_dump())
|
||||
if result:
|
||||
return prompt
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||
try:
|
||||
prompt = Prompt.get(Prompt.command == command)
|
||||
return PromptModel(**model_to_dict(prompt))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_prompts(self) -> List[PromptModel]:
|
||||
return [
|
||||
PromptModel(**model_to_dict(prompt))
|
||||
for prompt in Prompt.select()
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
query = Prompt.update(
|
||||
title=form_data.title,
|
||||
content=form_data.content,
|
||||
timestamp=int(time.time()),
|
||||
).where(Prompt.command == command)
|
||||
|
||||
query.execute()
|
||||
|
||||
prompt = Prompt.get(Prompt.command == command)
|
||||
return PromptModel(**model_to_dict(prompt))
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_prompt_by_command(self, command: str) -> bool:
|
||||
try:
|
||||
query = Prompt.delete().where((Prompt.command == command))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Prompts = PromptsTable(DB)
|
|
@ -8,6 +8,7 @@ from pydantic import BaseModel
|
|||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
from apps.web.models.auths import (
|
||||
SigninForm,
|
||||
SignupForm,
|
||||
|
@ -20,7 +21,7 @@ from apps.web.models.users import Users
|
|||
|
||||
|
||||
from utils.utils import get_password_hash, get_current_user, create_token
|
||||
from utils.misc import get_gravatar_url
|
||||
from utils.misc import get_gravatar_url, validate_email_format
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
|
@ -95,33 +96,38 @@ async def signin(form_data: SigninForm):
|
|||
@router.post("/signup", response_model=SigninResponse)
|
||||
async def signup(request: Request, form_data: SignupForm):
|
||||
if request.app.state.ENABLE_SIGNUP:
|
||||
if not Users.get_user_by_email(form_data.email.lower()):
|
||||
try:
|
||||
role = "admin" if Users.get_num_users() == 0 else "pending"
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(), hashed, form_data.name, role
|
||||
)
|
||||
if validate_email_format(form_data.email.lower()):
|
||||
if not Users.get_user_by_email(form_data.email.lower()):
|
||||
try:
|
||||
role = "admin" if Users.get_num_users() == 0 else "pending"
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(), hashed, form_data.name, role
|
||||
)
|
||||
|
||||
if user:
|
||||
token = create_token(data={"email": user.email})
|
||||
# response.set_cookie(key='token', value=token, httponly=True)
|
||||
if user:
|
||||
token = create_token(data={"email": user.email})
|
||||
# response.set_cookie(key='token', value=token, httponly=True)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
"profile_image_url": user.profile_image_url,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
||||
)
|
||||
except Exception as err:
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||
|
||||
|
|
41
backend/apps/web/routers/configs.py
Normal file
41
backend/apps/web/routers/configs.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from fastapi import Response, Request
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from apps.web.models.users import Users
|
||||
|
||||
|
||||
from utils.utils import get_password_hash, get_current_user, create_token
|
||||
from utils.misc import get_gravatar_url, validate_email_format
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SetDefaultModelsForm(BaseModel):
|
||||
models: str
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/default/models", response_model=str)
|
||||
async def set_global_default_models(
|
||||
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
|
||||
):
|
||||
if user.role == "admin":
|
||||
request.app.state.DEFAULT_MODELS = form_data.models
|
||||
return request.app.state.DEFAULT_MODELS
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
115
backend/apps/web/routers/prompts.py
Normal file
115
backend/apps/web/routers/prompts.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
|
||||
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
|
||||
|
||||
from utils.utils import get_current_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetPrompts
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[PromptModel])
|
||||
async def get_prompts(user=Depends(get_current_user)):
|
||||
return Prompts.get_prompts()
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewPrompt
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
if prompt == None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
|
||||
if prompt:
|
||||
return prompt
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.COMMAND_TAKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetPromptByCommand
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
|
||||
if prompt:
|
||||
return prompt
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdatePromptByCommand
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{command}/update", response_model=Optional[PromptModel])
|
||||
async def update_prompt_by_command(
|
||||
command: str, form_data: PromptForm, user=Depends(get_current_user)
|
||||
):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||
if prompt:
|
||||
return prompt
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeletePromptByCommand
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
return result
|
|
@ -17,10 +17,12 @@ class ERROR_MESSAGES(str, Enum):
|
|||
USERNAME_TAKEN = (
|
||||
"Uh-oh! This username is already registered. Please choose another username."
|
||||
)
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
INVALID_TOKEN = (
|
||||
"Your session has expired or the token is invalid. Please sign in again."
|
||||
)
|
||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||
INVALID_PASSWORD = (
|
||||
"The password provided is incorrect. Please check for typos and try again."
|
||||
)
|
||||
|
@ -31,5 +33,4 @@ class ERROR_MESSAGES(str, Enum):
|
|||
)
|
||||
NOT_FOUND = "We could not find what you're looking for :/"
|
||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||
|
||||
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
||||
|
|
4
backend/start.sh
Normal file → Executable file
4
backend/start.sh
Normal file → Executable file
|
@ -1 +1,3 @@
|
|||
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*'
|
||||
#!/usr/bin/env bash
|
||||
|
||||
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*'
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import hashlib
|
||||
import re
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
|
@ -21,3 +22,9 @@ def calculate_sha256(file):
|
|||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
def validate_email_format(email: str) -> bool:
|
||||
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
|
||||
return False
|
||||
return True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue