diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index f62857b7..078ae2a3 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware -from apps.web.routers import auths, users, chats, modelfiles, utils +from apps.web.routers import auths, users, chats, modelfiles, configs, utils from config import WEBUI_VERSION, WEBUI_AUTH app = FastAPI() @@ -9,6 +9,7 @@ app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = True +app.state.DEFAULT_MODELS = "llava:13b" app.add_middleware( CORSMiddleware, @@ -19,13 +20,18 @@ app.add_middleware( ) app.include_router(auths.router, prefix="/auths", tags=["auths"]) - app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) @app.get("/") async def get_status(): - return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH} + return { + "status": True, + "version": WEBUI_VERSION, + "auth": WEBUI_AUTH, + "default_models": app.state.DEFAULT_MODELS, + } diff --git a/backend/apps/web/routers/configs.py b/backend/apps/web/routers/configs.py new file mode 100644 index 00000000..b57fae3d --- /dev/null +++ b/backend/apps/web/routers/configs.py @@ -0,0 +1,41 @@ +from fastapi import Response, Request +from fastapi import Depends, FastAPI, HTTPException, status +from datetime import datetime, timedelta +from typing import List, Union + +from fastapi import APIRouter +from pydantic import BaseModel +import time +import uuid + +from apps.web.models.users import Users + + +from utils.utils import get_password_hash, get_current_user, create_token +from utils.misc import get_gravatar_url, validate_email_format +from constants import ERROR_MESSAGES + +router = APIRouter() + + +class SetDefaultModelsForm(BaseModel): + models: str + + +############################ +# SetDefaultModels +############################ + + +@router.post("/default/models", response_model=str) +async def set_global_default_models( + request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user) +): + if user.role == "admin": + request.app.state.DEFAULT_MODELS = form_data.models + return request.app.state.DEFAULT_MODELS + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts new file mode 100644 index 00000000..9762c41f --- /dev/null +++ b/src/lib/apis/configs/index.ts @@ -0,0 +1,31 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const setDefaultModels = async (token: string, models: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: models + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 4e3347ff..30920140 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -6,7 +6,7 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import { models, modelfiles, user, settings, chats, chatId } from '$lib/stores'; + import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; import { OLLAMA_API_BASE_URL } from '$lib/constants'; import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; @@ -90,9 +90,18 @@ messages: {}, currentId: null }; - selectedModels = $page.url.searchParams.get('models') - ? $page.url.searchParams.get('models')?.split(',') - : $settings.models ?? ['']; + + console.log($config); + + if ($page.url.searchParams.get('models')) { + selectedModels = $page.url.searchParams.get('models')?.split(','); + } else if ($settings?.models) { + selectedModels = $settings?.models; + } else if ($config?.default_models) { + selectedModels = $config?.default_models.split(','); + } else { + selectedModels = ['']; + } let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}'); settings.set({ @@ -383,13 +392,13 @@ } : { content: message.content }) })), - seed: $settings.options.seed ?? undefined, - stop: $settings.options.stop ?? undefined, - temperature: $settings.options.temperature ?? undefined, - top_p: $settings.options.top_p ?? undefined, - num_ctx: $settings.options.num_ctx ?? undefined, - frequency_penalty: $settings.options.repeat_penalty ?? undefined, - max_tokens: $settings.options.num_predict ?? undefined + seed: $settings?.options?.seed ?? undefined, + stop: $settings?.options?.stop ?? undefined, + temperature: $settings?.options?.temperature ?? undefined, + top_p: $settings?.options?.top_p ?? undefined, + num_ctx: $settings?.options?.num_ctx ?? undefined, + frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, + max_tokens: $settings?.options?.num_predict ?? undefined }) } ).catch((err) => { diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 62c271bf..e3a70ca2 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -409,13 +409,13 @@ } : { content: message.content }) })), - seed: $settings.options.seed ?? undefined, - stop: $settings.options.stop ?? undefined, - temperature: $settings.options.temperature ?? undefined, - top_p: $settings.options.top_p ?? undefined, - num_ctx: $settings.options.num_ctx ?? undefined, - frequency_penalty: $settings.options.repeat_penalty ?? undefined, - max_tokens: $settings.options.num_predict ?? undefined + seed: $settings?.options?.seed ?? undefined, + stop: $settings?.options?.stop ?? undefined, + temperature: $settings?.options?.temperature ?? undefined, + top_p: $settings?.options?.top_p ?? undefined, + num_ctx: $settings?.options?.num_ctx ?? undefined, + frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, + max_tokens: $settings?.options?.num_predict ?? undefined }) } ).catch((err) => {