feat: model filter backend

This commit is contained in:
Timothy J. Baek 2024-03-09 21:19:20 -08:00
parent 6d5ff8d469
commit b550e23bf6
4 changed files with 61 additions and 6 deletions

View file

@ -29,6 +29,10 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
@ -129,9 +133,16 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user) url_idx: Optional[int] = None, user=Depends(get_current_user)
): ):
if url_idx == None: if url_idx == None:
return await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["models"] = filter(
lambda model: model["name"] in app.state.MODEL_LIST,
models["models"],
)
return models
return models
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
try: try:

View file

@ -34,6 +34,9 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
@ -186,12 +189,19 @@ async def get_all_models():
return models return models
# , user=Depends(get_current_user)
@app.get("/models") @app.get("/models")
@app.get("/models/{url_idx}") @app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
return await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["data"] = filter(
lambda model: model["id"] in app.state.MODEL_LIST,
models["data"],
)
return models
return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = app.state.OPENAI_API_BASE_URLS[url_idx]
try: try:

View file

@ -23,7 +23,11 @@ from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from pydantic import BaseModel
from typing import List
from utils.utils import get_admin_user
from apps.rag.utils import query_doc, query_collection, rag_template from apps.rag.utils import query_doc, query_collection, rag_template
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles):
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
@ -211,6 +218,33 @@ async def get_app_config():
} }
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
@app.post("/api/config/model/filter")
async def get_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.MODEL_FILTER_ENABLED = form_data.enabled
app.state.MODEL_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
openai_app.state.MODEL_LIST = app.state.MODEL_LIST
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
@app.get("/api/version") @app.get("/api/version")
async def get_app_config(): async def get_app_config():

View file

@ -19,7 +19,7 @@
export let suggestionPrompts = []; export let suggestionPrompts = [];
export let autoScroll = true; export let autoScroll = true;
let chatTextAreaElement:HTMLTextAreaElement let chatTextAreaElement: HTMLTextAreaElement;
let filesInputElement; let filesInputElement;
let promptsElement; let promptsElement;