Merge pull request #1815 from Yanyutin753/new-dev

 expend the image format type after the file is downloaded
This commit is contained in:
Timothy Jaeryang Baek 2024-04-30 11:52:30 -07:00 committed by GitHub
commit de62153d49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from pathlib import Path
import mimetypes
import uuid import uuid
import base64 import base64
import json import json
@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
def save_b64_image(b64_str): def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
# Split the base64 string to get the actual image data header, encoded = b64_str.split(",", 1)
img_data = base64.b64decode(b64_str) mime_type = header.split(";")[0]
# Write the image data to a file img_data = base64.b64decode(encoded)
image_id = str(uuid.uuid4())
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(img_data) f.write(img_data)
return image_filename
return image_id
except Exception as e: except Exception as e:
log.error(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None
def save_url_image(url): def save_url_image(url):
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
r = requests.get(url) r = requests.get(url)
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
with open(file_path, "wb") as image_file: mime_type = r.headers["content-type"]
image_file.write(r.content) image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_id, image_format
else:
log.error(f"Url does not point to an image.")
return None, None
return image_id
except Exception as e: except Exception as e:
log.exception(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None, None
@app.post("/generations") @app.post("/generations")
@ -385,8 +398,8 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_b64_image(image["b64_json"]) image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
@ -422,8 +435,10 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_url_image(image["url"]) image_id, image_format = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
@ -460,8 +475,8 @@ def generate_image(
images = [] images = []
for image in res["images"]: for image in res["images"]:
image_id = save_b64_image(image) image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f: