Merge conflicts

This commit is contained in:
Self Denial 2024-03-21 00:14:13 -06:00
commit f74f2ea765
30 changed files with 2318 additions and 257 deletions

View file

@ -3,14 +3,26 @@ import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig()
@ -31,16 +43,58 @@ async def on_startup():
await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)

View file

@ -712,7 +712,7 @@ class GenerateChatCompletionForm(BaseModel):
format: Optional[str] = None
options: Optional[dict] = None
template: Optional[str] = None
stream: Optional[bool] = True
stream: Optional[bool] = None
keep_alive: Optional[Union[int, str]] = None

View file

@ -19,6 +19,7 @@ from config import (
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
WEBHOOK_URL,
)
app = FastAPI()
@ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware(

View file

@ -27,7 +27,8 @@ from utils.utils import (
create_token,
)
from utils.misc import parse_duration, validate_email_format
from constants import ERROR_MESSAGES
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
router = APIRouter()
@ -155,6 +156,17 @@ async def signup(request: Request, form_data: SignupForm):
)
# response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return {
"token": token,
"token_type": "Bearer",

View file

@ -320,13 +320,19 @@ DEFAULT_PROMPT_SUGGESTIONS = (
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
USER_PERMISSIONS = {"chat": {"deletion": True}}
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
####################################
# WEBUI_VERSION

View file

@ -5,6 +5,13 @@ class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
USER_SIGNUP = lambda username="": (
f"New user signed up: {username}" if username else "New user signed up"
)
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
@ -46,7 +53,7 @@ class ERROR_MESSAGES(str, Enum):
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
lambda err="": f"Invalid format. Please use the correct format{err}"
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"

View file

@ -1,5 +1,5 @@
{
"version": "0.0.1",
"version": 0,
"ui": {
"prompt_suggestions": [
{

View file

@ -41,6 +41,7 @@ from config import (
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
WEBHOOK_URL,
)
from constants import ERROR_MESSAGES
@ -64,6 +65,9 @@ app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"]
@ -184,7 +188,7 @@ class ModelFilterConfigForm(BaseModel):
@app.post("/api/config/model/filter")
async def get_model_filter_config(
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
@ -203,6 +207,28 @@ async def get_model_filter_config(
}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
"url": app.state.WEBHOOK_URL,
}
class UrlForm(BaseModel):
url: str
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
return {
"url": app.state.WEBHOOK_URL,
}
@app.get("/api/version")
async def get_app_config():

20
backend/utils/webhook.py Normal file
View file

@ -0,0 +1,20 @@
import requests
def post_webhook(url: str, message: str, event_data: dict) -> bool:
try:
payload = {}
if "https://hooks.slack.com" in url:
payload["text"] = message
elif "https://discord.com/api/webhooks" in url:
payload["content"] = message
else:
payload = {**event_data}
r = requests.post(url, json=payload)
r.raise_for_status()
return True
except Exception as e:
print(e)
return False