diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 23b39224..153b5dcb 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, chats, modelfiles, configs, utils +from apps.web.routers import auths, users, chats, modelfiles, prompts, configs, utils from config import WEBUI_VERSION, WEBUI_AUTH app = FastAPI() @@ -23,6 +23,9 @@ app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) + + app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py index 4d8202db..8231d8df 100644 --- a/backend/apps/web/models/modelfiles.py +++ b/backend/apps/web/models/modelfiles.py @@ -12,7 +12,7 @@ from apps.web.internal.db import DB import json #################### -# User DB Schema +# Modelfile DB Schema #################### diff --git a/backend/apps/web/models/prompts.py b/backend/apps/web/models/prompts.py new file mode 100644 index 00000000..bb0710b6 --- /dev/null +++ b/backend/apps/web/models/prompts.py @@ -0,0 +1,117 @@ +from pydantic import BaseModel +from peewee import * +from playhouse.shortcuts import model_to_dict +from typing import List, Union, Optional +import time + +from utils.utils import decode_token +from utils.misc import get_gravatar_url + +from apps.web.internal.db import DB + +import json + +#################### +# Prompts DB Schema +#################### + + +class Prompt(Model): + command = CharField(unique=True) + user_id = CharField() + title = CharField() + content = TextField() + timestamp = DateField() + + class Meta: + database = DB + + +class PromptModel(BaseModel): + command: str + user_id: str + title: str + content: str + timestamp: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class PromptForm(BaseModel): + command: str + title: str + content: str + + +class PromptsTable: + def __init__(self, db): + self.db = db + self.db.create_tables([Prompt]) + + def insert_new_prompt( + self, user_id: str, form_data: PromptForm + ) -> Optional[PromptModel]: + prompt = PromptModel( + **{ + "user_id": user_id, + "command": form_data.command, + "title": form_data.title, + "content": form_data.content, + "timestamp": int(time.time()), + } + ) + + try: + result = Prompt.create(**prompt.model_dump()) + if result: + return prompt + else: + return None + except: + return None + + def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + try: + prompt = Prompt.get(Prompt.command == command) + return PromptModel(**model_to_dict(prompt)) + except: + return None + + def get_prompts(self) -> List[PromptModel]: + return [ + PromptModel(**model_to_dict(prompt)) + for prompt in Prompt.select() + # .limit(limit).offset(skip) + ] + + def update_prompt_by_command( + self, command: str, form_data: PromptForm + ) -> Optional[PromptModel]: + try: + query = Prompt.update( + title=form_data.title, + content=form_data.content, + timestamp=int(time.time()), + ).where(Prompt.command == command) + + query.execute() + + prompt = Prompt.get(Prompt.command == command) + return PromptModel(**model_to_dict(prompt)) + except: + return None + + def delete_prompt_by_command(self, command: str) -> bool: + try: + query = Prompt.delete().where((Prompt.command == command)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + + +Prompts = PromptsTable(DB) diff --git a/backend/apps/web/routers/prompts.py b/backend/apps/web/routers/prompts.py new file mode 100644 index 00000000..5a002c94 --- /dev/null +++ b/backend/apps/web/routers/prompts.py @@ -0,0 +1,115 @@ +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + + +from apps.web.models.prompts import Prompts, PromptForm, PromptModel + +from utils.utils import get_current_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +############################ +# GetPrompts +############################ + + +@router.get("/", response_model=List[PromptModel]) +async def get_prompts(user=Depends(get_current_user)): + return Prompts.get_prompts() + + +############################ +# CreateNewPrompt +############################ + + +@router.post("/create", response_model=Optional[PromptModel]) +async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + prompt = Prompts.get_prompt_by_command(form_data.command) + if prompt == None: + prompt = Prompts.insert_new_prompt(user.id, form_data) + + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.COMMAND_TAKEN, + ) + + +############################ +# GetPromptByCommand +############################ + + +@router.get("/{command}", response_model=Optional[PromptModel]) +async def get_prompt_by_command(command: str, user=Depends(get_current_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdatePromptByCommand +############################ + + +@router.post("/{command}/update", response_model=Optional[PromptModel]) +async def update_prompt_by_command( + command: str, form_data: PromptForm, user=Depends(get_current_user) +): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + if prompt: + return prompt + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeletePromptByCommand +############################ + + +@router.delete("/{command}/delete", response_model=bool) +async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + result = Prompts.delete_prompt_by_command(f"/{command}") + return result diff --git a/backend/constants.py b/backend/constants.py index ec3ce337..0817445b 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -17,6 +17,7 @@ class ERROR_MESSAGES(str, Enum): USERNAME_TAKEN = ( "Uh-oh! This username is already registered. Please choose another username." ) + COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." ) @@ -32,5 +33,4 @@ class ERROR_MESSAGES(str, Enum): ) NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/" - MALICIOUS = "Unusual activities detected, please try again in a few minutes." diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts new file mode 100644 index 00000000..7ed303b3 --- /dev/null +++ b/src/lib/apis/prompts/index.ts @@ -0,0 +1,178 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewPrompt = async ( + token: string, + command: string, + title: string, + content: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + command: `/${command}`, + title: title, + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPrompts = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getPromptByCommand = async (token: string, command: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updatePromptByCommand = async ( + token: string, + command: string, + title: string, + content: string +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + command: `/${command}`, + title: title, + content: content + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deletePromptByCommand = async (token: string, command: string) => { + let error = null; + + command = command.charAt(0) === '/' ? command.slice(1) : command; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/${command}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/MessageInput/PromptCommands.svelte b/src/lib/components/chat/MessageInput/PromptCommands.svelte index 2b41bac0..ddf35360 100644 --- a/src/lib/components/chat/MessageInput/PromptCommands.svelte +++ b/src/lib/components/chat/MessageInput/PromptCommands.svelte @@ -1,158 +1,13 @@ + +