expend the image format type after the file is downloaded

This commit is contained in:
Yanyutin753 2024-04-28 12:00:52 +08:00
parent 9af6c5300b
commit 3321a1b922

View file

@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from pathlib import Path
import mimetypes
import uuid import uuid
import base64 import base64
import json import json
@ -315,38 +316,47 @@ class GenerateImageForm(BaseModel):
def save_b64_image(b64_str): def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
# Split the base64 string to get the actual image data header, encoded = b64_str.split(",", 1)
img_data = base64.b64decode(b64_str) mime_type = header.split(";")[0]
# Write the image data to a file image_format = mimetypes.guess_extension(mime_type)
img_data = base64.b64decode(encoded)
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR / f"{image_id}{image_format}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(img_data) f.write(img_data)
return image_id, image_format
return image_id
except Exception as e: except Exception as e:
log.error(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None, None
def save_url_image(url): def save_url_image(url):
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
r = requests.get(url) r = requests.get(url)
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
with open(file_path, "wb") as image_file: with open(file_path, "wb") as image_file:
image_file.write(r.content) for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_id, image_format
else:
log.error(f"Url does not point to an image.")
return None, None
return image_id
except Exception as e: except Exception as e:
log.exception(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None, None
@app.post("/generations") @app.post("/generations")
@ -385,8 +395,10 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_b64_image(image["b64_json"]) image_id, image_format = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
@ -422,8 +434,10 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_url_image(image["url"]) image_id, image_format = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
@ -460,8 +474,10 @@ def generate_image(
images = [] images = []
for image in res["images"]: for image in res["images"]:
image_id = save_b64_image(image) image_id, image_format = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f: