File size: 3,571 Bytes
92d45d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56f91c4
92d45d2
 
5be1082
92d45d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5be1082
92d45d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import cv2

import imgproc
from imgproc import image_to_tensor
from inference import choice_device, build_model
from utils import load_state_dict

model = "srresnet_x4"

device = choice_device("cpu")

# Initialize the model
sr_model = build_model(model, device)
print(f"Build {model} model successfully.")

# Load model weights
sr_model = load_state_dict(sr_model, "weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar")
print(f"Load `{model}` model weights successfully.")

# Start the verification mode of the model.
sr_model.eval()

def downscale(image):
    (width, height, colors) = image.shape

    new_height = int(60 * width / height)

    return cv2.resize(image, (60, new_height), interpolation=cv2.INTER_AREA)


def preprocess(image):
    image = image / 255.0

    # Convert image data to pytorch format data
    tensor = image_to_tensor(image, False, False).unsqueeze_(0)

    # Transfer tensor channel image format data to CUDA device
    tensor = tensor.to(device="cpu", memory_format=torch.channels_last, non_blocking=True)

    return tensor

def processHighRes(image):
    if image is None:
        raise gr.Error("Please enter an image")
    downscaled = downscale(image)
    lr_tensor = preprocess(downscaled)

    # Use the model to generate super-resolved images
    with torch.no_grad():
        sr_tensor = sr_model(lr_tensor)

    # Save image
    sr_image = imgproc.tensor_to_image(sr_tensor, False, False)

    return [downscaled, sr_image]

def processLowRes(image):
    if image is None:
        raise gr.Error("Please enter an image")

    (width, height, colors) = image.shape

    if width > 150 or height > 150:
        raise gr.Error("Image is too big")

    lr_tensor = preprocess(image)

    # Use the model to generate super-resolved images
    with torch.no_grad():
        sr_tensor = sr_model(lr_tensor)

    # Save image
    sr_image = imgproc.tensor_to_image(sr_tensor, False, False)

    return sr_image

description = """<p style='text-align: center'> <a href='https://arxiv.org/abs/1609.04802' target='_blank'>Paper</a> | <a href=https://github.com/Lornatang/SRGAN-PyTorch target='_blank'>GitHub</a></p>"""
about = "<p style='text-align: center'>Made for the 2022-2023 Grenoble-INP Phelma Image analysis course by Thibaud CHERUY, Clément DEBUY & Yassine EL KHANOUSSI.</p>"

with gr.Blocks() as demo:
    gr.Markdown("# **<p align='center'>SRGAN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network</p>**")
    gr.Markdown(description)

    with gr.Tab("From high res"):
        high_res_input = gr.Image(label="High-res source image", show_label=True)
        with gr.Row():
            low_res_output = gr.Image(label="Low-res image")
            srgan_output = gr.Image(label="SRGAN Output")
        high_res_button = gr.Button("Process")

    with gr.Tab("From low res"):
        low_res_input = gr.Image(label="Low-res source image", show_label=True)
        srgan_upscale = gr.Image(label="SRGAN Output")
        low_res_button = gr.Button("Process")

    gr.Examples(
        examples=["examples/bird.png", "examples/butterfly.png", "examples/comic.png", "examples/gray.png",
                  "examples/man.png"],
        inputs=[high_res_input],
        outputs=[low_res_output, srgan_output],
        fn=processHighRes
    )

    high_res_button.click(processHighRes, inputs=[high_res_input], outputs=[low_res_output, srgan_output])
    low_res_button.click(processLowRes, inputs=[low_res_input], outputs=[srgan_upscale])

    gr.Markdown(about)

demo.launch()