File size: 3,912 Bytes
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8edf3
 
 
6e70c4a
 
 
 
40d12a9
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d12a9
6e70c4a
 
 
 
 
2ac9f6b
 
 
 
 
 
 
 
6e70c4a
 
 
 
 
 
 
 
 
7da59af
6e70c4a
 
 
 
 
 
2ac9f6b
6e70c4a
 
2ac9f6b
6e70c4a
40d12a9
6e70c4a
 
 
 
 
 
 
40d12a9
6e70c4a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import os, requests
import numpy as np
import torch.nn.functional as F
from model.model import ResHalf
from inference import Inferencer
from utils import util

## local |  remote
RUN_MODE = "remote"
if RUN_MODE != "local":
    os.system("wget https://huggingface.co./menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar")
    os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar")
    ## examples
    os.system("wget https://huggingface.co./menghanxia/ReversibleHalftoning/resolve/main/girl.png")
    os.system("wget https://huggingface.co./menghanxia/ReversibleHalftoning/resolve/main/wave.png")
    os.system("wget https://huggingface.co./menghanxia/ReversibleHalftoning/resolve/main/painting.png")

## step 1: set up model
device = "cpu"
checkpt_path = "checkpoints/model_best.pth.tar"
invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False)


def prepare_data(input_img, decoding_only=False):
    input_img = np.array(input_img / 255., np.float32)
    if decoding_only:
        input_img = input_img[:,:,:1]
    input_img = util.img2tensor(input_img * 2. - 1.)
    return input_img


def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
    input_img = prepare_data(input_img, decoding_only)
    input_img = input_img.to(device)
    if decoding_only:
        print('>>>:restoration mode')
        resColor = invhalfer(input_img, decoding_only=decoding_only)
        output = util.tensor2img(resColor / 2. + 0.5) * 255.
    else:
        print('>>>:halftoning mode')
        resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
        output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
    return np.clip(output, 0, 255).astype(np.uint8)


def click_run(input_img, decoding_only):
    output = run_invhalf(invhalfer, input_img, decoding_only, device)
    return output
    
    
def click_move(output_img, decoding_only):
    if decoding_only:
        radio_status = "Halftoning (Photo2Halftone)"
    else:
        radio_status = "Restoration (Halftone2Photo)"
    return output_img, radio_status, None

## step 2: configure interface
demo = gr.Blocks(title="ReversibleHalftoning")
with demo:
    gr.Markdown(value="""

                    **Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛.

                    """)
    with gr.Row():
        with gr.Column():
            Image_input = gr.Image(type="numpy", label="Input", interactive=True).style(height=480)
            with gr.Row():
                Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \
                                                label="Choose a running mode", value="Halftoning (Photo2Halftone)")
                Button_run = gr.Button(value="Run")
        with gr.Column():
            Image_output = gr.Image(type="numpy", label="Output").style(height=480)
            Button_move = gr.Button(value="Use it as input")

    Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
    Button_move.click(fn=click_move, inputs=[Image_output, Radio_mode], outputs=[Image_input, Radio_mode, Image_output])
    
    if RUN_MODE != "local":
        gr.Examples(examples=[
                    ['girl.png', "Halftoning (Photo2Halftone)"],
                    ['wave.png', "Halftoning (Photo2Halftone)"],
                    ['painting.png', "Restoration (Halftone2Photo)"],
                    ], 
                    inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")

if RUN_MODE == "local":
    demo.launch(server_name='9.134.253.83',server_port=7788)
else:
    demo.launch()