File size: 5,370 Bytes
4384dd9
 
 
 
 
 
 
 
 
 
 
 
 
 
c5e4a40
 
34c993f
4384dd9
 
 
 
 
 
 
 
 
34c993f
 
 
4384dd9
34c993f
 
4384dd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import os
import sys
import math
from typing import List

import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.utils.import_utils import is_xformers_available

from src.my_utils.testing_utils import parse_args_paired_testing
from src.de_net import DEResNet
from src.s3diff_tile import S3Diff
from torchvision import transforms
from utils.wavelet_color import wavelet_color_fix, adain_color_fix

tensor_transforms = transforms.Compose([
                transforms.ToTensor(),
            ])

args = parse_args_paired_testing()

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Load scheduler, tokenizer and models.
pretrained_model_path = 'checkpoints/s3diff.pkl'
t2i_path = 'stabilityai/sd-turbo'
de_net_path = 'assets/mm-realsr/de_net.pth'

# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=t2i_path, pretrained_path=pretrained_model_path, args=args)
net_sr.set_eval()

# initalize degradation estimation network
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(de_net_path)
net_de = net_de.cuda()
net_de.eval()

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        net_sr.unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:
    net_sr.unet.enable_gradient_checkpointing()

weight_dtype = torch.float32
device = "cuda"

# Move text_encode and vae to gpu and cast to weight_dtype
net_sr.to(device, dtype=weight_dtype)
net_de.to(device, dtype=weight_dtype)

@torch.no_grad()
def process(
    input_image: Image.Image,
    scale_factor: float,
    cfg_scale: float,
    latent_tiled_size: int,
    latent_tiled_overlap: int,
    align_method: str,
    ) -> List[np.ndarray]:

    # positive_prompt = ""
    # negative_prompt = ""

    net_sr._set_latent_tile(latent_tiled_size = latent_tiled_size, latent_tiled_overlap = latent_tiled_overlap)

    im_lr = tensor_transforms(input_image).unsqueeze(0).to(device)
    ori_h, ori_w = im_lr.shape[2:]
    im_lr_resize = F.interpolate(
        im_lr,
        size=(int(ori_h * scale_factor),
              int(ori_w * scale_factor)),
        mode='bicubic',
        )
    im_lr_resize = im_lr_resize.contiguous() 
    im_lr_resize_norm = im_lr_resize * 2 - 1.0
    im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
    resize_h, resize_w = im_lr_resize_norm.shape[2:]

    pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
    pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
    im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
      
    try:
        with torch.autocast("cuda"):
            deg_score = net_de(im_lr)

            pos_tag_prompt = [args.pos_prompt]
            neg_tag_prompt = [args.neg_prompt]

            x_tgt_pred = net_sr(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
            x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
            out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()

        output_pil = transforms.ToPILImage()(out_img[0])

        if align_method == 'no fix':
            image = output_pil
        else:
            im_lr_resize = transforms.ToPILImage()(im_lr_resize[0])
            if align_method == 'wavelet':
                image = wavelet_color_fix(output_pil, im_lr_resize)
            elif align_method == 'adain':
                image = adain_color_fix(output_pil, im_lr_resize)

    except Exception as e:
        print(e)
        image = Image.new(mode="RGB", size=(512, 512))

    return image


#
MARKDOWN = \
"""
## Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors

[GitHub](https://github.com/ArcticHare105/S3Diff) | [Paper](https://arxiv.org/abs/2409.17058)

If S3Diff is helpful for you, please help star the GitHub Repo. Thanks!
"""

block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source="upload", type="pil")
            run_button = gr.Button(label="Run")
            with gr.Accordion("Options", open=True):
                cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=1.0, maximum=1.1, value=1.07, step=0.01)
                scale_factor = gr.Number(label="SR Scale", value=4)
                latent_tiled_size = gr.Slider(label="Tile Size", minimum=64, maximum=160, value=96, step=1)
                latent_tiled_overlap = gr.Slider(label="Tile Overlap", minimum=16, maximum=48, value=32, step=1)
                align_method = gr.Dropdown(label="Color Correction", choices=["wavelet", "adain", "no fix"], value="wavelet")
        with gr.Column():
            result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", source="canvas", width="100%", height="auto")

    inputs = [
        input_image,
        scale_factor,
        cfg_scale,
        latent_tiled_size,
        latent_tiled_overlap,
        align_method
    ]
    run_button.click(fn=process, inputs=inputs, outputs=[result_image])

block.launch()