forked from open-webui/open-webui
feat: modelfiles backend
This commit is contained in:
parent
d78df83453
commit
a2b1e3756b
3 changed files with 330 additions and 28 deletions
122
backend/apps/web/models/modelfiles.py
Normal file
122
backend/apps/web/models/modelfiles.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
|
||||
from utils.utils import decode_token
|
||||
from utils.misc import get_gravatar_url
|
||||
|
||||
from apps.web.internal.db import DB
|
||||
|
||||
import json
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Modelfile(Model):
|
||||
tag_name = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
modelfile = TextField()
|
||||
timestamp = DateField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class ModelfileModel(BaseModel):
|
||||
tag_name: str
|
||||
user_id: str
|
||||
modelfile: str
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ModelfileForm(BaseModel):
|
||||
modelfile: dict
|
||||
|
||||
|
||||
class ModelfileResponse(BaseModel):
|
||||
tag_name: str
|
||||
user_id: str
|
||||
modelfile: dict
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
|
||||
class ModelfilesTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([Modelfile])
|
||||
|
||||
def insert_new_modelfile(
|
||||
self, user_id: str, form_data: ModelfileForm
|
||||
) -> Optional[ModelfileModel]:
|
||||
if "title" in form_data.modelfile:
|
||||
modelfile = ModelfileModel(
|
||||
**{
|
||||
"user_id": user_id,
|
||||
"tag_name": form_data.modelfile["title"],
|
||||
"modelfile": json.dumps(form_data.modelfile),
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
)
|
||||
result = Modelfile.create(**modelfile.model_dump())
|
||||
if result:
|
||||
return modelfile
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
|
||||
try:
|
||||
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
|
||||
return ModelfileModel(**model_to_dict(modelfile))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
|
||||
return [
|
||||
ModelfileResponse(
|
||||
**{
|
||||
**model_to_dict(modelfile),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
for modelfile in Modelfile.select()
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def update_modelfile_by_tag_name(
|
||||
self, tag_name: str, modelfile: dict
|
||||
) -> Optional[ModelfileModel]:
|
||||
try:
|
||||
query = Modelfile.update(
|
||||
modelfile=json.dumps(modelfile),
|
||||
timestamp=int(time.time()),
|
||||
).where(Modelfile.tag_name == tag_name)
|
||||
|
||||
query.execute()
|
||||
|
||||
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
|
||||
return ModelfileModel(**model_to_dict(modelfile))
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
|
||||
try:
|
||||
query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Modelfiles = ModelfilesTable(DB)
|
178
backend/apps/web/routers/modelfiles.py
Normal file
178
backend/apps/web/routers/modelfiles.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
from fastapi import Response
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.web.models.users import Users
|
||||
from apps.web.models.modelfiles import (
|
||||
Modelfiles,
|
||||
ModelfileForm,
|
||||
ModelfileResponse,
|
||||
)
|
||||
|
||||
from utils.utils import (
|
||||
bearer_scheme,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetModelfiles
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ModelfileResponse])
|
||||
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
|
||||
token = cred.credentials
|
||||
user = Users.get_user_by_token(token)
|
||||
|
||||
if user:
|
||||
return Modelfiles.get_modelfiles(skip, limit)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewModelfile
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ModelfileResponse])
|
||||
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
|
||||
token = cred.credentials
|
||||
user = Users.get_user_by_token(token)
|
||||
|
||||
if user:
|
||||
# Admin Only
|
||||
if user.role == "admin":
|
||||
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{tag_name}", response_model=Optional[ModelfileResponse])
|
||||
async def get_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
|
||||
token = cred.credentials
|
||||
user = Users.get_user_by_token(token)
|
||||
|
||||
if user:
|
||||
modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
|
||||
|
||||
if modelfile:
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{tag_name}", response_model=Optional[ModelfileResponse])
|
||||
async def update_modelfile_by_tag_name(
|
||||
tag_name: str, form_data: ModelfileForm, cred=Depends(bearer_scheme)
|
||||
):
|
||||
token = cred.credentials
|
||||
user = Users.get_user_by_token(token)
|
||||
|
||||
if user:
|
||||
if user.role == "admin":
|
||||
modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
|
||||
if modelfile:
|
||||
updated_modelfile = {
|
||||
**json.loads(modelfile.modelfile),
|
||||
**form_data.modelfile,
|
||||
}
|
||||
|
||||
modelfile = Modelfiles.update_modelfile_by_tag_name(
|
||||
tag_name, updated_modelfile
|
||||
)
|
||||
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{tag_name}", response_model=bool)
|
||||
async def delete_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
|
||||
token = cred.credentials
|
||||
user = Users.get_user_by_token(token)
|
||||
|
||||
if user:
|
||||
if user.role == "admin":
|
||||
result = Modelfiles.delete_modelfile_by_tag_name(tag_name)
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
|
@ -98,35 +98,37 @@
|
|||
</button>
|
||||
</div>
|
||||
|
||||
<div class="px-2.5 flex justify-center my-1">
|
||||
<button
|
||||
class="flex-grow flex space-x-3 rounded-md px-3 py-2 hover:bg-gray-900 transition"
|
||||
on:click={async () => {
|
||||
goto('/modelfiles');
|
||||
}}
|
||||
>
|
||||
<div class="self-center">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="w-4 h-4"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M13.5 16.875h3.375m0 0h3.375m-3.375 0V13.5m0 3.375v3.375M6 10.5h2.25a2.25 2.25 0 002.25-2.25V6a2.25 2.25 0 00-2.25-2.25H6A2.25 2.25 0 003.75 6v2.25A2.25 2.25 0 006 10.5zm0 9.75h2.25A2.25 2.25 0 0010.5 18v-2.25a2.25 2.25 0 00-2.25-2.25H6a2.25 2.25 0 00-2.25 2.25V18A2.25 2.25 0 006 20.25zm9.75-9.75H18a2.25 2.25 0 002.25-2.25V6A2.25 2.25 0 0018 3.75h-2.25A2.25 2.25 0 0013.5 6v2.25a2.25 2.25 0 002.25 2.25z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
{#if $user?.role === 'admin'}
|
||||
<div class="px-2.5 flex justify-center my-1">
|
||||
<button
|
||||
class="flex-grow flex space-x-3 rounded-md px-3 py-2 hover:bg-gray-900 transition"
|
||||
on:click={async () => {
|
||||
goto('/modelfiles');
|
||||
}}
|
||||
>
|
||||
<div class="self-center">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke-width="1.5"
|
||||
stroke="currentColor"
|
||||
class="w-4 h-4"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M13.5 16.875h3.375m0 0h3.375m-3.375 0V13.5m0 3.375v3.375M6 10.5h2.25a2.25 2.25 0 002.25-2.25V6a2.25 2.25 0 00-2.25-2.25H6A2.25 2.25 0 003.75 6v2.25A2.25 2.25 0 006 10.5zm0 9.75h2.25A2.25 2.25 0 0010.5 18v-2.25a2.25 2.25 0 00-2.25-2.25H6a2.25 2.25 0 00-2.25 2.25V18A2.25 2.25 0 006 20.25zm9.75-9.75H18a2.25 2.25 0 002.25-2.25V6A2.25 2.25 0 0018 3.75h-2.25A2.25 2.25 0 0013.5 6v2.25a2.25 2.25 0 002.25 2.25z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<div class="flex self-center">
|
||||
<div class=" self-center font-medium text-sm">Modelfiles</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
<div class="flex self-center">
|
||||
<div class=" self-center font-medium text-sm">Modelfiles</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="px-2.5 mt-1 mb-2 flex justify-center space-x-2">
|
||||
<div class="flex w-full">
|
||||
|
|
Loading…
Reference in a new issue