forked from open-webui/open-webui
		
	Merge pull request #1117 from open-webui/model-whitelist
feat: model filter (whitelist)
This commit is contained in:
		
						commit
						bcabd3df84
					
				
					 7 changed files with 241 additions and 88 deletions
				
			
		|  | @ -29,6 +29,10 @@ app.add_middleware( | |||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS | ||||
| app.state.MODELS = {} | ||||
| 
 | ||||
|  | @ -129,9 +133,19 @@ async def get_all_models(): | |||
| async def get_ollama_tags( | ||||
|     url_idx: Optional[int] = None, user=Depends(get_current_user) | ||||
| ): | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         return await get_all_models() | ||||
|         models = await get_all_models() | ||||
| 
 | ||||
|         if app.state.MODEL_FILTER_ENABLED: | ||||
|             if user.role == "user": | ||||
|                 models["models"] = list( | ||||
|                     filter( | ||||
|                         lambda model: model["name"] in app.state.MODEL_LIST, | ||||
|                         models["models"], | ||||
|                     ) | ||||
|                 ) | ||||
|                 return models | ||||
|         return models | ||||
|     else: | ||||
|         url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|         try: | ||||
|  |  | |||
|  | @ -34,6 +34,9 @@ app.add_middleware( | |||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS | ||||
| app.state.OPENAI_API_KEYS = OPENAI_API_KEYS | ||||
| 
 | ||||
|  | @ -186,12 +189,21 @@ async def get_all_models(): | |||
|     return models | ||||
| 
 | ||||
| 
 | ||||
| # , user=Depends(get_current_user) | ||||
| @app.get("/models") | ||||
| @app.get("/models/{url_idx}") | ||||
| async def get_models(url_idx: Optional[int] = None): | ||||
| async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): | ||||
|     if url_idx == None: | ||||
|         return await get_all_models() | ||||
|         models = await get_all_models() | ||||
|         if app.state.MODEL_FILTER_ENABLED: | ||||
|             if user.role == "user": | ||||
|                 models["data"] = list( | ||||
|                     filter( | ||||
|                         lambda model: model["id"] in app.state.MODEL_LIST, | ||||
|                         models["data"], | ||||
|                     ) | ||||
|                 ) | ||||
|                 return models | ||||
|         return models | ||||
|     else: | ||||
|         url = app.state.OPENAI_API_BASE_URLS[url_idx] | ||||
|         try: | ||||
|  |  | |||
|  | @ -23,7 +23,11 @@ from apps.images.main import app as images_app | |||
| from apps.rag.main import app as rag_app | ||||
| from apps.web.main import app as webui_app | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import List | ||||
| 
 | ||||
| 
 | ||||
| from utils.utils import get_admin_user | ||||
| from apps.rag.utils import query_doc, query_collection, rag_template | ||||
| 
 | ||||
| from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR | ||||
|  | @ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles): | |||
| 
 | ||||
| app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| app.add_middleware( | ||||
|  | @ -213,6 +220,33 @@ async def get_app_config(): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/config/model/filter") | ||||
| async def get_model_filter_config(user=Depends(get_admin_user)): | ||||
|     return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} | ||||
| 
 | ||||
| 
 | ||||
