forked from open-webui/open-webui
		
	Merge conflicts
This commit is contained in:
		
						commit
						f74f2ea765
					
				
					 30 changed files with 2318 additions and 257 deletions
				
			
		|  | @ -3,14 +3,26 @@ import logging | |||
| from litellm.proxy.proxy_server import ProxyConfig, initialize | ||||
| from litellm.proxy.proxy_server import app | ||||
| 
 | ||||
| from fastapi import FastAPI, Request, Depends, status | ||||
| from fastapi import FastAPI, Request, Depends, status, Response | ||||
| from fastapi.responses import JSONResponse | ||||
| 
 | ||||
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | ||||
| from starlette.responses import StreamingResponse | ||||
| import json | ||||
| 
 | ||||
| from utils.utils import get_http_authorization_cred, get_current_user | ||||
| from config import SRC_LOG_LEVELS, ENV | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["LITELLM"]) | ||||
| 
 | ||||
| 
 | ||||
| from config import ( | ||||
|     MODEL_FILTER_ENABLED, | ||||
|     MODEL_FILTER_LIST, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| proxy_config = ProxyConfig() | ||||
| 
 | ||||
| 
 | ||||
|  | @ -31,16 +43,58 @@ async def on_startup(): | |||
|     await startup() | ||||
| 
 | ||||
| 
 | ||||
| app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED | ||||
| app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | ||||
| 
 | ||||
| 
 | ||||
| @app.middleware("http") | ||||
| async def auth_middleware(request: Request, call_next): | ||||
|     auth_header = request.headers.get("Authorization", "") | ||||
|     request.state.user = None | ||||
| 
 | ||||
|     if ENV != "dev": | ||||
|         try: | ||||
|             user = get_current_user(get_http_authorization_cred(auth_header)) | ||||
|             log.debug(f"user: {user}") | ||||
|         except Exception as e: | ||||
|             return JSONResponse(status_code=400, content={"detail": str(e)}) | ||||
|     try: | ||||
|         user = get_current_user(get_http_authorization_cred(auth_header)) | ||||
|         log.debug(f"user: {user}") | ||||
|         request.state.user = user | ||||
|     except Exception as e: | ||||
|         return JSONResponse(status_code=400, content={"detail": str(e)}) | ||||
| 
 | ||||
|     response = await call_next(request) | ||||
|     return response | ||||
| 
 | ||||
| 
 | ||||
| class ModifyModelsResponseMiddleware(BaseHTTPMiddleware): | ||||
|     async def dispatch( | ||||
|         self, request: Request, call_next: RequestResponseEndpoint | ||||
|     ) -> Response: | ||||
| 
 | ||||
|         response = await call_next(request) | ||||
|         user = request.state.user | ||||
| 
 | ||||
|         if "/models" in request.url.path: | ||||
|             if isinstance(response, StreamingResponse): | ||||
|                 # Read the content of the streaming response | ||||
|                 body = b"" | ||||
|                 async for chunk in response.body_iterator: | ||||
|                     body += chunk | ||||
| 
 | ||||
|                 data = json.loads(body.decode("utf-8")) | ||||
| 
 | ||||
|                 if app.state.MODEL_FILTER_ENABLED: | ||||
|                     if user and user.role == "user": | ||||
|                         data["data"] = list( | ||||
|                             filter( | ||||
|                                 lambda model: model["id"] | ||||
|                                 in app.state.MODEL_FILTER_LIST, | ||||
|                                 data["data"], | ||||
|                             ) | ||||
|                         ) | ||||
| 
 | ||||
|                 # Modified Flag | ||||
|                 data["modified"] = True | ||||
|                 return JSONResponse(content=data) | ||||
| 
 | ||||
|         return response | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware(ModifyModelsResponseMiddleware) | ||||
|  |  | |||
|  | @ -712,7 +712,7 @@ class GenerateChatCompletionForm(BaseModel): | |||
|     format: Optional[str] = None | ||||
|     options: Optional[dict] = None | ||||
|     template: Optional[str] = None | ||||
|     stream: Optional[bool] = True | ||||
|     stream: Optional[bool] = None | ||||
|     keep_alive: Optional[Union[int, str]] = None | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ from config import ( | |||
|     DEFAULT_USER_ROLE, | ||||
|     ENABLE_SIGNUP, | ||||
|     USER_PERMISSIONS, | ||||
|     WEBHOOK_URL, | ||||
| ) | ||||
| 
 | ||||
| app = FastAPI() | ||||
|  | @ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS | |||
| app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS | ||||
| app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE | ||||
| app.state.USER_PERMISSIONS = USER_PERMISSIONS | ||||
| app.state.WEBHOOK_URL = WEBHOOK_URL | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware( | ||||
|  |  | |||
|  | @ -27,7 +27,8 @@ from utils.utils import ( | |||
|     create_token, | ||||
| ) | ||||
| from utils.misc import parse_duration, validate_email_format | ||||
| from constants import ERROR_MESSAGES | ||||
| from utils.webhook import post_webhook | ||||
| from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
|  | @ -155,6 +156,17 @@ async def signup(request: Request, form_data: SignupForm): | |||
|             ) | ||||
|             # response.set_cookie(key='token', value=token, httponly=True) | ||||
| 
 | ||||
|             if request.app.state.WEBHOOK_URL: | ||||
|                 post_webhook( | ||||
|                     request.app.state.WEBHOOK_URL, | ||||
|                     WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | ||||
|                     { | ||||
|                         "action": "signup", | ||||
|                         "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | ||||
|                         "user": user.model_dump_json(exclude_none=True), | ||||
|                     }, | ||||
|                 ) | ||||
| 
 | ||||
|             return { | ||||
|                 "token": token, | ||||
|                 "token_type": "Bearer", | ||||
|  |  | |||
|  | @ -320,13 +320,19 @@ DEFAULT_PROMPT_SUGGESTIONS = ( | |||
| 
 | ||||
| 
 | ||||
| DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending") | ||||
| USER_PERMISSIONS = {"chat": {"deletion": True}} | ||||
| 
 | ||||
| USER_PERMISSIONS_CHAT_DELETION = ( | ||||
|     os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} | ||||
| 
 | ||||
| 
 | ||||
| MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False) | ||||
| MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true" | ||||
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") | ||||
| MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] | ||||
| 
 | ||||
| WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "") | ||||
| 
 | ||||
| #################################### | ||||
| # WEBUI_VERSION | ||||
|  |  | |||
|  | @ -5,6 +5,13 @@ class MESSAGES(str, Enum): | |||
|     DEFAULT = lambda msg="": f"{msg if msg else ''}" | ||||
| 
 | ||||
| 
 | ||||
| class WEBHOOK_MESSAGES(str, Enum): | ||||
|     DEFAULT = lambda msg="": f"{msg if msg else ''}" | ||||
|     USER_SIGNUP = lambda username="": ( | ||||
|         f"New user signed up: {username}" if username else "New user signed up" | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class ERROR_MESSAGES(str, Enum): | ||||
|     def __str__(self) -> str: | ||||
|         return super().__str__() | ||||
|  | @ -46,7 +53,7 @@ class ERROR_MESSAGES(str, Enum): | |||
| 
 | ||||
|     PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." | ||||
|     INCORRECT_FORMAT = ( | ||||
|         lambda err="": f"Invalid format. Please use the correct format{err if err else ''}" | ||||
|         lambda err="": f"Invalid format. Please use the correct format{err}" | ||||
|     ) | ||||
|     RATE_LIMIT_EXCEEDED = "API rate limit exceeded" | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| { | ||||
|     "version": "0.0.1", | ||||
|     "version": 0, | ||||
|     "ui": { | ||||
|         "prompt_suggestions": [ | ||||
|             { | ||||
|  |  | |||
|  | @ -41,6 +41,7 @@ from config import ( | |||
|     MODEL_FILTER_LIST, | ||||
|     GLOBAL_LOG_LEVEL, | ||||
|     SRC_LOG_LEVELS, | ||||
|     WEBHOOK_URL, | ||||
| ) | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -64,6 +65,9 @@ app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) | |||
| app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED | ||||
| app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | ||||
| 
 | ||||
| app.state.WEBHOOK_URL = WEBHOOK_URL | ||||
| 
 | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -184,7 +188,7 @@ class ModelFilterConfigForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| @app.post("/api/config/model/filter") | ||||
| async def get_model_filter_config( | ||||
| async def update_model_filter_config( | ||||
|     form_data: ModelFilterConfigForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|  | @ -203,6 +207,28 @@ async def get_model_filter_config( | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/webhook") | ||||
| async def get_webhook_url(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "url": app.state.WEBHOOK_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class UrlForm(BaseModel): | ||||
|     url: str | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/api/webhook") | ||||
| async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)): | ||||
|     app.state.WEBHOOK_URL = form_data.url | ||||
| 
 | ||||
|     webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL | ||||
| 
 | ||||
|     return { | ||||
|         "url": app.state.WEBHOOK_URL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/api/version") | ||||
| async def get_app_config(): | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										20
									
								
								backend/utils/webhook.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								backend/utils/webhook.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,20 @@ | |||
| import requests | ||||
| 
 | ||||
| 
 | ||||
| def post_webhook(url: str, message: str, event_data: dict) -> bool: | ||||
|     try: | ||||
|         payload = {} | ||||
| 
 | ||||
|         if "https://hooks.slack.com" in url: | ||||
|             payload["text"] = message | ||||
|         elif "https://discord.com/api/webhooks" in url: | ||||
|             payload["content"] = message | ||||
|         else: | ||||
|             payload = {**event_data} | ||||
| 
 | ||||
|         r = requests.post(url, json=payload) | ||||
|         r.raise_for_status() | ||||
|         return True | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         return False | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Self Denial
						Self Denial