|
import os |
|
import math |
|
import json |
|
from typing import List |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from PIL.PngImagePlugin import PngInfo |
|
from nodes import SaveImage |
|
|
|
class GridImage(SaveImage): |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
'required': { |
|
'images': ('IMAGE',), |
|
'filename_prefix': ('STRING', {'default': 'ComfyUI-Grid'}), |
|
'x': ('INT', { |
|
'default': 1, |
|
'min': 1, |
|
'max': 64, |
|
'step': 1 |
|
}), |
|
'gap': ('INT', { |
|
'default': 0, |
|
'min': 0, |
|
'max': 32, |
|
'step': 1 |
|
}), |
|
}, |
|
'hidden': { |
|
'prompt': 'PROMPT', |
|
'extra_pnginfo': 'EXTRA_PNGINFO' |
|
}, |
|
} |
|
|
|
OUTPUT_NODE = True |
|
|
|
RETURN_TYPES = () |
|
|
|
FUNCTION = 'execute' |
|
|
|
CATEGORY = 'image' |
|
|
|
def execute(self, images: List[torch.Tensor], filename_prefix: str = 'ComfyUI-Grid', x: int = 1, gap: int = 0, prompt=None, extra_pnginfo=None): |
|
y = max([math.ceil(len(images) / x), 1]) |
|
|
|
def map_filename(filename): |
|
prefix_len = len(os.path.basename(filename_prefix)) |
|
prefix = filename[:prefix_len + 1] |
|
try: |
|
digits = int(filename[prefix_len + 1:].split('_')[0]) |
|
except: |
|
digits = 0 |
|
return (digits, prefix) |
|
|
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix)) |
|
filename = os.path.basename(os.path.normpath(filename_prefix)) |
|
|
|
full_output_folder = os.path.join(self.output_dir, subfolder) |
|
|
|
if os.path.commonpath((self.output_dir, os.path.realpath(full_output_folder))) != self.output_dir: |
|
print("Saving image outside the output folder is not allowed.") |
|
return {} |
|
|
|
try: |
|
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 |
|
except ValueError: |
|
counter = 1 |
|
except FileNotFoundError: |
|
os.makedirs(full_output_folder, exist_ok=True) |
|
counter = 1 |
|
|
|
if not os.path.exists(self.output_dir): |
|
os.makedirs(self.output_dir) |
|
|
|
results = list() |
|
|
|
canvas = self.grid_image(images, x, y, gap) |
|
metadata = PngInfo() |
|
if prompt is not None: |
|
metadata.add_text("prompt", json.dumps(prompt)) |
|
if extra_pnginfo is not None: |
|
for x in extra_pnginfo: |
|
metadata.add_text(x, json.dumps(extra_pnginfo[x])) |
|
|
|
file = f"{filename}_{counter:05}_.png" |
|
canvas.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) |
|
results.append({ |
|
"filename": file, |
|
"subfolder": subfolder, |
|
"type": 'output' |
|
}) |
|
counter += 1 |
|
|
|
return { "ui": { "images": results } } |
|
|
|
def grid_image(self, images: List[torch.Tensor], x: int, y: int, gap: int): |
|
width, height, _ = images[0].shape |
|
canvas = Image.new('RGB', (x*(width+gap)-gap, y*(height+gap)-gap), color='black') |
|
|
|
for Y in range(y): |
|
for X in range(x): |
|
idx = Y * x + X |
|
if len(images) <= idx: |
|
return canvas |
|
|
|
image = images[idx] |
|
|
|
i = 255. * image.cpu().numpy() |
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) |
|
|
|
canvas.paste(img, (X*(width+gap), Y*(height+gap))) |
|
|
|
return canvas |
|
|