File size: 2,693 Bytes
b42bd57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import cv2
import onnxruntime
import gradio as gr


def pre_process(img: np.array) -> np.array:
    # H, W, C -> C, H, W
    img = np.transpose(img[:, :, 0:3], (2, 0, 1))
    # C, H, W -> 1, C, H, W
    img = np.expand_dims(img, axis=0).astype(np.float32)
    return img


def post_process(img: np.array) -> np.array:
    # 1, C, H, W -> C, H, W
    img = np.squeeze(img)
    # C, H, W -> H, W, C
    img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
    return img


def inference(model_path: str, img_array: np.array) -> np.array:
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 1
    options.inter_op_num_threads = 1
    ort_session = onnxruntime.InferenceSession(model_path, options)
    ort_inputs = {ort_session.get_inputs()[0].name: img_array}
    ort_outs = ort_session.run(None, ort_inputs)

    return ort_outs[0]


def convert_pil_to_cv2(image):
    # pil_image = image.convert("RGB")
    open_cv_image = np.array(image)
    # RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image


def upscale(image, model):
    model_path = f"models/{model}.ort"
    img = convert_pil_to_cv2(image)
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    if img.shape[2] == 4:
        alpha = img[:, :, 3]  # GRAY
        alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
        alpha_output = post_process(inference(model_path, pre_process(alpha)))  # BGR
        alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY

        img = img[:, :, 0:3]  # BGR
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR
        image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
        image_output[:, :, 3] = alpha_output

    elif img.shape[2] == 3:
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR

    return image_output


css = ".output-image, .input-image, .image-preview {height: 480px !important} "
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]

gr.Interface(
    fn=upscale,
    inputs=[
        gr.inputs.Image(type="pil", label="Input Image"),
        gr.inputs.Radio(
            model_choices,
            type="value",
            default=None,
            label="Choose Upscaler",
            optional=False,
        ),
    ],
    outputs="image",
    title="Image Upscaling 🦆",
    description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
    allow_flagging="never",
    css=css,
).launch()