forked from open-webui/open-webui
		
	feat: modelfiles backend
This commit is contained in:
		
							parent
							
								
									d78df83453
								
							
						
					
					
						commit
						a2b1e3756b
					
				
					 3 changed files with 330 additions and 28 deletions
				
			
		
							
								
								
									
										122
									
								
								backend/apps/web/models/modelfiles.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								backend/apps/web/models/modelfiles.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,122 @@ | |||
| from pydantic import BaseModel | ||||
| from peewee import * | ||||
| from playhouse.shortcuts import model_to_dict | ||||
| from typing import List, Union, Optional | ||||
| import time | ||||
| 
 | ||||
| from utils.utils import decode_token | ||||
| from utils.misc import get_gravatar_url | ||||
| 
 | ||||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| import json | ||||
| 
 | ||||
| #################### | ||||
| # User DB Schema | ||||
| #################### | ||||
| 
 | ||||
| 
 | ||||
| class Modelfile(Model): | ||||
|     tag_name = CharField(unique=True) | ||||
|     user_id = CharField() | ||||
|     modelfile = TextField() | ||||
|     timestamp = DateField() | ||||
| 
 | ||||
|     class Meta: | ||||
|         database = DB | ||||
| 
 | ||||
| 
 | ||||
| class ModelfileModel(BaseModel): | ||||
|     tag_name: str | ||||
|     user_id: str | ||||
|     modelfile: str | ||||
|     timestamp: int  # timestamp in epoch | ||||
| 
 | ||||
| 
 | ||||
| #################### | ||||
| # Forms | ||||
| #################### | ||||
| 
 | ||||
| 
 | ||||
| class ModelfileForm(BaseModel): | ||||
|     modelfile: dict | ||||
| 
 | ||||
| 
 | ||||
| class ModelfileResponse(BaseModel): | ||||
|     tag_name: str | ||||
|     user_id: str | ||||
|     modelfile: dict | ||||
|     timestamp: int  # timestamp in epoch | ||||
| 
 | ||||
| 
 | ||||
| class ModelfilesTable: | ||||
|     def __init__(self, db): | ||||
|         self.db = db | ||||
|         self.db.create_tables([Modelfile]) | ||||
| 
 | ||||
|     def insert_new_modelfile( | ||||
|         self, user_id: str, form_data: ModelfileForm | ||||
|     ) -> Optional[ModelfileModel]: | ||||
|         if "title" in form_data.modelfile: | ||||
|             modelfile = ModelfileModel( | ||||
|                 **{ | ||||
|                     "user_id": user_id, | ||||
|                     "tag_name": form_data.modelfile["title"], | ||||
|                     "modelfile": json.dumps(form_data.modelfile), | ||||
|                     "timestamp": int(time.time()), | ||||
|                 } | ||||
|             ) | ||||
|             result = Modelfile.create(**modelfile.model_dump()) | ||||
|             if result: | ||||
|                 return modelfile | ||||
|             else: | ||||
|                 return None | ||||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|     def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: | ||||
|         try: | ||||
|             modelfile = Modelfile.get(Modelfile.tag_name == tag_name) | ||||
|             return ModelfileModel(**model_to_dict(modelfile)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: | ||||
|         return [ | ||||
|             ModelfileResponse( | ||||
|                 **{ | ||||
|                     **model_to_dict(modelfile), | ||||
|                     "modelfile": json.loads(modelfile.modelfile), | ||||
|                 } | ||||
|             ) | ||||
|             for modelfile in Modelfile.select() | ||||
|             # .limit(limit).offset(skip) | ||||
|         ] | ||||
| 
 | ||||
|     def update_modelfile_by_tag_name( | ||||
|         self, tag_name: str, modelfile: dict | ||||
|     ) -> Optional[ModelfileModel]: | ||||
|         try: | ||||
|             query = Modelfile.update( | ||||
|                 modelfile=json.dumps(modelfile), | ||||
|                 timestamp=int(time.time()), | ||||
|             ).where(Modelfile.tag_name == tag_name) | ||||
| 
 | ||||
|             query.execute() | ||||
| 
 | ||||
|             modelfile = Modelfile.get(Modelfile.tag_name == tag_name) | ||||
|             return ModelfileModel(**model_to_dict(modelfile)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def delete_modelfile_by_tag_name(self, tag_name: str) -> bool: | ||||
|         try: | ||||
|             query = Modelfile.delete().where((Modelfile.tag_name == tag_name)) | ||||
|             query.execute()  # Remove the rows, return number of rows removed. | ||||
| 
 | ||||
|             return True | ||||
|         except: | ||||
|             return False | ||||
| 
 | ||||
| 
 | ||||
| Modelfiles = ModelfilesTable(DB) | ||||
							
								
								
									
										178
									
								
								backend/apps/web/routers/modelfiles.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								backend/apps/web/routers/modelfiles.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,178 @@ | |||
| from fastapi import Response | ||||
| from fastapi import Depends, FastAPI, HTTPException, status | ||||
| from datetime import datetime, timedelta | ||||
| from typing import List, Union, Optional | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from pydantic import BaseModel | ||||
| import json | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| from apps.web.models.modelfiles import ( | ||||
|     Modelfiles, | ||||
|     ModelfileForm, | ||||
|     ModelfileResponse, | ||||
| ) | ||||
| 
 | ||||
| from utils.utils import ( | ||||
|     bearer_scheme, | ||||
| ) | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| ############################ | ||||
| # GetModelfiles | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/", response_model=List[ModelfileResponse]) | ||||
| async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): | ||||
|     token = cred.credentials | ||||
|     user = Users.get_user_by_token(token) | ||||
| 
 | ||||
|     if user: | ||||
|         return Modelfiles.get_modelfiles(skip, limit) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # CreateNewModelfile | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[ModelfileResponse]) | ||||
| async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)): | ||||
|     token = cred.credentials | ||||
|     user = Users.get_user_by_token(token) | ||||
| 
 | ||||
