forked from open-webui/open-webui
		
	chore: py formatting
This commit is contained in:
		
							parent
							
								
									d4fabeee3c
								
							
						
					
					
						commit
						5af8d0612a
					
				
					 12 changed files with 77 additions and 45 deletions
				
			
		|  | @ -22,7 +22,13 @@ from utils.utils import ( | |||
| ) | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     CACHE_DIR, | ||||
|     UPLOAD_DIR, | ||||
|     WHISPER_MODEL, | ||||
|     WHISPER_MODEL_DIR, | ||||
| ) | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["AUDIO"]) | ||||
|  |  | |||
|  | @ -33,7 +33,13 @@ from constants import ERROR_MESSAGES | |||
| from utils.utils import decode_token, get_current_user, get_admin_user | ||||
| 
 | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     OLLAMA_BASE_URLS, | ||||
|     MODEL_FILTER_ENABLED, | ||||
|     MODEL_FILTER_LIST, | ||||
|     UPLOAD_DIR, | ||||
| ) | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
|  | @ -770,7 +776,11 @@ async def generate_chat_completion( | |||
| 
 | ||||
|     r = None | ||||
| 
 | ||||
|     log.debug("form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(form_data.model_dump_json(exclude_none=True).encode())) | ||||
|     log.debug( | ||||
|         "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( | ||||
|             form_data.model_dump_json(exclude_none=True).encode() | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     def get_request(): | ||||
|         nonlocal form_data | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ from utils.utils import verify_password | |||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ from apps.web.internal.db import DB | |||
| import json | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -64,8 +64,8 @@ class ModelfilesTable: | |||
|         self.db.create_tables([Modelfile]) | ||||
| 
 | ||||
|     def insert_new_modelfile( | ||||
|             self, user_id: str, | ||||
|             form_data: ModelfileForm) -> Optional[ModelfileModel]: | ||||
|         self, user_id: str, form_data: ModelfileForm | ||||
|     ) -> Optional[ModelfileModel]: | ||||
|         if "tagName" in form_data.modelfile: | ||||
|             modelfile = ModelfileModel( | ||||
|                 **{ | ||||
|  | @ -73,7 +73,8 @@ class ModelfilesTable: | |||
|                     "tag_name": form_data.modelfile["tagName"], | ||||
|                     "modelfile": json.dumps(form_data.modelfile), | ||||
|                     "timestamp": int(time.time()), | ||||
|                 }) | ||||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|             try: | ||||
|                 result = Modelfile.create(**modelfile.model_dump()) | ||||
|  | @ -87,29 +88,28 @@ class ModelfilesTable: | |||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|     def get_modelfile_by_tag_name(self, | ||||
|                                   tag_name: str) -> Optional[ModelfileModel]: | ||||
|     def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: | ||||
|         try: | ||||
|             modelfile = Modelfile.get(Modelfile.tag_name == tag_name) | ||||
|             return ModelfileModel(**model_to_dict(modelfile)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_modelfiles(self, | ||||
|                        skip: int = 0, | ||||
|                        limit: int = 50) -> List[ModelfileResponse]: | ||||
|     def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: | ||||
|         return [ | ||||
|             ModelfileResponse( | ||||
|                 **{ | ||||
|                     **model_to_dict(modelfile), | ||||
|                     "modelfile": | ||||
|                     json.loads(modelfile.modelfile), | ||||
|                 }) for modelfile in Modelfile.select() | ||||
|                     "modelfile": json.loads(modelfile.modelfile), | ||||
|                 } | ||||
|             ) | ||||
|             for modelfile in Modelfile.select() | ||||
|             # .limit(limit).offset(skip) | ||||
|         ] | ||||
| 
 | ||||
|     def update_modelfile_by_tag_name( | ||||
|             self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]: | ||||
|         self, tag_name: str, modelfile: dict | ||||
|     ) -> Optional[ModelfileModel]: | ||||
|         try: | ||||
|             query = Modelfile.update( | ||||
|                 modelfile=json.dumps(modelfile), | ||||
|  |  | |||
|  | @ -52,8 +52,9 @@ class PromptsTable: | |||
|         self.db = db | ||||
|         self.db.create_tables([Prompt]) | ||||
| 
 | ||||
|     def insert_new_prompt(self, user_id: str, | ||||
|                           form_data: PromptForm) -> Optional[PromptModel]: | ||||
|     def insert_new_prompt( | ||||
|         self, user_id: str, form_data: PromptForm | ||||
|     ) -> Optional[PromptModel]: | ||||
|         prompt = PromptModel( | ||||
|             **{ | ||||
|                 "user_id": user_id, | ||||
|  | @ -61,7 +62,8 @@ class PromptsTable: | |||
|                 "title": form_data.title, | ||||
|                 "content": form_data.content, | ||||
|                 "timestamp": int(time.time()), | ||||
|             }) | ||||
|             } | ||||
|         ) | ||||
| 
 | ||||
|         try: | ||||
|             result = Prompt.create(**prompt.model_dump()) | ||||
|  | @ -81,13 +83,14 @@ class PromptsTable: | |||
| 
 | ||||
|     def get_prompts(self) -> List[PromptModel]: | ||||
|         return [ | ||||
|             PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select() | ||||
|             PromptModel(**model_to_dict(prompt)) | ||||
|             for prompt in Prompt.select() | ||||
|             # .limit(limit).offset(skip) | ||||
|         ] | ||||
| 
 | ||||
|     def update_prompt_by_command( | ||||
|             self, command: str, | ||||
|             form_data: PromptForm) -> Optional[PromptModel]: | ||||
|         self, command: str, form_data: PromptForm | ||||
|     ) -> Optional[PromptModel]: | ||||
|         try: | ||||
|             query = Prompt.update( | ||||
|                 title=form_data.title, | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ import logging | |||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -29,6 +29,7 @@ from apps.web.models.tags import ( | |||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -10,7 +10,12 @@ import uuid | |||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| 
 | ||||
| from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token | ||||
| from utils.utils import ( | ||||
|     get_password_hash, | ||||
|     get_current_user, | ||||
|     get_admin_user, | ||||
|     create_token, | ||||
| ) | ||||
| from utils.misc import get_gravatar_url, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -43,7 +48,6 @@ async def set_global_default_models( | |||
|     return request.app.state.DEFAULT_MODELS | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/default/suggestions", response_model=List[PromptSuggestion]) | ||||
| async def set_global_default_suggestions( | ||||
|     request: Request, | ||||
|  |  | |||
|  | @ -24,9 +24,9 @@ router = APIRouter() | |||
| 
 | ||||
| 
 | ||||
| @router.get("/", response_model=List[ModelfileResponse]) | ||||
| async def get_modelfiles(skip: int = 0, | ||||
|                          limit: int = 50, | ||||
|                          user=Depends(get_current_user)): | ||||
| async def get_modelfiles( | ||||
|     skip: int = 0, limit: int = 50, user=Depends(get_current_user) | ||||
| ): | ||||
|     return Modelfiles.get_modelfiles(skip, limit) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -36,17 +36,16 @@ async def get_modelfiles(skip: int = 0, | |||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[ModelfileResponse]) | ||||
| async def create_new_modelfile(form_data: ModelfileForm, | ||||
|                                user=Depends(get_admin_user)): | ||||
| async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): | ||||
|     modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) | ||||
| 
 | ||||
|     if modelfile: | ||||
|         return ModelfileResponse( | ||||
|             **{ | ||||
|                 **modelfile.model_dump(), | ||||
|                 "modelfile": | ||||
|                 json.loads(modelfile.modelfile), | ||||
|             }) | ||||
|                 "modelfile": json.loads(modelfile.modelfile), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|  | @ -60,17 +59,18 @@ async def create_new_modelfile(form_data: ModelfileForm, | |||
| 
 | ||||
| 
 | ||||
| @router.post("/", response_model=Optional[ModelfileResponse]) | ||||
| async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | ||||
|                                     user=Depends(get_current_user)): | ||||
| async def get_modelfile_by_tag_name( | ||||
|     form_data: ModelfileTagNameForm, user=Depends(get_current_user) | ||||
| ): | ||||
|     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) | ||||
| 
 | ||||
|     if modelfile: | ||||
|         return ModelfileResponse( | ||||
|             **{ | ||||
|                 **modelfile.model_dump(), | ||||
|                 "modelfile": | ||||
|                 json.loads(modelfile.modelfile), | ||||
|             }) | ||||
|                 "modelfile": json.loads(modelfile.modelfile), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|  | @ -84,8 +84,9 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | |||
| 
 | ||||
| 
 | ||||
| @router.post("/update", response_model=Optional[ModelfileResponse]) | ||||
| async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | ||||
|                                        user=Depends(get_admin_user)): | ||||
| async def update_modelfile_by_tag_name( | ||||
|     form_data: ModelfileUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) | ||||
|     if modelfile: | ||||
|         updated_modelfile = { | ||||
|  | @ -94,14 +95,15 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | |||
|         } | ||||
| 
 | ||||
|         modelfile = Modelfiles.update_modelfile_by_tag_name( | ||||
|             form_data.tag_name, updated_modelfile) | ||||
|             form_data.tag_name, updated_modelfile | ||||
|         ) | ||||
| 
 | ||||
|         return ModelfileResponse( | ||||
|             **{ | ||||
|                 **modelfile.model_dump(), | ||||
|                 "modelfile": | ||||
|                 json.loads(modelfile.modelfile), | ||||
|             }) | ||||
|                 "modelfile": json.loads(modelfile.modelfile), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|  | @ -115,7 +117,8 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, | |||
| 
 | ||||
| 
 | ||||
| @router.delete("/delete", response_model=bool) | ||||
| async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, | ||||
|                                        user=Depends(get_admin_user)): | ||||
| async def delete_modelfile_by_tag_name( | ||||
|     form_data: ModelfileTagNameForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) | ||||
|     return result | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ from utils.utils import get_current_user, get_password_hash, get_admin_user | |||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ import json | |||
| import requests | ||||
| from config import VERSION, WEBUI_FAVICON_URL, WEBUI_NAME | ||||
| 
 | ||||
| 
 | ||||
| def post_webhook(url: str, message: str, event_data: dict) -> bool: | ||||
|     try: | ||||
|         payload = {} | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue