forked from open-webui/open-webui
		
	feat: model filter backend
This commit is contained in:
		
							parent
							
								
									6d5ff8d469
								
							
						
					
					
						commit
						b550e23bf6
					
				
					 4 changed files with 61 additions and 6 deletions
				
			
		|  | @ -29,6 +29,10 @@ app.add_middleware( | |||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS | ||||
| app.state.MODELS = {} | ||||
| 
 | ||||
|  | @ -129,9 +133,16 @@ async def get_all_models(): | |||
| async def get_ollama_tags( | ||||
|     url_idx: Optional[int] = None, user=Depends(get_current_user) | ||||
| ): | ||||
| 
 | ||||
|     if url_idx == None: | ||||
|         return await get_all_models() | ||||
|         models = await get_all_models() | ||||
|         if app.state.MODEL_FILTER_ENABLED: | ||||
|             if user.role == "user": | ||||
|                 models["models"] = filter( | ||||
|                     lambda model: model["name"] in app.state.MODEL_LIST, | ||||
|                     models["models"], | ||||
|                 ) | ||||
|                 return models | ||||
|         return models | ||||
|     else: | ||||
|         url = app.state.OLLAMA_BASE_URLS[url_idx] | ||||
|         try: | ||||
|  |  | |||
|  | @ -34,6 +34,9 @@ app.add_middleware( | |||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS | ||||
| app.state.OPENAI_API_KEYS = OPENAI_API_KEYS | ||||
| 
 | ||||
|  | @ -186,12 +189,19 @@ async def get_all_models(): | |||
|     return models | ||||
| 
 | ||||
| 
 | ||||
| # , user=Depends(get_current_user) | ||||
| @app.get("/models") | ||||
| @app.get("/models/{url_idx}") | ||||
| async def get_models(url_idx: Optional[int] = None): | ||||
| async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): | ||||
|     if url_idx == None: | ||||
|         return await get_all_models() | ||||
|         models = await get_all_models() | ||||
|         if app.state.MODEL_FILTER_ENABLED: | ||||
|             if user.role == "user": | ||||
|                 models["data"] = filter( | ||||
|                     lambda model: model["id"] in app.state.MODEL_LIST, | ||||
|                     models["data"], | ||||
|                 ) | ||||
|                 return models | ||||
|         return models | ||||
|     else: | ||||
|         url = app.state.OPENAI_API_BASE_URLS[url_idx] | ||||
|         try: | ||||
|  |  | |||
|  | @ -23,7 +23,11 @@ from apps.images.main import app as images_app | |||
| from apps.rag.main import app as rag_app | ||||
| from apps.web.main import app as webui_app | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| from typing import List | ||||
| 
 | ||||
| 
 | ||||
| from utils.utils import get_admin_user | ||||
| from apps.rag.utils import query_doc, query_collection, rag_template | ||||
| 
 | ||||
| from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR | ||||
|  | @ -43,6 +47,9 @@ class SPAStaticFiles(StaticFiles): | |||
| 
 | ||||
| app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = False | ||||
| app.state.MODEL_LIST = [] | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| app.add_middleware( | ||||
|  | @ -211,6 +218,33 @@ async def get_app_config(): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/config/model/filter") | ||||
| async def get_model_filter_config(user=Depends(get_admin_user)): | ||||
|     return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} | ||||
| 
 | ||||
| 
 | ||||
| class ModelFilterConfigForm(BaseModel): | ||||
|     enabled: bool | ||||
|     models: List[str] | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/api/config/model/filter") | ||||
| async def get_model_filter_config( | ||||
|     form_data: ModelFilterConfigForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     app.state.MODEL_FILTER_ENABLED = form_data.enabled | ||||
|     app.state.MODEL_LIST = form_data.models | ||||
| 
 | ||||
|     ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED | ||||
|     ollama_app.state.MODEL_LIST = app.state.MODEL_LIST | ||||
| 
 | ||||
|     openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED | ||||
|     openai_app.state.MODEL_LIST = app.state.MODEL_LIST | ||||
| 
 | ||||
|     return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST} | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/version") | ||||
| async def get_app_config(): | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ | |||
| 
 | ||||
| 	export let suggestionPrompts = []; | ||||
| 	export let autoScroll = true; | ||||
| 	let chatTextAreaElement:HTMLTextAreaElement | ||||
| 	let chatTextAreaElement: HTMLTextAreaElement; | ||||
| 	let filesInputElement; | ||||
| 
 | ||||
| 	let promptsElement; | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek