import numpy as np from PIL import Image def postprocess_image(image, output_type="pil"): """ Postprocesses the given image. Args: image: The image to postprocess. output_type (str): The desired output type. Can be "pil" or "np". Returns: The postprocessed image. """ if output_type not in ["pil", "np"]: raise ValueError(f"Invalid output_type: {output_type}") if isinstance(image, np.ndarray): image = (image * 255).round().astype("uint8") if output_type == "pil": if isinstance(image, np.ndarray): image = Image.fromarray(image) elif isinstance(image, list): image = Image.fromarray(np.concatenate(image, axis=0)) return image