feat: custom interface support

This commit is contained in:
Timothy J. Baek 2024-01-22 21:07:40 -08:00
parent b246c62d2c
commit 4e1b52e91b
3 changed files with 60 additions and 5 deletions

View file

@ -11,14 +11,15 @@ from apps.web.routers import (
configs,
utils,
)
from config import WEBUI_VERSION, WEBUI_AUTH
from config import WEBUI_VERSION, WEBUI_AUTH, DEFAULT_MODELS, DEFAULT_PROMPT_SUGGESTIONS
app = FastAPI()
origins = ["*"]
app.state.ENABLE_SIGNUP = True
app.state.DEFAULT_MODELS = None
app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.add_middleware(
CORSMiddleware,
@ -46,4 +47,5 @@ async def get_status():
"version": WEBUI_VERSION,
"auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
}

View file

@ -21,15 +21,24 @@ class SetDefaultModelsForm(BaseModel):
models: str
class PromptSuggestion(BaseModel):
title: List[str]
content: str
class SetDefaultSuggestionsForm(BaseModel):
suggestions: List[PromptSuggestion]
############################
# SetDefaultModels
############################
@router.post("/default/models", response_model=str)
async def set_global_default_models(request: Request,
form_data: SetDefaultModelsForm,
user=Depends(get_current_user)):
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
@ -38,3 +47,19 @@ async def set_global_default_models(request: Request,
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.post("/default/suggestions", response_model=str)
async def set_global_default_suggestions(
request: Request,
form_data: SetDefaultSuggestionsForm,
user=Depends(get_current_user),
):
if user.role == "admin":
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = form_data.suggestions
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)