feat: sd backend integration

This commit is contained in:
Timothy J. Baek 2024-02-21 18:12:01 -08:00
parent 7a730c3f0f
commit 733e963c44
11 changed files with 611 additions and 33 deletions

148
backend/apps/images/main.py Normal file
View file

@ -0,0 +1,148 @@
import os
import requests
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from config import AUTOMATIC1111_BASE_URL
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = False
@app.get("/enabled", response_model=bool)
async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/enabled/toggle", response_model=bool)
async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
try:
r = requests.head(form_data.url)
if r.ok:
app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"status": True,
}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
models = r.json()
return models
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel):
model: str
def set_model_handler(model: str):
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
):
return set_model_handler(form_data.model)
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: str = "512x512"
negative_prompt: Optional[str] = None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, form_data.size.split("x")))
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json={
"prompt": form_data.prompt,
"negative_prompt": form_data.negative_prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
},
)
return r.json()

View file

@ -185,3 +185,10 @@ Query: [query]"""
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
####################################
# Images
####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")

View file

@ -11,10 +11,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app
from config import ENV, FRONTEND_BUILD_DIR
@ -58,10 +58,21 @@ app.mount("/api/v1", webui_app)
app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
@app.get("/api/config")
async def get_app_config():
return {
"status": True,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
}
app.mount(
"/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),