From b550e23bf6fcd54a4a31841e1201ea1b0e3937a3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 9 Mar 2024 21:19:20 -0800 Subject: [PATCH] feat: model filter backend --- backend/apps/ollama/main.py | 15 +++++++-- backend/apps/openai/main.py | 16 ++++++++-- backend/main.py | 34 +++++++++++++++++++++ src/lib/components/chat/MessageInput.svelte | 2 +- 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f8f166d0..97806ba7 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,6 +29,10 @@ app.add_middleware( allow_headers=["*"], ) + +app.state.MODEL_FILTER_ENABLED = False +app.state.MODEL_LIST = [] + app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -129,9 +133,16 @@ async def get_all_models(): async def get_ollama_tags( url_idx: Optional[int] = None, user=Depends(get_current_user) ): - if url_idx == None: - return await get_all_models() + models = await get_all_models() + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["models"] = filter( + lambda model: model["name"] in app.state.MODEL_LIST, + models["models"], + ) + return models + return models else: url = app.state.OLLAMA_BASE_URLS[url_idx] try: diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6b9c542e..ec3152e3 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -34,6 +34,9 @@ app.add_middleware( allow_headers=["*"], ) +app.state.MODEL_FILTER_ENABLED = False +app.state.MODEL_LIST = [] + app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -186,12 +189,19 @@ async def get_all_models(): return models -# , user=Depends(get_current_user) @app.get("/models") @app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None): +async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): if url_idx == None: - return await get_all_models() + models = await get_all_models() + if app.state.MODEL_FILTER_ENABLED: + if user.role == "user": + models["data"] = filter( + lambda model: model["id"] in app.state.MODEL_LIST, + models["data"], + ) + return models + return models else: url = app.state.OPENAI_API_BASE_URLS[url_idx] try: diff --git a/backend/main.py b/backend/main.py index e63f91a0..01c59c15 100644 --- a/backend/main.py +++ b/backend/main.py @@ -23,7 +23,11 @@ from apps.images.main import app as images_app from apps.rag.main import app as rag_app from apps.web.main import app as webui_app +from pydantic import BaseModel +from typing import List + +from utils.utils import get_admin_user from apps.rag.utils import query_doc, query_collection, rag_template from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR @@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles): app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) +app.state.MODEL_FILTER_ENABLED = False +app.state.MODEL_LIST = [] + origins = ["*"] app.add_middleware( @@ -211,6 +218,33 @@ async def get_app_config(): } +@app.get("/api/config/model/filter") +async def get_model_filter_config(user=Depends(get_admin_user)): + return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} + + +class ModelFilterConfigForm(BaseModel): + enabled: bool + models: List[str] + + +@app.post("/api/config/model/filter") +async def get_model_filter_config( + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) +): + + app.state.MODEL_FILTER_ENABLED = form_data.enabled + app.state.MODEL_LIST = form_data.models + + ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + ollama_app.state.MODEL_LIST = app.state.MODEL_LIST + + openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED + openai_app.state.MODEL_LIST = app.state.MODEL_LIST + + return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} + + @app.get("/api/version") async def get_app_config(): diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 5a7e8a05..0e396cad 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -19,7 +19,7 @@ export let suggestionPrompts = []; export let autoScroll = true; - let chatTextAreaElement:HTMLTextAreaElement + let chatTextAreaElement: HTMLTextAreaElement; let filesInputElement; let promptsElement;