forked from open-webui/open-webui
		
	feat: dall-e integration
This commit is contained in:
		
							parent
							
								
									dd3a4b3889
								
							
						
					
					
						commit
						0221acd163
					
				
					 6 changed files with 296 additions and 64 deletions
				
			
		|  | @ -21,7 +21,16 @@ from utils.utils import ( | |||
| from utils.misc import calculate_sha256 | ||||
| from typing import Optional | ||||
| from pydantic import BaseModel | ||||
| from config import AUTOMATIC1111_BASE_URL | ||||
| from pathlib import Path | ||||
| import uuid | ||||
| import base64 | ||||
| import json | ||||
| 
 | ||||
| from config import CACHE_DIR, AUTOMATIC1111_BASE_URL | ||||
| 
 | ||||
| 
 | ||||
| IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") | ||||
| IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| app.add_middleware( | ||||
|  | @ -32,25 +41,34 @@ app.add_middleware( | |||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| app.state.ENGINE = "" | ||||
| app.state.ENABLED = False | ||||
| 
 | ||||
| app.state.OPENAI_API_KEY = "" | ||||
| app.state.MODEL = "" | ||||
| 
 | ||||
| 
 | ||||
| app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
| app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != "" | ||||
| 
 | ||||
| app.state.IMAGE_SIZE = "512x512" | ||||
| app.state.IMAGE_STEPS = 50 | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/enabled", response_model=bool) | ||||
| async def get_enable_status(request: Request, user=Depends(get_admin_user)): | ||||
|     return app.state.ENABLED | ||||
| @app.get("/config") | ||||
| async def get_config(request: Request, user=Depends(get_admin_user)): | ||||
|     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/enabled/toggle", response_model=bool) | ||||
| async def toggle_enabled(request: Request, user=Depends(get_admin_user)): | ||||
|     try: | ||||
|         r = requests.head(app.state.AUTOMATIC1111_BASE_URL) | ||||
|         app.state.ENABLED = not app.state.ENABLED | ||||
|         return app.state.ENABLED | ||||
|     except Exception as e: | ||||
|         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| class ConfigUpdateForm(BaseModel): | ||||
|     engine: str | ||||
|     enabled: bool | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/config/update") | ||||
| async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): | ||||
|     app.state.ENGINE = form_data.engine | ||||
|     app.state.ENABLED = form_data.enabled | ||||
|     return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} | ||||
| 
 | ||||
| 
 | ||||
| class UrlUpdateForm(BaseModel): | ||||
|  | @ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| @app.get("/url") | ||||
| async def get_openai_url(user=Depends(get_admin_user)): | ||||
| async def get_automatic1111_url(user=Depends(get_admin_user)): | ||||
|     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/url/update") | ||||
| async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): | ||||
| async def update_automatic1111_url( | ||||
|     form_data: UrlUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     if form_data.url == "": | ||||
|         app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL | ||||
|     else: | ||||
|         app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/") | ||||
|         url = form_data.url.strip("/") | ||||
|         try: | ||||
|             r = requests.head(url) | ||||
|             app.state.AUTOMATIC1111_BASE_URL = url | ||||
|         except Exception as e: | ||||
|             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
| 
 | ||||
|     return { | ||||
|         "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, | ||||
|  | @ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIKeyUpdateForm(BaseModel): | ||||
|     key: str | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/key") | ||||
| async def get_openai_key(user=Depends(get_admin_user)): | ||||
|     return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/key/update") | ||||
| async def update_openai_key( | ||||
|     form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
| 
 | ||||
|     if form_data.key == "": | ||||
|         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | ||||
| 
 | ||||
|     app.state.OPENAI_API_KEY = form_data.key | ||||
|     return { | ||||
|         "OPENAI_API_KEY": app.state.OPENAI_API_KEY, | ||||
|         "status": True, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class ImageSizeUpdateForm(BaseModel): | ||||
|     size: str | ||||
| 
 | ||||
|  | @ -132,9 +181,22 @@ async def update_image_size( | |||
| @app.get("/models") | ||||
| def get_models(user=Depends(get_current_user)): | ||||
|     try: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models") | ||||
|         models = r.json() | ||||
|         return models | ||||
|         if app.state.ENGINE == "openai": | ||||
|             return [ | ||||
|                 {"id": "dall-e-2", "name": "DALL·E 2"}, | ||||
|                 {"id": "dall-e-3", "name": "DALL·E 3"}, | ||||
|             ] | ||||
|         else: | ||||
|             r = requests.get( | ||||
|                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" | ||||
|             ) | ||||
|             models = r.json() | ||||
|             return list( | ||||
|                 map( | ||||
|                     lambda model: {"id": model["title"], "name": model["model_name"]}, | ||||
|                     models, | ||||
|                 ) | ||||
|             ) | ||||
|     except Exception as e: | ||||
|         app.state.ENABLED = False | ||||
|         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
|  | @ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)): | |||
| @app.get("/models/default") | ||||
| async def get_default_model(user=Depends(get_admin_user)): | ||||
|     try: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|         options = r.json() | ||||
| 
 | ||||
|         return {"model": options["sd_model_checkpoint"]} | ||||
|         if app.state.ENGINE == "openai": | ||||
|             return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} | ||||
|         else: | ||||
|             r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|             options = r.json() | ||||
|             return {"model": options["sd_model_checkpoint"]} | ||||
|     except Exception as e: | ||||
|         app.state.ENABLED = False | ||||
|         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
|  | @ -157,16 +221,21 @@ class UpdateModelForm(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| def set_model_handler(model: str): | ||||
|     r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|     options = r.json() | ||||
| 
 | ||||
|     if model != options["sd_model_checkpoint"]: | ||||
|         options["sd_model_checkpoint"] = model | ||||
|         r = requests.post( | ||||
|             url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options | ||||
|         ) | ||||
|     if app.state.ENGINE == "openai": | ||||
|         app.state.MODEL = model | ||||
|         return app.state.MODEL | ||||
|     else: | ||||
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") | ||||
|         options = r.json() | ||||
| 
 | ||||
|     return options | ||||
|         if model != options["sd_model_checkpoint"]: | ||||
|             options["sd_model_checkpoint"] = model | ||||
|             r = requests.post( | ||||
|                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options | ||||
|             ) | ||||
| 
 | ||||
|         return options | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/models/default/update") | ||||
|  | @ -185,6 +254,24 @@ class GenerateImageForm(BaseModel): | |||
|     negative_prompt: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| def save_b64_image(b64_str): | ||||
|     image_id = str(uuid.uuid4()) | ||||
|     file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png") | ||||
| 
 | ||||
|     try: | ||||
|         # Split the base64 string to get the actual image data | ||||
|         img_data = base64.b64decode(b64_str) | ||||
| 
 | ||||
|         # Write the image data to a file | ||||
|         with open(file_path, "wb") as f: | ||||
|             f.write(img_data) | ||||
| 
 | ||||
|         return image_id | ||||
|     except Exception as e: | ||||
|         print(f"Error saving image: {e}") | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/generations") | ||||
| def generate_image( | ||||
|     form_data: GenerateImageForm, | ||||
|  | @ -194,32 +281,82 @@ def generate_image( | |||
|     print(form_data) | ||||
| 
 | ||||
|     try: | ||||
|         if form_data.model: | ||||
|             set_model_handler(form_data.model) | ||||
|         if app.state.ENGINE == "openai": | ||||
| 
 | ||||
|         width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
|             headers = {} | ||||
|             headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" | ||||
|             headers["Content-Type"] = "application/json" | ||||
| 
 | ||||
|         data = { | ||||
|             "prompt": form_data.prompt, | ||||
|             "batch_size": form_data.n, | ||||
|             "width": width, | ||||
|             "height": height, | ||||
|         } | ||||
|             data = { | ||||
|                 "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "n": form_data.n, | ||||
|                 "size": form_data.size, | ||||
|                 "response_format": "b64_json", | ||||
|             } | ||||
| 
 | ||||
|         if app.state.IMAGE_STEPS != None: | ||||
|             data["steps"] = app.state.IMAGE_STEPS | ||||
|             r = requests.post( | ||||
|                 url=f"https://api.openai.com/v1/images/generations", | ||||
|                 json=data, | ||||
|                 headers=headers, | ||||
|             ) | ||||
| 
 | ||||
|         if form_data.negative_prompt != None: | ||||
|             data["negative_prompt"] = form_data.negative_prompt | ||||
|             r.raise_for_status() | ||||
| 
 | ||||
|         print(data) | ||||
|             res = r.json() | ||||
| 
 | ||||
|         r = requests.post( | ||||
|             url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", | ||||
|             json=data, | ||||
|         ) | ||||
|             images = [] | ||||
| 
 | ||||
|             for image in res["data"]: | ||||
|                 image_id = save_b64_image(image["b64_json"]) | ||||
|                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||
|                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||
| 
 | ||||
|                 with open(file_body_path, "w") as f: | ||||
|                     json.dump(data, f) | ||||
| 
 | ||||
|             return images | ||||
| 
 | ||||
|         else: | ||||
|             if form_data.model: | ||||
|                 set_model_handler(form_data.model) | ||||
| 
 | ||||
|             width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) | ||||
| 
 | ||||
|             data = { | ||||
|                 "prompt": form_data.prompt, | ||||
|                 "batch_size": form_data.n, | ||||
|                 "width": width, | ||||
|                 "height": height, | ||||
|             } | ||||
| 
 | ||||
|             if app.state.IMAGE_STEPS != None: | ||||
|                 data["steps"] = app.state.IMAGE_STEPS | ||||
| 
 | ||||
|             if form_data.negative_prompt != None: | ||||
|                 data["negative_prompt"] = form_data.negative_prompt | ||||
| 
 | ||||
|             r = requests.post( | ||||
|                 url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", | ||||
|                 json=data, | ||||
|             ) | ||||
| 
 | ||||
|             res = r.json() | ||||
| 
 | ||||
|             print(res) | ||||
| 
 | ||||
|             images = [] | ||||
| 
 | ||||
|             for image in res["images"]: | ||||
|                 image_id = save_b64_image(image) | ||||
|                 images.append({"url": f"/cache/image/generations/{image_id}.png"}) | ||||
|                 file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") | ||||
| 
 | ||||
|                 with open(file_body_path, "w") as f: | ||||
|                     json.dump({**data, "info": res["info"]}, f) | ||||
| 
 | ||||
|             return images | ||||
| 
 | ||||
|         return r.json() | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek