forked from open-webui/open-webui
		
	Merge branch 'main' into rag
This commit is contained in:
		
						commit
						fa598b59e2
					
				
					 31 changed files with 1917 additions and 143 deletions
				
			
		|  | @ -1,7 +1,7 @@ | |||
| from fastapi import FastAPI, Depends | ||||
| from fastapi.routing import APIRoute | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from apps.web.routers import auths, users, chats, modelfiles, utils | ||||
| from apps.web.routers import auths, users, chats, modelfiles, prompts, configs, utils | ||||
| from config import WEBUI_VERSION, WEBUI_AUTH | ||||
| 
 | ||||
| app = FastAPI() | ||||
|  | @ -9,6 +9,7 @@ app = FastAPI() | |||
| origins = ["*"] | ||||
| 
 | ||||
| app.state.ENABLE_SIGNUP = True | ||||
| app.state.DEFAULT_MODELS = None | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|  | @ -19,13 +20,21 @@ app.add_middleware( | |||
| ) | ||||
| 
 | ||||
| app.include_router(auths.router, prefix="/auths", tags=["auths"]) | ||||
| 
 | ||||
| app.include_router(users.router, prefix="/users", tags=["users"]) | ||||
| app.include_router(chats.router, prefix="/chats", tags=["chats"]) | ||||
| app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) | ||||
| app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) | ||||
| 
 | ||||
| 
 | ||||
| app.include_router(configs.router, prefix="/configs", tags=["configs"]) | ||||
| app.include_router(utils.router, prefix="/utils", tags=["utils"]) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def get_status(): | ||||
|     return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH} | ||||
|     return { | ||||
|         "status": True, | ||||
|         "version": WEBUI_VERSION, | ||||
|         "auth": WEBUI_AUTH, | ||||
|         "default_models": app.state.DEFAULT_MODELS, | ||||
|     } | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ from apps.web.internal.db import DB | |||
| import json | ||||
| 
 | ||||
| #################### | ||||
| # User DB Schema | ||||
| # Modelfile DB Schema | ||||
| #################### | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										117
									
								
								backend/apps/web/models/prompts.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								backend/apps/web/models/prompts.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,117 @@ | |||
| from pydantic import BaseModel | ||||
| from peewee import * | ||||
| from playhouse.shortcuts import model_to_dict | ||||
| from typing import List, Union, Optional | ||||
| import time | ||||
| 
 | ||||
| from utils.utils import decode_token | ||||
| from utils.misc import get_gravatar_url | ||||
| 
 | ||||
| from apps.web.internal.db import DB | ||||
| 
 | ||||
| import json | ||||
| 
 | ||||
| #################### | ||||
| # Prompts DB Schema | ||||
| #################### | ||||
| 
 | ||||
| 
 | ||||
| class Prompt(Model): | ||||
|     command = CharField(unique=True) | ||||
|     user_id = CharField() | ||||
|     title = CharField() | ||||
|     content = TextField() | ||||
|     timestamp = DateField() | ||||
| 
 | ||||
|     class Meta: | ||||
|         database = DB | ||||
| 
 | ||||
| 
 | ||||
| class PromptModel(BaseModel): | ||||
|     command: str | ||||
|     user_id: str | ||||
|     title: str | ||||
|     content: str | ||||
|     timestamp: int  # timestamp in epoch | ||||
| 
 | ||||
| 
 | ||||
| #################### | ||||
| # Forms | ||||
| #################### | ||||
| 
 | ||||
| 
 | ||||
| class PromptForm(BaseModel): | ||||
|     command: str | ||||
|     title: str | ||||
|     content: str | ||||
| 
 | ||||
| 
 | ||||
| class PromptsTable: | ||||
|     def __init__(self, db): | ||||
|         self.db = db | ||||
|         self.db.create_tables([Prompt]) | ||||
| 
 | ||||
|     def insert_new_prompt( | ||||
|         self, user_id: str, form_data: PromptForm | ||||
|     ) -> Optional[PromptModel]: | ||||
|         prompt = PromptModel( | ||||
|             **{ | ||||
|                 "user_id": user_id, | ||||
|                 "command": form_data.command, | ||||
|                 "title": form_data.title, | ||||
|                 "content": form_data.content, | ||||
|                 "timestamp": int(time.time()), | ||||
|             } | ||||
|         ) | ||||
| 
 | ||||