| class ModelFilterConfigForm(BaseModel): | ||||
|     enabled: bool | ||||
|     models: List[str] | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/api/config/model/filter") | ||||
| async def get_model_filter_config( | ||||
|     form_data: ModelFilterConfigForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     app.state.MODEL_FILTER_ENABLED = form_data.enabled | ||||
|     app.state.MODEL_LIST = form_data.models | ||||
| 
 | ||||
|     ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED | ||||
|     ollama_app.state.MODEL_LIST = app.state.MODEL_LIST | ||||
| 
 | ||||
|     openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED | ||||
|     openai_app.state.MODEL_LIST = app.state.MODEL_LIST | ||||
| 
 | ||||
|     return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/version") | ||||
| async def get_app_config(): | ||||
| 
 | ||||
|  |  | |||
|  | @ -77,3 +77,65 @@ export const getVersionUpdates = async () => { | |||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getModelFilterConfig = async (token: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const updateModelFilterConfig = async ( | ||||
| 	token: string, | ||||
| 	enabled: boolean, | ||||
| 	models: string[] | ||||
| ) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			enabled: enabled, | ||||
| 			models: models | ||||
| 		}) | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
|  |  | |||
|  | @ -1,10 +1,14 @@ | |||
| <script lang="ts"> | ||||
| 	import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis'; | ||||
| 	import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths'; | ||||
| 	import { getUserPermissions, updateUserPermissions } from '$lib/apis/users'; | ||||
| 	import { models } from '$lib/stores'; | ||||
| 	import { onMount } from 'svelte'; | ||||
| 
 | ||||
| 	export let saveHandler: Function; | ||||
| 
 | ||||
| 	let whitelistEnabled = false; | ||||
| 	let whitelistModels = ['']; | ||||
| 	let permissions = { | ||||
| 		chat: { | ||||
| 			deletion: true | ||||
|  | @ -13,6 +17,13 @@ | |||
| 
 | ||||
| 	onMount(async () => { | ||||
| 		permissions = await getUserPermissions(localStorage.token); | ||||
| 
 | ||||
| 		const res = await getModelFilterConfig(localStorage.token); | ||||
| 		if (res) { | ||||
| 			whitelistEnabled = res.enabled; | ||||
| 
 | ||||
| 			whitelistModels = res.models.length > 0 ? res.models : ['']; | ||||
| 		} | ||||
| 	}); | ||||
| </script> | ||||
| 
 | ||||
|  | @ -21,6 +32,8 @@ | |||
| 	on:submit|preventDefault={async () => { | ||||
| 		// console.log('submit'); | ||||
| 		await updateUserPermissions(localStorage.token, permissions); | ||||
| 
 | ||||
| 		await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels); | ||||
| 		saveHandler(); | ||||
| 	}} | ||||
| > | ||||
|  | @ -69,6 +82,106 @@ | |||
| 				</button> | ||||
| 			</div> | ||||
| 		</div> | ||||
| 
 | ||||
| 		<hr class=" dark:border-gray-700 my-2" /> | ||||
| 
 | ||||
| 		<div class="mt-2 space-y-3 pr-1.5"> | ||||
| 			<div> | ||||
| 				<div class="mb-2"> | ||||
| 					<div class="flex justify-between items-center text-xs"> | ||||
| 						<div class=" text-sm font-medium">Manage Models</div> | ||||
| 					</div> | ||||
| 				</div> | ||||
| 
 | ||||
| 				<div class=" space-y-3"> | ||||
| 					<div> | ||||
| 						<div class="flex justify-between items-center text-xs"> | ||||
| 							<div class=" text-xs font-medium">Model Whitelisting</div> | ||||
| 
 | ||||
| 							<button | ||||
| 								class=" text-xs font-medium text-gray-500" | ||||
| 								type="button" | ||||
| 								on:click={() => { | ||||
| 									whitelistEnabled = !whitelistEnabled; | ||||
| 								}}>{whitelistEnabled ? 'On' : 'Off'}</button | ||||
| 							> | ||||
| 						</div> | ||||
| 					</div> | ||||
| 
 | ||||
| 					{#if whitelistEnabled} | ||||
| 						<div> | ||||
| 							<div class=" space-y-1.5"> | ||||
| 								{#each whitelistModels as modelId, modelIdx} | ||||
| 									<div class="flex w-full"> | ||||
| 										<div class="flex-1 mr-2"> | ||||
| 											<select | ||||
| 												class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 												bind:value={modelId} | ||||
| 												placeholder="Select a model" | ||||
| 											> | ||||
| 												<option value="" disabled selected>Select a model</option> | ||||
| 												{#each $models.filter((model) => model.id) as model} | ||||
| 													<option value={model.id} class="bg-gray-100 dark:bg-gray-700" | ||||
| 														>{model.name}</option | ||||
| 													> | ||||
| 												{/each} | ||||
| 											</select> | ||||
| 										</div> | ||||
| 
 | ||||
| 										{#if modelIdx === 0} | ||||
| 											<button | ||||
| 												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition" | ||||
| 												type="button" | ||||
| 												on:click={() => { | ||||
| 													if (whitelistModels.at(-1) !== '') { | ||||
| 														whitelistModels = [...whitelistModels, '']; | ||||
| 													} | ||||
| 												}} | ||||
| 											> | ||||
| 												<svg | ||||
| 													xmlns="http://www.w3.org/2000/svg" | ||||
| 													viewBox="0 0 16 16" | ||||
| 													fill="currentColor" | ||||
| 													class="w-4 h-4" | ||||
| 												> | ||||
| 													<path | ||||
| 														d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z" | ||||
| 													/> | ||||
| 												</svg> | ||||
| 											</button> | ||||
| 										{:else} | ||||
| 											<button | ||||
| 												class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition" | ||||
| 												type="button" | ||||
| 												on:click={() => { | ||||
| 													whitelistModels.splice(modelIdx, 1); | ||||
| 													whitelistModels = whitelistModels; | ||||
| 												}} | ||||
| 											> | ||||
| 												<svg | ||||
| 													xmlns="http://www.w3.org/2000/svg" | ||||
| 													viewBox="0 0 16 16" | ||||
| 													fill="currentColor" | ||||
| 													class="w-4 h-4" | ||||
| 												> | ||||
| 													<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" /> | ||||
| 												</svg> | ||||
| 											</button> | ||||
| 										{/if} | ||||
| 									</div> | ||||
| 								{/each} | ||||
| 							</div> | ||||
| 
 | ||||
| 							<div class="flex justify-end items-center text-xs mt-1.5 text-right"> | ||||
| 								<div class=" text-xs font-medium"> | ||||
| 									{whitelistModels.length} Model(s) Whitelisted | ||||
| 								</div> | ||||
| 							</div> | ||||
| 						</div> | ||||
| 					{/if} | ||||
| 				</div> | ||||
| 			</div> | ||||
| 		</div> | ||||
| 	</div> | ||||
| 
 | ||||
| 	<div class="flex justify-end pt-3 text-sm font-medium"> | ||||
|  |  | |||
|  | @ -912,88 +912,6 @@ | |||
| 					{/if} | ||||
| 				</div> | ||||
| 			</div> | ||||
| 
 | ||||
| 			<!-- <div class="mt-2 space-y-3 pr-1.5"> | ||||
| 				<div> | ||||
| 					<div class=" mb-2.5 text-sm font-medium">Add LiteLLM Model</div> | ||||
| 					<div class="flex w-full mb-2"> | ||||
| 						<div class="flex-1"> | ||||
| 							<input | ||||
| 								class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" | ||||
| 								placeholder="Enter LiteLLM Model (e.g. ollama/mistral)" | ||||
| 								bind:value={liteLLMModel} | ||||
| 								autocomplete="off" | ||||
| 							/> | ||||
| 						</div> | ||||
| 					</div> | ||||
| 
 | ||||
| 					<div class="flex justify-between items-center text-sm"> | ||||
| 						<div class="  font-medium">Advanced Model Params</div> | ||||
| 						<button | ||||
| 							class=" text-xs font-medium text-gray-500" | ||||
| 							type="button" | ||||
| 							on:click={() => { | ||||
| 								showLiteLLMParams = !showLiteLLMParams; | ||||
| 							}}>{showLiteLLMParams ? 'Hide' : 'Show'}</button | ||||
| 						> | ||||
| 					</div> | ||||
| 
 | ||||
| 					{#if showLiteLLMParams} | ||||
| 						<div> | ||||
| 							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Key</div> | ||||
| 							<div class="flex w-full"> | ||||
| 								<div class="flex-1"> | ||||
| 									<input | ||||
| 										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" | ||||
| 										placeholder="Enter LiteLLM API Key (e.g. os.environ/AZURE_API_KEY_CA)" | ||||
| 										bind:value={liteLLMAPIKey} | ||||
| 										autocomplete="off" | ||||
| 									/> | ||||
| 								</div> | ||||
| 							</div> | ||||
| 						</div> | ||||
| 
 | ||||
| 						<div> | ||||
| 							<div class=" mb-2.5 text-sm font-medium">LiteLLM API Base URL</div> | ||||
| 							<div class="flex w-full"> | ||||
| 								<div class="flex-1"> | ||||
| 									<input | ||||
| 										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" | ||||
| 										placeholder="Enter LiteLLM API Base URL" | ||||
| 										bind:value={liteLLMAPIBase} | ||||
| 										autocomplete="off" | ||||
| 									/> | ||||
| 								</div> | ||||
| 							</div> | ||||
| 						</div> | ||||
| 
 | ||||
| 						<div> | ||||
| 							<div class=" mb-2.5 text-sm font-medium">LiteLLM API RPM</div> | ||||
| 							<div class="flex w-full"> | ||||
| 								<div class="flex-1"> | ||||
| 									<input | ||||
| 										class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none" | ||||
| 										placeholder="Enter LiteLLM API RPM" | ||||
| 										bind:value={liteLLMRPM} | ||||
| 										autocomplete="off" | ||||
| 									/> | ||||
| 								</div> | ||||
| 							</div> | ||||
| 						</div> | ||||
| 					{/if} | ||||
| 
 | ||||
| 					<div class="mt-2 text-xs text-gray-400 dark:text-gray-500"> | ||||
| 						Not sure what to add? | ||||
| 						<a | ||||
| 							class=" text-gray-300 font-medium underline" | ||||
| 							href="https://litellm.vercel.app/docs/proxy/configs#quick-start" | ||||
| 							target="_blank" | ||||
| 						> | ||||
| 							Click here for help. | ||||
| 						</a> | ||||
| 					</div> | ||||
| 				</div> | ||||
| 			</div> --> | ||||
| 		</div> | ||||
| 	</div> | ||||
| </div> | ||||
|  |  | |||
|  | @ -267,7 +267,7 @@ | |||
| 
 | ||||
| <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white"> | ||||
| 	<div class=" flex flex-col justify-between w-full overflow-y-auto h-[100dvh]"> | ||||
| 		<div class="max-w-2xl mx-auto w-full px-3 p-3 md:px-0 h-full"> | ||||
| 		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10 h-full"> | ||||
| 			<div class=" flex flex-col h-full"> | ||||
| 				<div class="flex flex-col justify-between mb-2.5 gap-1"> | ||||
| 					<div class="flex justify-between items-center gap-2"> | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek