Merge branch 'dev' into buroa/hybrid-search

This commit is contained in:
Steven Kreitzer 2024-04-22 18:35:32 -05:00 committed by GitHub
commit db801aee79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 701 additions and 204 deletions

View file

@ -195,7 +195,7 @@ class ImageGenerationPayload(BaseModel):
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
host = base_url.replace("http://", "").replace("https://", "")
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
@ -217,7 +217,7 @@ def comfyui_generate_image(
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
ws.connect(f"{ws_url}/ws?clientId={client_id}")
log.info("WebSocket connection established.")
except Exception as e:
log.exception(f"Failed to connect to WebSocket server: {e}")

View file

@ -1,100 +1,336 @@
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
import time
import requests
from utils.utils import get_http_authorization_cred, get_current_user
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
proxy_config = ProxyConfig()
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference
background_process = None
async def config():
router, model_list, general_settings = await proxy_config.load_config(
router=None, config_file_path="./data/litellm/config.yaml"
async def run_background_process(command):
global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = (
"litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml"
)
await initialize(config="./data/litellm/config.yaml", telemetry=False)
await run_background_process(command)
async def startup():
await config()
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.on_event("startup")
async def on_startup():
await startup()
async def startup_event():
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
@app.get("/")
async def get_status():
return {"status": True}
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
while not background_process:
await asyncio.sleep(0.1)
url = "http://localhost:14365/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
@app.post("/model/delete")
async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = "http://localhost:14365"
target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
user = get_current_user(get_http_authorization_cred(auth_header))
log.debug(f"user: {user}")
request.state.user = user
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)

View file

@ -80,6 +80,7 @@ async def get_openai_urls(user=Depends(get_admin_user)):
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
app.state.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}

View file

@ -136,7 +136,9 @@ class TagTable:
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
]
def get_tags_by_chat_id_and_user_id(
@ -151,7 +153,9 @@ class TagTable:
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
]
def get_chat_ids_by_tag_name_and_user_id(

View file

@ -28,7 +28,7 @@ from apps.web.models.tags import (
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -79,6 +79,11 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()

View file

@ -91,7 +91,11 @@ async def download_chat_as_pdf(
@router.get("/db/download")
async def download_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return FileResponse(
f"{DATA_DIR}/webui.db",
media_type="application/octet-stream",

View file

@ -322,9 +322,14 @@ OPENAI_API_BASE_URLS = [
]
OPENAI_API_KEY = ""
OPENAI_API_KEY = OPENAI_API_KEYS[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
]
try:
OPENAI_API_KEY = OPENAI_API_KEYS[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
]
except:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
@ -377,6 +382,8 @@ MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
####################################
# WEBUI_VERSION
####################################

View file

@ -3,6 +3,10 @@ from enum import Enum
class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
MODEL_DELETED = (
lambda model="": f"The model '{model}' has been deleted successfully."
)
class WEBHOOK_MESSAGES(str, Enum):

View file

@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
from apps.litellm.main import (
app as litellm_app,
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List
@ -47,6 +52,7 @@ from config import (
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
)
from constants import ERROR_MESSAGES
@ -171,7 +177,7 @@ async def check_url(request: Request, call_next):
@app.on_event("startup")
async def on_startup():
await litellm_app_startup()
asyncio.create_task(start_litellm_background())
app.mount("/api/v1", webui_app)
@ -203,6 +209,7 @@ async def get_app_config():
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
}
@ -316,3 +323,8 @@ app.mount(
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
name="spa-static-files",
)
@app.on_event("shutdown")
async def shutdown_event():
await shutdown_litellm_background()

View file

@ -17,7 +17,9 @@ peewee
peewee-migrate
bcrypt
litellm==1.30.7
litellm==1.35.17
litellm[proxy]==1.35.17
boto3
argon2-cffi