|         try: | ||||
|             result = Prompt.create(**prompt.model_dump()) | ||||
|             if result: | ||||
|                 return prompt | ||||
|             else: | ||||
|                 return None | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: | ||||
|         try: | ||||
|             prompt = Prompt.get(Prompt.command == command) | ||||
|             return PromptModel(**model_to_dict(prompt)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_prompts(self) -> List[PromptModel]: | ||||
|         return [ | ||||
|             PromptModel(**model_to_dict(prompt)) | ||||
|             for prompt in Prompt.select() | ||||
|             # .limit(limit).offset(skip) | ||||
|         ] | ||||
| 
 | ||||
|     def update_prompt_by_command( | ||||
|         self, command: str, form_data: PromptForm | ||||
|     ) -> Optional[PromptModel]: | ||||
|         try: | ||||
|             query = Prompt.update( | ||||
|                 title=form_data.title, | ||||
|                 content=form_data.content, | ||||
|                 timestamp=int(time.time()), | ||||
|             ).where(Prompt.command == command) | ||||
| 
 | ||||
|             query.execute() | ||||
| 
 | ||||
|             prompt = Prompt.get(Prompt.command == command) | ||||
|             return PromptModel(**model_to_dict(prompt)) | ||||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def delete_prompt_by_command(self, command: str) -> bool: | ||||
|         try: | ||||
|             query = Prompt.delete().where((Prompt.command == command)) | ||||
|             query.execute()  # Remove the rows, return number of rows removed. | ||||
| 
 | ||||
|             return True | ||||
|         except: | ||||
|             return False | ||||
| 
 | ||||
| 
 | ||||
| Prompts = PromptsTable(DB) | ||||
|  | @ -8,6 +8,7 @@ from pydantic import BaseModel | |||
| import time | ||||
| import uuid | ||||
| 
 | ||||
| 
 | ||||
| from apps.web.models.auths import ( | ||||
|     SigninForm, | ||||
|     SignupForm, | ||||
|  | @ -20,7 +21,7 @@ from apps.web.models.users import Users | |||
| 
 | ||||
| 
 | ||||
| from utils.utils import get_password_hash, get_current_user, create_token | ||||
| from utils.misc import get_gravatar_url | ||||
| from utils.misc import get_gravatar_url, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| 
 | ||||
|  | @ -95,33 +96,38 @@ async def signin(form_data: SigninForm): | |||
| @router.post("/signup", response_model=SigninResponse) | ||||
| async def signup(request: Request, form_data: SignupForm): | ||||
|     if request.app.state.ENABLE_SIGNUP: | ||||
|         if not Users.get_user_by_email(form_data.email.lower()): | ||||
|             try: | ||||
|                 role = "admin" if Users.get_num_users() == 0 else "pending" | ||||
|                 hashed = get_password_hash(form_data.password) | ||||
|                 user = Auths.insert_new_auth( | ||||
|                     form_data.email.lower(), hashed, form_data.name, role | ||||
|                 ) | ||||
|         if validate_email_format(form_data.email.lower()): | ||||
|             if not Users.get_user_by_email(form_data.email.lower()): | ||||
|                 try: | ||||
|                     role = "admin" if Users.get_num_users() == 0 else "pending" | ||||
|                     hashed = get_password_hash(form_data.password) | ||||
|                     user = Auths.insert_new_auth( | ||||
|                         form_data.email.lower(), hashed, form_data.name, role | ||||
|                     ) | ||||
| 
 | ||||
|                 if user: | ||||
|                     token = create_token(data={"email": user.email}) | ||||
|                     # response.set_cookie(key='token', value=token, httponly=True) | ||||
|                     if user: | ||||
|                         token = create_token(data={"email": user.email}) | ||||
|                         # response.set_cookie(key='token', value=token, httponly=True) | ||||
| 
 | ||||
|                     return { | ||||
|                         "token": token, | ||||
|                         "token_type": "Bearer", | ||||
|                         "id": user.id, | ||||
|                         "email": user.email, | ||||
|                         "name": user.name, | ||||
|                         "role": user.role, | ||||
|                         "profile_image_url": user.profile_image_url, | ||||
|                     } | ||||
|                 else: | ||||
|                     raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) | ||||
|             except Exception as err: | ||||
|                 raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) | ||||
|                         return { | ||||
|                             "token": token, | ||||
|                             "token_type": "Bearer", | ||||
|                             "id": user.id, | ||||
|                             "email": user.email, | ||||
|                             "name": user.name, | ||||
|                             "role": user.role, | ||||
|                             "profile_image_url": user.profile_image_url, | ||||
|                         } | ||||
|                     else: | ||||
|                         raise HTTPException( | ||||
|                             500, detail=ERROR_MESSAGES.CREATE_USER_ERROR | ||||
|                         ) | ||||
|                 except Exception as err: | ||||
|                     raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) | ||||
|             else: | ||||
|                 raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | ||||
|         else: | ||||
|             raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | ||||
|             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT) | ||||
|     else: | ||||
|         raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										41
									
								
								backend/apps/web/routers/configs.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								backend/apps/web/routers/configs.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,41 @@ | |||
| from fastapi import Response, Request | ||||
| from fastapi import Depends, FastAPI, HTTPException, status | ||||
| from datetime import datetime, timedelta | ||||
| from typing import List, Union | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from pydantic import BaseModel | ||||
| import time | ||||
| import uuid | ||||
| 
 | ||||
