Merge pull request #1117 from open-webui/model-whitelist

feat: model filter (whitelist)
This commit is contained in:
Timothy Jaeryang Baek 2024-03-10 00:30:43 -05:00 committed by GitHub
commit bcabd3df84
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 241 additions and 88 deletions

View file

@ -29,6 +29,10 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
@ -129,9 +133,19 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user) url_idx: Optional[int] = None, user=Depends(get_current_user)
): ):
if url_idx == None: if url_idx == None:
return await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_LIST,
models["models"],
)
)
return models
return models
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
try: try:

View file

@ -34,6 +34,9 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
@ -186,12 +189,21 @@ async def get_all_models():
return models return models
# , user=Depends(get_current_user)
@app.get("/models") @app.get("/models")
@app.get("/models/{url_idx}") @app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
return await get_all_models() models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_LIST,
models["data"],
)
)
return models
return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = app.state.OPENAI_API_BASE_URLS[url_idx]
try: try:

View file

@ -23,7 +23,11 @@ from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from pydantic import BaseModel
from typing import List
from utils.utils import get_admin_user
from apps.rag.utils import query_doc, query_collection, rag_template from apps.rag.utils import query_doc, query_collection, rag_template
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
@ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles):
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = False
app.state.MODEL_LIST = []
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
@ -213,6 +220,33 @@ async def get_app_config():
} }
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
@app.post("/api/config/model/filter")
async def get_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.MODEL_FILTER_ENABLED = form_data.enabled
app.state.MODEL_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
openai_app.state.MODEL_LIST = app.state.MODEL_LIST
return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
@app.get("/api/version") @app.get("/api/version")
async def get_app_config(): async def get_app_config():

View file

@ -77,3 +77,65 @@ export const getVersionUpdates = async () => {
return res; return res;
}; };
export const getModelFilterConfig = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateModelFilterConfig = async (
token: string,
enabled: boolean,
models: string[]
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
enabled: enabled,
models: models
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -1,10 +1,14 @@
<script lang="ts"> <script lang="ts">
import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis';
import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths'; import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
import { getUserPermissions, updateUserPermissions } from '$lib/apis/users'; import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
import { models } from '$lib/stores';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
export let saveHandler: Function; export let saveHandler: Function;
let whitelistEnabled = false;
let whitelistModels = [''];
let permissions = { let permissions = {
chat: { chat: {
deletion: true deletion: true
@ -13,6 +17,13 @@
onMount(async () => { onMount(async () => {
permissions = await getUserPermissions(localStorage.token); permissions = await getUserPermissions(localStorage.token);
const res = await getModelFilterConfig(localStorage.token);
if (res) {
whitelistEnabled = res.enabled;
whitelistModels = res.models.length > 0 ? res.models : [''];
}
}); });
</script> </script>
@ -21,6 +32,8 @@
on:submit|preventDefault={async () => { on:submit|preventDefault={async () => {
// console.log('submit'); // console.log('submit');
await updateUserPermissions(localStorage.token, permissions); await updateUserPermissions(localStorage.token, permissions);
await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
saveHandler(); saveHandler();
}} }}
> >
@ -69,6 +82,106 @@
</button> </button>
</div> </div>
</div> </div>
<hr class=" dark:border-gray-700 my-2" />
<div class="mt-2 space-y-3 pr-1.5">
<div>
<div class="mb-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-sm font-medium">Manage Models</div>
</div>
</div>
<div class=" space-y-3">
<div>
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">Model Whitelisting</div>
<button
class=" text-xs font-medium text-gray-500"
type="button"
on:click={() => {
whitelistEnabled = !whitelistEnabled;
}}>{whitelistEnabled ? 'On' : 'Off'}</button
>
</div>
</div>
{#if whitelistEnabled}
<div>
<div class=" space-y-1.5">
{#each whitelistModels as modelId, modelIdx}
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={modelId}
placeholder="Select a model"
>
<option value="" disabled selected>Select a model</option>
{#each $models.filter((model) => model.id) as model}
<option value={model.id} class="bg-gray-100 dark:bg-gray-700"
>{model.name}</option
>
{/each}
</select>
</div>
{#if modelIdx === 0}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
type="button"
on:click={() => {
if (whitelistModels.at(-1) !== '') {
whitelistModels = [...whitelistModels, ''];
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
/>
</svg>
</button>
{:else}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
type="button"
on:click={() => {
whitelistModels.splice(modelIdx, 1);
whitelistModels = whitelistModels;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
</svg>
</button>
{/if}
</div>
{/each}
</div>
<div class="flex justify-end items-center text-xs mt-1.5 text-right">
<div class=" text-xs font-medium">
{whitelistModels.length} Model(s) Whitelisted
</div>
</div>
</div>
{/if}
</div>
</div>
</div>
</div> </div>
<div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex justify-end pt-3 text-sm font-medium">

View file

@ -912,88 +912,6 @@
{/if} {/if}
</div> </div>
</div> </div>
<!-- <div class="mt-2 space-y-3 pr-1.5">
<div>
<div class=" mb-2.5 text-sm font-medium">Add LiteLLM Model</div>
<div class="flex w-full mb-2">
<div class="flex-1">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter LiteLLM Model (e.g. ollama/mistral)"
bind:value={liteLLMModel}
autocomplete="off"
/>
</div>
</div>
<div class="flex justify-between items-center text-sm">
<div class=" font-medium">Advanced Model Params</div>
<button
class=" text-xs font-medium text-gray-500"
type="button"
on:click={() => {
showLiteLLMParams = !showLiteLLMParams;
}}>{showLiteLLMParams ? 'Hide' : 'Show'}</button
>
</div>
{#if showLiteLLMParams}
<div>
<div class=" mb-2.5 text-sm font-medium">LiteLLM API Key</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter LiteLLM API Key (e.g. os.environ/AZURE_API_KEY_CA)"
bind:value={liteLLMAPIKey}
autocomplete="off"
/>
</div>
</div>
</div>
<div>
<div class=" mb-2.5 text-sm font-medium">LiteLLM API Base URL</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter LiteLLM API Base URL"
bind:value={liteLLMAPIBase}
autocomplete="off"
/>
</div>
</div>
</div>
<div>
<div class=" mb-2.5 text-sm font-medium">LiteLLM API RPM</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
placeholder="Enter LiteLLM API RPM"
bind:value={liteLLMRPM}
autocomplete="off"
/>
</div>
</div>
</div>
{/if}
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
Not sure what to add?
<a
class=" text-gray-300 font-medium underline"
href="https://litellm.vercel.app/docs/proxy/configs#quick-start"
target="_blank"
>
Click here for help.
</a>
</div>
</div>
</div> -->
</div> </div>
</div> </div>
</div> </div>

View file

@ -267,7 +267,7 @@
<div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white"> <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]"> <div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]">
<div class="max-w-2xl mx-auto w-full px-3 p-3 md:px-0 h-full"> <div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10 h-full">
<div class=" flex flex-col h-full"> <div class=" flex flex-col h-full">
<div class="flex flex-col justify-between mb-2.5 gap-1"> <div class="flex flex-col justify-between mb-2.5 gap-1">
<div class="flex justify-between items-center gap-2"> <div class="flex justify-between items-center gap-2">