diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 89616064..dafb29a5 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -11,14 +11,15 @@ from apps.web.routers import ( configs, utils, ) -from config import WEBUI_VERSION, WEBUI_AUTH +from config import WEBUI_VERSION, WEBUI_AUTH, DEFAULT_MODELS, DEFAULT_PROMPT_SUGGESTIONS app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = True -app.state.DEFAULT_MODELS = None +app.state.DEFAULT_MODELS = DEFAULT_MODELS +app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.add_middleware( CORSMiddleware, @@ -46,4 +47,5 @@ async def get_status(): "version": WEBUI_VERSION, "auth": WEBUI_AUTH, "default_models": app.state.DEFAULT_MODELS, + "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, } diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py index 4dfe79fd..379ba9f0 100644 --- a/backend/apps/web/routers/configs.py +++ b/backend/apps/web/routers/configs.py @@ -21,15 +21,24 @@ class SetDefaultModelsForm(BaseModel): models: str +class PromptSuggestion(BaseModel): + title: List[str] + content: str + + +class SetDefaultSuggestionsForm(BaseModel): + suggestions: List[PromptSuggestion] + + ############################ # SetDefaultModels ############################ @router.post("/default/models", response_model=str) -async def set_global_default_models(request: Request, - form_data: SetDefaultModelsForm, - user=Depends(get_current_user)): +async def set_global_default_models( + request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) +): if user.role == "admin": request.app.state.DEFAULT_MODELS = form_data.models return request.app.state.DEFAULT_MODELS @@ -38,3 +47,19 @@ async def set_global_default_models(request: Request, status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + + +@router.post("/default/suggestions", response_model=str) +async def set_global_default_suggestions( + request: Request, + form_data: SetDefaultSuggestionsForm, + user=Depends(get_current_user), +): + if user.role == "admin": + request.app.state.DEFAULT_PROMPT_SUGGESTIONS = form_data.suggestions + return request.app.state.DEFAULT_PROMPT_SUGGESTIONS + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) diff --git a/backend/config.py b/backend/config.py index 2a96d018..8f37eb41 100644 --- a/backend/config.py +++ b/backend/config.py @@ -54,6 +54,34 @@ OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" + +#################################### +# WEBUI +#################################### + +DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) +DEFAULT_PROMPT_SUGGESTIONS = os.environ.get( + "DEFAULT_PROMPT_SUGGESTIONS", + [ + { + "title": ["Help me study", "vocabulary for a college entrance exam"], + "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", + }, + { + "title": ["Give me ideas", "for what to do with my kids' art"], + "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", + }, + { + "title": ["Tell me a fun fact", "about the Roman Empire"], + "content": "Tell me a random fun fact about the Roman Empire", + }, + { + "title": ["Show me a code snippet", "of a website's sticky header"], + "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", + }, + ], +) + #################################### # WEBUI_VERSION ####################################