Merge pull request #1815 from Yanyutin753/new-dev

 expend the image format type after the file is downloaded
This commit is contained in:
Timothy Jaeryang Baek 2024-04-30 11:52:30 -07:00 committed by GitHub
commit de62153d49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from pathlib import Path
import mimetypes
import uuid
import base64
import json
@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
# Split the base64 string to get the actual image data
img_data = base64.b64decode(b64_str)
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
# Write the image data to a file
img_data = base64.b64decode(encoded)
image_id = str(uuid.uuid4())
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_id
return image_filename
except Exception as e:
log.error(f"Error saving image: {e}")
log.exception(f"Error saving image: {e}")
return None
def save_url_image(url):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
r = requests.get(url)
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
with open(file_path, "wb") as image_file:
image_file.write(r.content)
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_id, image_format
else:
log.error(f"Url does not point to an image.")
return None, None
return image_id
except Exception as e:
log.exception(f"Error saving image: {e}")
return None
return None, None
@app.post("/generations")
@ -385,8 +398,8 @@ def generate_image(
images = []
for image in res["data"]:
image_id = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
@ -422,8 +435,10 @@ def generate_image(
images = []
for image in res["data"]:
image_id = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
image_id, image_format = save_url_image(image["url"])
images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
@ -460,8 +475,8 @@ def generate_image(
images = []
for image in res["images"]:
image_id = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"})
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: