From a4ca1fc5c4e417a90416c72032e0b749dbef6703 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 9 Mar 2024 21:47:01 -0800 Subject: [PATCH] feat: model filter list env var --- backend/apps/ollama/main.py | 8 ++++---- backend/apps/openai/main.py | 14 ++++++++++---- backend/config.py | 5 +++++ backend/main.py | 30 ++++++++++++++++++++++-------- 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 532e5523..5ecbaa29 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -15,7 +15,7 @@ import asyncio from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user, get_admin_user -from config import OLLAMA_BASE_URLS +from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST from typing import Optional, List, Union @@ -30,8 +30,8 @@ app.add_middleware( ) -app.state.MODEL_FILTER_ENABLED = False -app.state.MODEL_LIST = [] +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -140,7 +140,7 @@ async def get_ollama_tags( if user.role == "user": models["models"] = list( filter( - lambda model: model["name"] in app.state.MODEL_LIST, + lambda model: model["name"] in app.state.MODEL_FILTER_LIST, models["models"], ) ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 546de3d5..e902bea2 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -18,7 +18,13 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR +from config import ( + OPENAI_API_BASE_URLS, + OPENAI_API_KEYS, + CACHE_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) from typing import List, Optional @@ -34,8 +40,8 @@ app.add_middleware( allow_headers=["*"], ) -app.state.MODEL_FILTER_ENABLED = False -app.state.MODEL_LIST = [] +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -198,7 +204,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use if user.role == "user": models["data"] = list( filter( - lambda model: model["id"] in app.state.MODEL_LIST, + lambda model: model["id"] in app.state.MODEL_FILTER_LIST, models["data"], ) ) diff --git a/backend/config.py b/backend/config.py index 05ed686b..019e44e0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") USER_PERMISSIONS = {"chat": {"deletion": True}} +MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) +MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") +MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] + + #################################### # WEBUI_VERSION #################################### diff --git a/backend/main.py b/backend/main.py index 07543e7e..c7523ec6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -30,7 +30,15 @@ from typing import List from utils.utils import get_admin_user from apps.rag.utils import query_doc, query_collection, rag_template -from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR +from config import ( + WEBUI_NAME, + ENV, + VERSION, + CHANGELOG, + FRONTEND_BUILD_DIR, + MODEL_FILTER_ENABLED, + MODEL_FILTER_LIST, +) from constants import ERROR_MESSAGES @@ -47,8 +55,8 @@ class SPAStaticFiles(StaticFiles): app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) -app.state.MODEL_FILTER_ENABLED = False -app.state.MODEL_LIST = [] +app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED +app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST origins = ["*"] @@ -222,7 +230,10 @@ async def get_app_config(): @app.get("/api/config/model/filter") async def get_model_filter_config(user=Depends(get_admin_user)): - return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } class ModelFilterConfigForm(BaseModel): @@ -236,15 +247,18 @@ async def get_model_filter_config( ): app.state.MODEL_FILTER_ENABLED = form_data.enabled - app.state.MODEL_LIST = form_data.models + app.state.MODEL_FILTER_LIST = form_data.models ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED - ollama_app.state.MODEL_LIST = app.state.MODEL_LIST + ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED - openai_app.state.MODEL_LIST = app.state.MODEL_LIST + openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST - return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} + return { + "enabled": app.state.MODEL_FILTER_ENABLED, + "models": app.state.MODEL_FILTER_LIST, + } @app.get("/api/version")