jjourney1125 commited on
Commit
d53e2c2
1 Parent(s): 0c4e407

Delete main_test_swin2sr.py

Browse files
Files changed (1) hide show
  1. main_test_swin2sr.py +0 -302
main_test_swin2sr.py DELETED
@@ -1,302 +0,0 @@
1
- import argparse
2
- import cv2
3
- import glob
4
- import numpy as np
5
- from collections import OrderedDict
6
- import os
7
- import torch
8
- import requests
9
-
10
- from models.network_swin2sr import Swin2SR as net
11
- from utils import util_calculate_psnr_ssim as util
12
-
13
-
14
- def main():
15
- parser = argparse.ArgumentParser()
16
- parser.add_argument('--task', type=str, default='color_dn', help='classical_sr, lightweight_sr, real_sr, '
17
- 'gray_dn, color_dn, jpeg_car, color_jpeg_car')
18
- parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
19
- parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
20
- parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
21
- parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training Swin2SR. '
22
- 'Just used to differentiate two different settings in Table 2 of the paper. '
23
- 'Images are NOT tested patch by patch.')
24
- parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
25
- parser.add_argument('--model_path', type=str,
26
- default='model_zoo/swin2sr/Swin2SR_ClassicalSR_X2_64.pth')
27
- parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
28
- parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
29
- parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
30
- parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
31
- parser.add_argument('--save_img_only', default=False, action='store_true', help='save image and do not evaluate')
32
- args = parser.parse_args()
33
-
34
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
- # set up model
36
- if os.path.exists(args.model_path):
37
- print(f'loading model from {args.model_path}')
38
- else:
39
- os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
40
- url = 'https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}'.format(os.path.basename(args.model_path))
41
- r = requests.get(url, allow_redirects=True)
42
- print(f'downloading model {args.model_path}')
43
- open(args.model_path, 'wb').write(r.content)
44
-
45
- model = define_model(args)
46
- model.eval()
47
- model = model.to(device)
48
-
49
- # setup folder and path
50
- folder, save_dir, border, window_size = setup(args)
51
- os.makedirs(save_dir, exist_ok=True)
52
- test_results = OrderedDict()
53
- test_results['psnr'] = []
54
- test_results['ssim'] = []
55
- test_results['psnr_y'] = []
56
- test_results['ssim_y'] = []
57
- test_results['psnrb'] = []
58
- test_results['psnrb_y'] = []
59
- psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0
60
-
61
- for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
62
- # read image
63
- imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32
64
- img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
65
- img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
66
-
67
- # inference
68
- with torch.no_grad():
69
- # pad input image to be a multiple of window_size
70
- _, _, h_old, w_old = img_lq.size()
71
- h_pad = (h_old // window_size + 1) * window_size - h_old
72
- w_pad = (w_old // window_size + 1) * window_size - w_old
73
- img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
74
- img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
75
- output = test(img_lq, model, args, window_size)
76
-
77
- if args.task == 'compressed_sr':
78
- output = output[0][..., :h_old * args.scale, :w_old * args.scale]
79
- else:
80
- output = output[..., :h_old * args.scale, :w_old * args.scale]
81
-
82
- # save image
83
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
84
- if output.ndim == 3:
85
- output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
86
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
87
- cv2.imwrite(f'{save_dir}/{imgname}_Swin2SR.png', output)
88
-
89
-
90
- # evaluate psnr/ssim/psnr_b
91
- if img_gt is not None:
92
- img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8
93
- img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...] # crop gt
94
- img_gt = np.squeeze(img_gt)
95
-
96
- psnr = util.calculate_psnr(output, img_gt, crop_border=border)
97
- ssim = util.calculate_ssim(output, img_gt, crop_border=border)
98
- test_results['psnr'].append(psnr)
99
- test_results['ssim'].append(ssim)
100
- if img_gt.ndim == 3: # RGB image
101
- psnr_y = util.calculate_psnr(output, img_gt, crop_border=border, test_y_channel=True)
102
- ssim_y = util.calculate_ssim(output, img_gt, crop_border=border, test_y_channel=True)
103
- test_results['psnr_y'].append(psnr_y)
104
- test_results['ssim_y'].append(ssim_y)
105
- if args.task in ['jpeg_car', 'color_jpeg_car']:
106
- psnrb = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=False)
107
- test_results['psnrb'].append(psnrb)
108
- if args.task in ['color_jpeg_car']:
109
- psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=True)
110
- test_results['psnrb_y'].append(psnrb_y)
111
- print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
112
- 'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
113
- format(idx, imgname, psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
114
- else:
115
- print('Testing {:d} {:20s}'.format(idx, imgname))
116
-
117
- # summarize psnr/ssim
118
- if img_gt is not None:
119
- ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
120
- ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
121
- print('\n{} \n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(save_dir, ave_psnr, ave_ssim))
122
- if img_gt.ndim == 3:
123
- ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
124
- ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
125
- print('-- Average PSNR_Y/SSIM_Y: {:.2f} dB; {:.4f}'.format(ave_psnr_y, ave_ssim_y))
126
- if args.task in ['jpeg_car', 'color_jpeg_car']:
127
- ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb'])
128
- print('-- Average PSNRB: {:.2f} dB'.format(ave_psnrb))
129
- if args.task in ['color_jpeg_car']:
130
- ave_psnrb_y = sum(test_results['psnrb_y']) / len(test_results['psnrb_y'])
131
- print('-- Average PSNRB_Y: {:.2f} dB'.format(ave_psnrb_y))
132
-
133
-
134
- def define_model(args):
135
- # 001 classical image sr
136
- if args.task == 'classical_sr':
137
- model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
138
- img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
139
- mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
140
- param_key_g = 'params'
141
-
142
- # 002 lightweight image sr
143
- # use 'pixelshuffledirect' to save parameters
144
- elif args.task in ['lightweight_sr']:
145
- model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
146
- img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
147
- mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
148
- param_key_g = 'params'
149
-
150
- elif args.task == 'compressed_sr':
151
- model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
152
- img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
153
- mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
154
- param_key_g = 'params'
155
-
156
- # 003 real-world image sr
157
- elif args.task == 'real_sr':
158
- if not args.large_model:
159
- # use 'nearest+conv' to avoid block artifacts
160
- model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
161
- img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
162
- mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
163
- else:
164
- # larger model size; use '3conv' to save parameters and memory; use ema for GAN training
165
- model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
166
- img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
167
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
168
- mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
169
- param_key_g = 'params_ema'
170
-
171
- # 006 grayscale JPEG compression artifact reduction
172
- # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
173
- elif args.task == 'jpeg_car':
174
- model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
175
- img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
176
- mlp_ratio=2, upsampler='', resi_connection='1conv')
177
- param_key_g = 'params'
178
-
179
- # 006 color JPEG compression artifact reduction
180
- # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
181
- elif args.task == 'color_jpeg_car':
182
- model = net(upscale=1, in_chans=3, img_size=126, window_size=7,
183
- img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
184
- mlp_ratio=2, upsampler='', resi_connection='1conv')
185
- param_key_g = 'params'
186
-
187
- pretrained_model = torch.load(args.model_path)
188
- model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
189
-
190
- return model
191
-
192
-
193
- def setup(args):
194
- # 001 classical image sr/ 002 lightweight image sr
195
- if args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
196
- save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
197
- if args.save_img_only:
198
- folder = args.folder_lq
199
- else:
200
- folder = args.folder_gt
201
- border = args.scale
202
- window_size = 8
203
-
204
- # 003 real-world image sr
205
- elif args.task in ['real_sr']:
206
- save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
207
- if args.large_model:
208
- save_dir += '_large'
209
- folder = args.folder_lq
210
- border = 0
211
- window_size = 8
212
-
213
- # 006 JPEG compression artifact reduction
214
- elif args.task in ['jpeg_car', 'color_jpeg_car']:
215
- save_dir = f'results/swin2sr_{args.task}_jpeg{args.jpeg}'
216
- folder = args.folder_gt
217
- border = 0
218
- window_size = 7
219
-
220
- return folder, save_dir, border, window_size
221
-
222
-
223
- def get_image_pair(args, path):
224
- (imgname, imgext) = os.path.splitext(os.path.basename(path))
225
-
226
- # 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)
227
- if args.task in ['classical_sr', 'lightweight_sr']:
228
- if args.save_img_only:
229
- img_gt = None
230
- img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
231
- else:
232
- img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
233
- img_lq = cv2.imread(f'{args.folder_lq}/{imgname}x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(
234
- np.float32) / 255.
235
-
236
- elif args.task in ['compressed_sr']:
237
- if args.save_img_only:
238
- img_gt = None
239
- img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
240
- else:
241
- img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
242
- img_lq = cv2.imread(f'{args.folder_lq}/{imgname}.jpg', cv2.IMREAD_COLOR).astype(
243
- np.float32) / 255.
244
-
245
- # 003 real-world image sr (load lq image only)
246
- elif args.task in ['real_sr', 'lightweight_sr_infer']:
247
- img_gt = None
248
- img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
249
-
250
- # 006 grayscale JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
251
- elif args.task in ['jpeg_car']:
252
- img_gt = cv2.imread(path, cv2.IMREAD_UNCHANGED)
253
- if img_gt.ndim != 2:
254
- img_gt = util.bgr2ycbcr(img_gt, y_only=True)
255
- result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
256
- img_lq = cv2.imdecode(encimg, 0)
257
- img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
258
- img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
259
-
260
- # 006 JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
261
- elif args.task in ['color_jpeg_car']:
262
- img_gt = cv2.imread(path)
263
- result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
264
- img_lq = cv2.imdecode(encimg, 1)
265
- img_gt = img_gt.astype(np.float32)/ 255.
266
- img_lq = img_lq.astype(np.float32)/ 255.
267
-
268
- return imgname, img_lq, img_gt
269
-
270
-
271
- def test(img_lq, model, args, window_size):
272
- if args.tile is None:
273
- # test the image as a whole
274
- output = model(img_lq)
275
- else:
276
- # test the image tile by tile
277
- b, c, h, w = img_lq.size()
278
- tile = min(args.tile, h, w)
279
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
280
- tile_overlap = args.tile_overlap
281
- sf = args.scale
282
-
283
- stride = tile - tile_overlap
284
- h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
285
- w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
286
- E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
287
- W = torch.zeros_like(E)
288
-
289
- for h_idx in h_idx_list:
290
- for w_idx in w_idx_list:
291
- in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
292
- out_patch = model(in_patch)
293
- out_patch_mask = torch.ones_like(out_patch)
294
-
295
- E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
296
- W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
297
- output = E.div_(W)
298
-
299
- return output
300
-
301
- if __name__ == '__main__':
302
- main()