|     if user: | ||||
|         # Admin Only | ||||
|         if user.role == "admin": | ||||
|             modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) | ||||
|             return ModelfileResponse( | ||||
|                 **{ | ||||
|                     **modelfile.model_dump(), | ||||
|                     "modelfile": json.loads(modelfile.modelfile), | ||||
|                 } | ||||
|             ) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|             ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # GetModelfileByTagName | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/{tag_name}", response_model=Optional[ModelfileResponse]) | ||||
| async def get_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)): | ||||
|     token = cred.credentials | ||||
|     user = Users.get_user_by_token(token) | ||||
| 
 | ||||
|     if user: | ||||
|         modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name) | ||||
| 
 | ||||
|         if modelfile: | ||||
|             return ModelfileResponse( | ||||
|                 **{ | ||||
|                     **modelfile.model_dump(), | ||||
|                     "modelfile": json.loads(modelfile.modelfile), | ||||
|                 } | ||||
|             ) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail=ERROR_MESSAGES.NOT_FOUND, | ||||
|             ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # UpdateModelfileByTagName | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/{tag_name}", response_model=Optional[ModelfileResponse]) | ||||
| async def update_modelfile_by_tag_name( | ||||
|     tag_name: str, form_data: ModelfileForm, cred=Depends(bearer_scheme) | ||||
| ): | ||||
|     token = cred.credentials | ||||
|     user = Users.get_user_by_token(token) | ||||
| 
 | ||||
|     if user: | ||||
|         if user.role == "admin": | ||||
|             modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name) | ||||
|             if modelfile: | ||||
|                 updated_modelfile = { | ||||
|                     **json.loads(modelfile.modelfile), | ||||
|                     **form_data.modelfile, | ||||
|                 } | ||||
| 
 | ||||
|                 modelfile = Modelfiles.update_modelfile_by_tag_name( | ||||
|                     tag_name, updated_modelfile | ||||
|                 ) | ||||
| 
 | ||||
|                 return ModelfileResponse( | ||||
|                     **{ | ||||
|                         **modelfile.model_dump(), | ||||
|                         "modelfile": json.loads(modelfile.modelfile), | ||||
|                     } | ||||
|                 ) | ||||
|             else: | ||||
|                 raise HTTPException( | ||||
|                     status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                     detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|                 ) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|             ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # DeleteModelfileByTagName | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.delete("/{tag_name}", response_model=bool) | ||||
| async def delete_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)): | ||||
|     token = cred.credentials | ||||
|     user = Users.get_user_by_token(token) | ||||
| 
 | ||||
|     if user: | ||||
|         if user.role == "admin": | ||||
|             result = Modelfiles.delete_modelfile_by_tag_name(tag_name) | ||||
|             return result | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|             ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.INVALID_TOKEN, | ||||
|         ) | ||||
|  | @ -98,6 +98,7 @@ | |||
| 			</button> | ||||
| 		</div> | ||||
| 
 | ||||
| 		{#if $user?.role === 'admin'} | ||||
| 			<div class="px-2.5 flex justify-center my-1"> | ||||
| 				<button | ||||
| 					class="flex-grow flex space-x-3 rounded-md px-3 py-2 hover:bg-gray-900 transition" | ||||
|  | @ -127,6 +128,7 @@ | |||
| 					</div> | ||||
| 				</button> | ||||
| 			</div> | ||||
| 		{/if} | ||||
| 
 | ||||
| 		<div class="px-2.5 mt-1 mb-2 flex justify-center space-x-2"> | ||||
| 			<div class="flex w-full"> | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek