forked from open-webui/open-webui
feat: model filter list env var
This commit is contained in:
parent
bcabd3df84
commit
a4ca1fc5c4
4 changed files with 41 additions and 16 deletions
|
@ -15,7 +15,7 @@ import asyncio
|
||||||
from apps.web.models.users import Users
|
from apps.web.models.users import Users
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
from utils.utils import decode_token, get_current_user, get_admin_user
|
from utils.utils import decode_token, get_current_user, get_admin_user
|
||||||
from config import OLLAMA_BASE_URLS
|
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
|
||||||
|
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
@ -30,8 +30,8 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
app.state.MODEL_FILTER_ENABLED = False
|
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||||
app.state.MODEL_LIST = []
|
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||||
|
|
||||||
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
@ -140,7 +140,7 @@ async def get_ollama_tags(
|
||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
models["models"] = list(
|
models["models"] = list(
|
||||||
filter(
|
filter(
|
||||||
lambda model: model["name"] in app.state.MODEL_LIST,
|
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
||||||
models["models"],
|
models["models"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,13 @@ from utils.utils import (
|
||||||
get_verified_user,
|
get_verified_user,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
)
|
)
|
||||||
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
|
from config import (
|
||||||
|
OPENAI_API_BASE_URLS,
|
||||||
|
OPENAI_API_KEYS,
|
||||||
|
CACHE_DIR,
|
||||||
|
MODEL_FILTER_ENABLED,
|
||||||
|
MODEL_FILTER_LIST,
|
||||||
|
)
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,8 +40,8 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.state.MODEL_FILTER_ENABLED = False
|
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||||
app.state.MODEL_LIST = []
|
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||||
|
|
||||||
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||||
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
|
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||||
|
@ -198,7 +204,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
models["data"] = list(
|
models["data"] = list(
|
||||||
filter(
|
filter(
|
||||||
lambda model: model["id"] in app.state.MODEL_LIST,
|
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
||||||
models["data"],
|
models["data"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
|
||||||
USER_PERMISSIONS = {"chat": {"deletion": True}}
|
USER_PERMISSIONS = {"chat": {"deletion": True}}
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
|
||||||
|
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
|
||||||
|
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
|
||||||
|
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
# WEBUI_VERSION
|
# WEBUI_VERSION
|
||||||
####################################
|
####################################
|
||||||
|
|
|
@ -30,7 +30,15 @@ from typing import List
|
||||||
from utils.utils import get_admin_user
|
from utils.utils import get_admin_user
|
||||||
from apps.rag.utils import query_doc, query_collection, rag_template
|
from apps.rag.utils import query_doc, query_collection, rag_template
|
||||||
|
|
||||||
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
from config import (
|
||||||
|
WEBUI_NAME,
|
||||||
|
ENV,
|
||||||
|
VERSION,
|
||||||
|
CHANGELOG,
|
||||||
|
FRONTEND_BUILD_DIR,
|
||||||
|
MODEL_FILTER_ENABLED,
|
||||||
|
MODEL_FILTER_LIST,
|
||||||
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,8 +55,8 @@ class SPAStaticFiles(StaticFiles):
|
||||||
|
|
||||||
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
||||||
|
|
||||||
app.state.MODEL_FILTER_ENABLED = False
|
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
|
||||||
app.state.MODEL_LIST = []
|
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
|
@ -222,7 +230,10 @@ async def get_app_config():
|
||||||
|
|
||||||
@app.get("/api/config/model/filter")
|
@app.get("/api/config/model/filter")
|
||||||
async def get_model_filter_config(user=Depends(get_admin_user)):
|
async def get_model_filter_config(user=Depends(get_admin_user)):
|
||||||
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
|
return {
|
||||||
|
"enabled": app.state.MODEL_FILTER_ENABLED,
|
||||||
|
"models": app.state.MODEL_FILTER_LIST,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelFilterConfigForm(BaseModel):
|
class ModelFilterConfigForm(BaseModel):
|
||||||
|
@ -236,15 +247,18 @@ async def get_model_filter_config(
|
||||||
):
|
):
|
||||||
|
|
||||||
app.state.MODEL_FILTER_ENABLED = form_data.enabled
|
app.state.MODEL_FILTER_ENABLED = form_data.enabled
|
||||||
app.state.MODEL_LIST = form_data.models
|
app.state.MODEL_FILTER_LIST = form_data.models
|
||||||
|
|
||||||
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||||
ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
|
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
||||||
|
|
||||||
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||||
openai_app.state.MODEL_LIST = app.state.MODEL_LIST
|
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
|
||||||
|
|
||||||
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
|
return {
|
||||||
|
"enabled": app.state.MODEL_FILTER_ENABLED,
|
||||||
|
"models": app.state.MODEL_FILTER_LIST,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
|
|
Loading…
Reference in a new issue