jjourney1125 commited on
Commit
fac140f
1 Parent(s): d53e2c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -31
app.py CHANGED
@@ -1,45 +1,336 @@
1
- import os
2
  import cv2
3
- import requests
 
4
  import gradio as gr
5
- from PIL import Image
 
6
  import torch
7
- import shutil
 
 
 
8
 
9
- model_path = 'experiments/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth'
10
 
11
- if os.path.exists(model_path):
12
- print(f'loading model from {model_path}')
13
- else:
14
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
15
- url = 'https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}'.format(os.path.basename(model_path))
16
- r = requests.get(url, allow_redirects=True)
17
- print(f'downloading model {model_path}')
18
- open(model_path, 'wb').write(r.content)
19
 
 
 
 
 
 
20
 
21
- os.makedirs("test", exist_ok=True)
22
 
23
- def inference(img):
24
  basewidth = 256
25
  wpercent = (basewidth/float(img.size[0]))
26
  hsize = int((float(img.size[1])*float(wpercent)))
27
  img = img.resize((basewidth,hsize), Image.ANTIALIAS)
28
  img.save("test/1.png", "PNG")
29
- os.system('python main_test_swin2sr.py --task real_sr --model_path experiments/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth --folder_lq test --scale 4')
30
- return 'results/swin2sr_real_sr_x4/1_Swin2SR.png'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- title = "Swin2SR"
33
- description = "Gradio demo for Swin2SR."
34
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2209.11345' target='_blank'>Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration</a> | <a href='https://github.com/mv-lab/swin2sr' target='_blank'>Github Repo</a></p>"
35
-
36
- examples=[['butterflyx4.png']]
37
- gr.Interface(
38
- inference,
39
- gr.inputs.Image(type="pil", label="Input"),
40
- "image",
41
- title=title,
42
- description=description,
43
- article=article,
44
- examples=examples,
45
- ).launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import cv2
3
+ import glob
4
+ import numpy as np
5
  import gradio as gr
6
+ from collections import OrderedDict
7
+ import os
8
  import torch
9
+ import requests
10
+ from PIL import Image
11
+ from models.network_swin2sr import Swin2SR as net
12
+ from utils import util_calculate_psnr_ssim as util
13
 
 
14
 
15
+ def setup_model(args):
 
 
 
 
 
 
 
16
 
17
+ model = define_model(args)
18
+ model.eval()
19
+ model = model.to(device)
20
+
21
+ return model
22
 
23
+ def main(img):
24
 
25
+ # setup folder and path
26
  basewidth = 256
27
  wpercent = (basewidth/float(img.size[0]))
28
  hsize = int((float(img.size[1])*float(wpercent)))
29
  img = img.resize((basewidth,hsize), Image.ANTIALIAS)
30
  img.save("test/1.png", "PNG")
