Merge branch 'dev' into buroa/hybrid-search

This commit is contained in:
Steven Kreitzer 2024-04-24 14:51:49 -05:00 committed by GitHub
commit adb009f388
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 278 additions and 204 deletions

View file

@ -35,8 +35,8 @@ from config import (
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL,
OPENAI_API_BASE_URL,
OPENAI_API_KEY,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
)
@ -58,8 +58,8 @@ app.add_middleware(
app.state.ENGINE = ""
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.MODEL = ""
@ -135,27 +135,33 @@ async def update_engine_url(
}
class OpenAIKeyUpdateForm(BaseModel):
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
@app.get("/key")
async def get_openai_key(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/key/update")
async def update_openai_key(
form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
@app.post("/openai/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
return {
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}

View file

@ -1,3 +1,5 @@
import sys
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
@ -23,7 +25,13 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
from litellm.utils import get_llm_provider
@ -64,7 +72,7 @@ async def run_background_process(command):
log.info(f"Executing command: {command}")
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
background_process = process
log.info("Subprocess started successfully.")
@ -90,9 +98,17 @@ async def run_background_process(command):
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = (
"litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml"
)
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
@ -109,7 +125,6 @@ async def shutdown_litellm_background():
@app.on_event("startup")
async def startup_event():
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
@ -186,7 +201,7 @@ async def get_models(user=Depends(get_current_user)):
while not background_process:
await asyncio.sleep(0.1)
url = "http://localhost:14365/v1"
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
@ -289,7 +304,7 @@ async def delete_model_from_config(
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = "http://localhost:14365"
url = f"http://localhost:{LITELLM_PROXY_PORT}"
target_url = f"{url}/{path}"

View file

@ -499,9 +499,24 @@ AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
####################################
# LiteLLM
####################################
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")