| from apps.web.models.users import Users | ||||
| 
 | ||||
| 
 | ||||
| from utils.utils import get_password_hash, get_current_user, create_token | ||||
| from utils.misc import get_gravatar_url, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| class SetDefaultModelsForm(BaseModel): | ||||
|     models: str | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # SetDefaultModels | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/default/models", response_model=str) | ||||
| async def set_global_default_models( | ||||
|     request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) | ||||
| ): | ||||
|     if user.role == "admin": | ||||
|         request.app.state.DEFAULT_MODELS = form_data.models | ||||
|         return request.app.state.DEFAULT_MODELS | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_403_FORBIDDEN, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
							
								
								
									
										115
									
								
								backend/apps/web/routers/prompts.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								backend/apps/web/routers/prompts.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,115 @@ | |||
| from fastapi import Depends, FastAPI, HTTPException, status | ||||
| from datetime import datetime, timedelta | ||||
| from typing import List, Union, Optional | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from pydantic import BaseModel | ||||
| import json | ||||
| 
 | ||||
| 
 | ||||
| from apps.web.models.prompts import Prompts, PromptForm, PromptModel | ||||
| 
 | ||||
| from utils.utils import get_current_user | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| ############################ | ||||
| # GetPrompts | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/", response_model=List[PromptModel]) | ||||
| async def get_prompts(user=Depends(get_current_user)): | ||||
|     return Prompts.get_prompts() | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # CreateNewPrompt | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[PromptModel]) | ||||
| async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     prompt = Prompts.get_prompt_by_command(form_data.command) | ||||
|     if prompt == None: | ||||
|         prompt = Prompts.insert_new_prompt(user.id, form_data) | ||||
| 
 | ||||
|         if prompt: | ||||
|             return prompt | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|                 detail=ERROR_MESSAGES.DEFAULT(), | ||||
|             ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|             detail=ERROR_MESSAGES.COMMAND_TAKEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # GetPromptByCommand | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/{command}", response_model=Optional[PromptModel]) | ||||
| async def get_prompt_by_command(command: str, user=Depends(get_current_user)): | ||||
|     prompt = Prompts.get_prompt_by_command(f"/{command}") | ||||
| 
 | ||||
|     if prompt: | ||||
|         return prompt | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.NOT_FOUND, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # UpdatePromptByCommand | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/{command}/update", response_model=Optional[PromptModel]) | ||||
| async def update_prompt_by_command( | ||||
|     command: str, form_data: PromptForm, user=Depends(get_current_user) | ||||
| ): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) | ||||
|     if prompt: | ||||
|         return prompt | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # DeletePromptByCommand | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.delete("/{command}/delete", response_model=bool) | ||||
| async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | ||||
|         ) | ||||
| 
 | ||||
|     result = Prompts.delete_prompt_by_command(f"/{command}") | ||||
|     return result | ||||
|  | @ -17,10 +17,12 @@ class ERROR_MESSAGES(str, Enum): | |||
|     USERNAME_TAKEN = ( | ||||
|         "Uh-oh! This username is already registered. Please choose another username." | ||||
|     ) | ||||
|     COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." | ||||
|     INVALID_TOKEN = ( | ||||
|         "Your session has expired or the token is invalid. Please sign in again." | ||||
|     ) | ||||
|     INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." | ||||
|     INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." | ||||
|     INVALID_PASSWORD = ( | ||||
|         "The password provided is incorrect. Please check for typos and try again." | ||||
|     ) | ||||
|  | @ -31,5 +33,4 @@ class ERROR_MESSAGES(str, Enum): | |||
|     ) | ||||
|     NOT_FOUND = "We could not find what you're looking for :/" | ||||
|     USER_NOT_FOUND = "We could not find what you're looking for :/" | ||||
| 
 | ||||
|     MALICIOUS = "Unusual activities detected, please try again in a few minutes." | ||||
|  |  | |||
							
								
								
									
										4
									
								
								backend/start.sh
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										4
									
								
								backend/start.sh
									
										
									
									
									
										
										
										Normal file → Executable file
									
								
							|  | @ -1 +1,3 @@ | |||
| uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' | ||||
| #!/usr/bin/env bash | ||||
| 
 | ||||
| uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import hashlib | ||||
| import re | ||||
| 
 | ||||
| 
 | ||||
| def get_gravatar_url(email): | ||||
|  | @ -21,3 +22,9 @@ def calculate_sha256(file): | |||
|     for chunk in iter(lambda: file.read(8192), b""): | ||||
|         sha256.update(chunk) | ||||
|     return sha256.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def validate_email_format(email: str) -> bool: | ||||
|     if not re.match(r"[^@]+@[^@]+\.[^@]+", email): | ||||
|         return False | ||||
|     return True | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek