File size: 751 Bytes
9628f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1210c1d
9628f1b
 
 
1210c1d
9628f1b
 
 
1210c1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from PIL import Image

def postprocess_image(image, output_type="pil"):
    """
    Postprocesses the given image.

    Args:
        image: The image to postprocess.
        output_type (str): The desired output type. Can be "pil" or "np".

    Returns:
        The postprocessed image.
    """
    if output_type not in ["pil", "np"]:
        raise ValueError(f"Invalid output_type: {output_type}")

    if isinstance(image, np.ndarray):
        image = (image * 255).round().astype("uint8")

    if output_type == "pil":
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        elif isinstance(image, list):
            image = Image.fromarray(np.concatenate(image, axis=0))

    return image