31
+
32
+ folder, save_dir, border, window_size = setup(args)
33
+ os.makedirs(save_dir, exist_ok=True)
34
+ test_results = OrderedDict()
35
+ test_results['psnr'] = []
36
+ test_results['ssim'] = []
37
+ test_results['psnr_y'] = []
38
+ test_results['ssim_y'] = []
39
+ test_results['psnrb'] = []
40
+ test_results['psnrb_y'] = []
41
+ psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0
42
+
43
+ for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
44
+ # read image
45
+ imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32
46
+ img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
47
+ img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
48
+
49
+ # inference
50
+ with torch.no_grad():
51
+ # pad input image to be a multiple of window_size
52
+ _, _, h_old, w_old = img_lq.size()
53
+ h_pad = (h_old // window_size + 1) * window_size - h_old
54
+ w_pad = (w_old // window_size + 1) * window_size - w_old
55
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
56
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
57
+ output = test(img_lq, model, args, window_size)
58
+
59
+ if args.task == 'compressed_sr':
60
+ output = output[0][..., :h_old * args.scale, :w_old * args.scale]
61
+ else:
62
+ output = output[..., :h_old * args.scale, :w_old * args.scale]
63
+
64
+ # save image
65
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
66
+ if output.ndim == 3:
67
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
68
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
69
+ cv2.imwrite(f'{save_dir}/{imgname}_Swin2SR.png', output)
70
+
71
+
72
+ # evaluate psnr/ssim/psnr_b
73
+ if img_gt is not None:
74
+ img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8
75
+ img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...] # crop gt
76
+ img_gt = np.squeeze(img_gt)
77
+
78
+ psnr = util.calculate_psnr(output, img_gt, crop_border=border)
79
+ ssim = util.calculate_ssim(output, img_gt, crop_border=border)
80
+ test_results['psnr'].append(psnr)
81
+ test_results['ssim'].append(ssim)
82
+ if img_gt.ndim == 3: # RGB image
83
+ psnr_y = util.calculate_psnr(output, img_gt, crop_border=border, test_y_channel=True)
84
+ ssim_y = util.calculate_ssim(output, img_gt, crop_border=border, test_y_channel=True)
85
+ test_results['psnr_y'].append(psnr_y)
86
+ test_results['ssim_y'].append(ssim_y)
87
+ if args.task in ['jpeg_car', 'color_jpeg_car']:
88
+ psnrb = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=False)
89
+ test_results['psnrb'].append(psnrb)
90
+ if args.task in ['color_jpeg_car']:
91
+ psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=True)
92
+ test_results['psnrb_y'].append(psnrb_y)
93
+ print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
94
+ 'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
95
+ format(idx, imgname, psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
96
+ else:
97
+ print('Testing {:d} {:20s}'.format(idx, imgname))
98
+
99
+ # summarize psnr/ssim
100
+ if img_gt is not None:
101
+ ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
102
+ ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
103
+ print('\n{} \n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(save_dir, ave_psnr, ave_ssim))
104
+ if img_gt.ndim == 3:
105
+ ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
106
+ ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
107
+ print('-- Average PSNR_Y/SSIM_Y: {:.2f} dB; {:.4f}'.format(ave_psnr_y, ave_ssim_y))
108
+ if args.task in ['jpeg_car', 'color_jpeg_car']:
109
+ ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb'])
110
+ print('-- Average PSNRB: {:.2f} dB'.format(ave_psnrb))
111
+ if args.task in ['color_jpeg_car']:
112
+ ave_psnrb_y = sum(test_results['psnrb_y']) / len(test_results['psnrb_y'])
113
+ print('-- Average PSNRB_Y: {:.2f} dB'.format(ave_psnrb_y))
114
+
115
+ return f"results/swin2sr_{args.task}_x{args.scale}/1_Swin2SR.png"
116
+
117
+
118
+ def define_model(args):
119
+ # 001 classical image sr
120
+ if args.task == 'classical_sr':
121
+ model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
122
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
123
+ mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
124
+ param_key_g = 'params'
125
+
126
+ # 002 lightweight image sr
127
+ # use 'pixelshuffledirect' to save parameters
128
+ elif args.task in ['lightweight_sr']:
129
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
130
+ img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
131
+ mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
132
+ param_key_g = 'params'
133
+
134
+ elif args.task == 'compressed_sr':
135
+ model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
136
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
137
+ mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
138
+ param_key_g = 'params'
139
+
140
+ # 003 real-world image sr
141
+ elif args.task == 'real_sr':
142
+ if not args.large_model:
143
+ # use 'nearest+conv' to avoid block artifacts
144
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
145
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
146
+ mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
147
+ else:
148
+ # larger model size; use '3conv' to save parameters and memory; use ema for GAN training
149
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
150
+ img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
151
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
152
+ mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
153
+ param_key_g = 'params_ema'
154
+
155
+ # 006 grayscale JPEG compression artifact reduction
156
+ # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
157
+ elif args.task == 'jpeg_car':
158
+ model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
159
+ img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
160
+ mlp_ratio=2, upsampler='', resi_connection='1conv')
161
+ param_key_g = 'params'
162
+
163
+ # 006 color JPEG compression artifact reduction
164
+ # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
165
+ elif args.task == 'color_jpeg_car':
166
+ model = net(upscale=1, in_chans=3, img_size=126, window_size=7,
167
+ img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
168
+ mlp_ratio=2, upsampler='', resi_connection='1conv')
169
+ param_key_g = 'params'
170
+
171
+ pretrained_model = torch.load(args.model_path)
172
+ model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
173
+
174
+ return model
175
+
176
+
177
+ def setup(args):
178
+ # 001 classical image sr/ 002 lightweight image sr
179
+ if args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
180
+ save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
181
+ if args.save_img_only:
182
+ folder = args.folder_lq
183
+ else:
184
+ folder = args.folder_gt
185
+ border = args.scale
186
+ window_size = 8
187
+
188
+ # 003 real-world image sr
189
+ elif args.task in ['real_sr']:
190
+ save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
191
+ if args.large_model:
192
+ save_dir += '_large'
193
+ folder = args.folder_lq
194
+ border = 0
195
+ window_size = 8
196
+
197
+ # 006 JPEG compression artifact reduction
198
+ elif args.task in ['jpeg_car', 'color_jpeg_car']:
199
+ save_dir = f'results/swin2sr_{args.task}_jpeg{args.jpeg}'
200
+ folder = args.folder_gt
201
+ border = 0
202
+ window_size = 7
203
+
204
+ return folder, save_dir, border, window_size
205
+
206
+
207
+ def get_image_pair(args, path):
208
+ (imgname, imgext) = os.path.splitext(os.path.basename(path))
209
+
210
+ # 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)
211
+ if args.task in ['classical_sr', 'lightweight_sr']:
212
+ if args.save_img_only:
213
+ img_gt = None
214
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
215
+ else:
216
+ img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
217
+ img_lq = cv2.imread(f'{args.folder_lq}/{imgname}x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(
218
+ np.float32) / 255.
219
 
220
+ elif args.task in ['compressed_sr']:
221
+ if args.save_img_only:
222
+ img_gt = None
223
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
224
+ else:
225
+ img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
226
+ img_lq = cv2.imread(f'{args.folder_lq}/{imgname}.jpg', cv2.IMREAD_COLOR).astype(
227
+ np.float32) / 255.
228
+
229
+ # 003 real-world image sr (load lq image only)
230
+ elif args.task in ['real_sr', 'lightweight_sr_infer']:
231
+ img_gt = None
232
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
233
+
234
+ # 006 grayscale JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
235
+ elif args.task in ['jpeg_car']:
236
+ img_gt = cv2.imread(path, cv2.IMREAD_UNCHANGED)
237
+ if img_gt.ndim != 2:
238
+ img_gt = util.bgr2ycbcr(img_gt, y_only=True)
239
+ result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
240
+ img_lq = cv2.imdecode(encimg, 0)
241
+ img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
242
+ img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
243
+
244
+ # 006 JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
245
+ elif args.task in ['color_jpeg_car']:
246
+ img_gt = cv2.imread(path)
247
+ result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
248
+ img_lq = cv2.imdecode(encimg, 1)
249
+ img_gt = img_gt.astype(np.float32)/ 255.
250
+ img_lq = img_lq.astype(np.float32)/ 255.
251
+
252
+ return imgname, img_lq, img_gt
253
+
254
+
255
+ def test(img_lq, model, args, window_size):
256
+ if args.tile is None:
257
+ # test the image as a whole
258
+ output = model(img_lq)
259
+ else:
260
+ # test the image tile by tile
261
+ b, c, h, w = img_lq.size()
262
+ tile = min(args.tile, h, w)
263
+ assert tile % window_size == 0, "tile size should be a multiple of window_size"
264
+ tile_overlap = args.tile_overlap
265
+ sf = args.scale
266
+
267
+ stride = tile - tile_overlap
268
+ h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
269
+ w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
270
+ E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
271
+ W = torch.zeros_like(E)
272
+
273
+ for h_idx in h_idx_list:
274
+ for w_idx in w_idx_list:
275
+ in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
276
+ out_patch = model(in_patch)
277
+ out_patch_mask = torch.ones_like(out_patch)
278
+
279
+ E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
280
+ W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
281
+ output = E.div_(W)
282
+
283
+ return output
284
+
285
+ if __name__ == '__main__':
286
+
287
+ parser = argparse.ArgumentParser()
288
+ parser.add_argument('--task', type=str, default='color_dn', help='classical_sr, lightweight_sr, real_sr, '
289
+ 'gray_dn, color_dn, jpeg_car, color_jpeg_car')
290
+ parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
291
+ parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
292
+ parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
293
+ parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training Swin2SR. '
294
+ 'Just used to differentiate two different settings in Table 2 of the paper. '
295
+ 'Images are NOT tested patch by patch.')
296
+ parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
297
+ parser.add_argument('--model_path', type=str,
298
+ default='model_zoo/swin2sr/Swin2SR_ClassicalSR_X2_64.pth')
299
+ parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
300
+ parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
301
+ parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
302
+ parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
303
+ parser.add_argument('--save_img_only', default=False, action='store_true', help='save image and do not evaluate')
304
+ args = parser.parse_args()
305
+
306
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
307
+ # set up model
308
+ if os.path.exists(args.model_path):
309
+ print(f'loading model from {args.model_path}')
310
+ else:
311
+ os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
312
+ url = 'https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}'.format(os.path.basename(args.model_path))
313
+ r = requests.get(url, allow_redirects=True)
314
+ print(f'downloading model {args.model_path}')
315
+ open(args.model_path, 'wb').write(r.content)
316
+
317
+ model = setup_model(args)
318
+
319
+ os.makedirs("test", exist_ok=True)
320
+
321
+ #main(img)
322
+
323
+ title = "Swin2SR"
324
+ description = "Gradio demo for Swin2SR."
325
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2209.11345' target='_blank'>Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration</a> | <a href='https://github.com/mv-lab/swin2sr' target='_blank'>Github Repo</a></p>"
326
+
327
+ examples=[['butterflyx4.png']]
328
+ gr.Interface(
329
+ main,
330
+ gr.inputs.Image(type="pil", label="Input"),
331
+ "image",
332
+ title=title,
333
+ description=description,
334
+ article=article,
335
+ examples=examples,
336
+ ).launch(enable_queue=True)