feat: model filter list env var

This commit is contained in:
Timothy J. Baek 2024-03-09 21:47:01 -08:00
parent bcabd3df84
commit a4ca1fc5c4
4 changed files with 41 additions and 16 deletions

View file

@ -15,7 +15,7 @@ import asyncio
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_BASE_URLS
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
from typing import Optional, List, Union
@ -30,8 +30,8 @@ app.add_middleware(
)
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
@ -140,7 +140,7 @@ async def get_ollama_tags(
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_LIST,
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
models["models"],
)
)

View file

@ -18,7 +18,13 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
from typing import List, Optional
@ -34,8 +40,8 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
@ -198,7 +204,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_LIST,
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
models["data"],
)
)