forked from open-webui/open-webui
feat: model filter backend
This commit is contained in:
parent
6d5ff8d469
commit
b550e23bf6
4 changed files with 61 additions and 6 deletions
|
@ -29,6 +29,10 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.state.MODEL_FILTER_ENABLED = False
|
||||||
|
app.state.MODEL_LIST = []
|
||||||
|
|
||||||
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
|
@ -129,9 +133,16 @@ async def get_all_models():
|
||||||
async def get_ollama_tags(
|
async def get_ollama_tags(
|
||||||
url_idx: Optional[int] = None, user=Depends(get_current_user)
|
url_idx: Optional[int] = None, user=Depends(get_current_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
return await get_all_models()
|
models = await get_all_models()
|
||||||
|
if app.state.MODEL_FILTER_ENABLED:
|
||||||
|
if user.role == "user":
|
||||||
|
models["models"] = filter(
|
||||||
|
lambda model: model["name"] in app.state.MODEL_LIST,
|
||||||
|
models["models"],
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
return models
|
||||||
else:
|
else:
|
||||||
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -34,6 +34,9 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.state.MODEL_FILTER_ENABLED = False
|
||||||
|
app.state.MODEL_LIST = []
|
||||||
|
|
||||||
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||||
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
|
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||||
|
|
||||||
|
@ -186,12 +189,19 @@ async def get_all_models():
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
# , user=Depends(get_current_user)
|
|
||||||
@app.get("/models")
|
@app.get("/models")
|
||||||
@app.get("/models/{url_idx}")
|
@app.get("/models/{url_idx}")
|
||||||
async def get_models(url_idx: Optional[int] = None):
|
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
return await get_all_models()
|
models = await get_all_models()
|
||||||
|
if app.state.MODEL_FILTER_ENABLED:
|
||||||
|
if user.role == "user":
|
||||||
|
models["data"] = filter(
|
||||||
|
lambda model: model["id"] in app.state.MODEL_LIST,
|
||||||
|
models["data"],
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
return models
|
||||||
else:
|
else:
|
||||||
url = app.state.OPENAI_API_BASE_URLS[url_idx]
|
url = app.state.OPENAI_API_BASE_URLS[url_idx]
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -23,7 +23,11 @@ from apps.images.main import app as images_app
|
||||||
from apps.rag.main import app as rag_app
|
from apps.rag.main import app as rag_app
|
||||||
from apps.web.main import app as webui_app
|
from apps.web.main import app as webui_app
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
from utils.utils import get_admin_user
|
||||||
from apps.rag.utils import query_doc, query_collection, rag_template
|
from apps.rag.utils import query_doc, query_collection, rag_template
|
||||||
|
|
||||||
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
|
||||||
|
@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles):
|
||||||
|
|
||||||
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
|
||||||
|
|
||||||
|
app.state.MODEL_FILTER_ENABLED = False
|
||||||
|
app.state.MODEL_LIST = []
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
@ -211,6 +218,33 @@ async def get_app_config():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/config/model/filter")
|
||||||
|
async def get_model_filter_config(user=Depends(get_admin_user)):
|
||||||
|
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFilterConfigForm(BaseModel):
|
||||||
|
enabled: bool
|
||||||
|
models: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/config/model/filter")
|
||||||
|
async def get_model_filter_config(
|
||||||
|
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
|
||||||
|
app.state.MODEL_FILTER_ENABLED = form_data.enabled
|
||||||
|
app.state.MODEL_LIST = form_data.models
|
||||||
|
|
||||||
|
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||||
|
ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
|
||||||
|
|
||||||
|
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
|
||||||
|
openai_app.state.MODEL_LIST = app.state.MODEL_LIST
|
||||||
|
|
||||||
|
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
async def get_app_config():
|
async def get_app_config():
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
export let suggestionPrompts = [];
|
export let suggestionPrompts = [];
|
||||||
export let autoScroll = true;
|
export let autoScroll = true;
|
||||||
let chatTextAreaElement:HTMLTextAreaElement
|
let chatTextAreaElement: HTMLTextAreaElement;
|
||||||
let filesInputElement;
|
let filesInputElement;
|
||||||
|
|
||||||
let promptsElement;
|
let promptsElement;
|
||||||
|
|
Loading…
Reference in a new issue