zhengrongzhang commited on
Commit
3135a01
·
1 Parent(s): 00dcd01

init model

Browse files
PAN_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5b5e35f9eeaf54988685263e868a1c54cb075a0560d5228af5f423d123af3be
3
+ size 1263469
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - Set5
5
+ - Div2K
6
+ language:
7
+ - en
8
+ tags:
9
+ - RyzenAI
10
+ - PAN
11
+ - Pytorch
12
+ - Super Resolution
13
+ - Vision
14
+ pipeline_tag: image-to-image
15
+ ---
16
+
17
+ ## Model description
18
+
19
+ PAN is an lightwight image super-resolution method with pixel pttention. It was introduced in the paper [Efficient Image Super-Resolution Using Pixel Attention](https://arxiv.org/abs/2010.01073) by Hengyuan Zhao et al. and first released in [this repository](https://github.com/zhaohengyuan1/PAN).
20
+
21
+ We changed the negative slope of the leaky ReLU of the original model and replaced the sigmoid activation with hard sigmoid to make the model compatible with [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html). We loaded the published model parameters and fine-tuned them on the DIV2K dataset.
22
+
23
+
24
+ ## Intended uses & limitations
25
+
26
+ You can use the raw model for super resolution. See the [model hub](https://huggingface.co/models?search=amd/pan) to look for all available PAN models.
27
+
28
+
29
+ ## How to use
30
+
31
+ ### Installation
32
+
33
+ Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
34
+ Run the following script to install pre-requisites for this model.
35
+ ```bash
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+
40
+ ### Data Preparation (optional: for accuracy evaluation)
41
+
42
+ 1. Download the benchmark(https://cv.snu.ac.kr/research/EDSR/benchmark.tar) dataset.
43
+ 3. Unzip the dataset and put it under the project folder. Organize the dataset directory as follows:
44
+ ```Plain
45
+ PAN
46
+ └── dataset
47
+ └── benchmark
48
+ ├── Set5
49
+ ├── HR
50
+ | ├── baby.png
51
+ | ├── ...
52
+ └── LR_bicubic
53
+ └──X2
54
+ ├──babyx2.png
55
+ ├── ...
56
+ ├── Set14
57
+ ├── ...
58
+ ```
59
+
60
+ ### Test & Evaluation
61
+
62
+ - Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
63
+ ```python
64
+ parser = argparse.ArgumentParser(description='PAN SR')
65
+ parser.add_argument('--onnx_path',
66
+ type=str,
67
+ default='PAN_int8.onnx',
68
+ help='Onnx path')
69
+ parser.add_argument('--image_path',
70
+ type=str,
71
+ default='test_data/test.png',
72
+ help='Path to your input image.')
73
+ parser.add_argument('--output_path',
74
+ type=str,
75
+ default='test_data/sr.png',
76
+ help='Path to your output image.')
77
+ parser.add_argument('--provider_config',
78
+ type=str,
79
+ default="vaip_config.json",
80
+ help="Path of the config file for seting provider_options.")
81
+ parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
82
+
83
+ args = parser.parse_args()
84
+
85
+ onnx_file_name = args.onnx_path
86
+ image_path = args.image_path
87
+ output_path = args.output_path
88
+
89
+ if args.ipu:
90
+ providers = ["VitisAIExecutionProvider"]
91
+ provider_options = [{"config_file": args.provider_config}]
92
+ else:
93
+ providers = ['CPUExecutionProvider']
94
+ provider_options = None
95
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
96
+
97
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
98
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
99
+ sr = np.clip(sr, 0, 255)
100
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
101
+ sr = cv2.imwrite(output_path, sr)
102
+ ```
103
+
104
+ - Run inference for a single image
105
+ ```python
106
+ python infer_onnx.py --onnx_path PAN_int8.onnx --image_path /Path/To/Your/Image --ipu --provider_config Path\To\vaip_config.json
107
+ ```
108
+
109
+ - Test accuracy of the quantized model
110
+ ```python
111
+ python eval_onnx.py --onnx_path PAN_int8.onnx --data_test Set5 --ipu --provider_config Path\To\vaip_config.json
112
+ ```
113
+
114
+ Note: **vaip_config.json** is located at the setup package of Ryzen AI (refer to [Installation](https://huggingface.co/amd/yolox-s#installation))
115
+
116
+ ### Performance
117
+
118
+ | Method | Scale | Flops | Set5 |
119
+ |------------|-------|-------|--------------|
120
+ |PAN (float) |X2 |141G |38.00 / 0.961|
121
+ |PAN_amd (float) |X2 |141G |37.859 / 0.960|
122
+ |PAN_amd (int8) |X2 |141G |37.18 / 0.952|
123
+ - Note: the Flops is calculated with the output resolution is 360x640
124
+
125
+ ```bibtex
126
+ @inproceedings{zhao2020efficient,
127
+ title={Efficient image super-resolution using pixel attention},
128
+ author={Zhao, Hengyuan and Kong, Xiangtao and He, Jingwen and Qiao, Yu and Dong, Chao},
129
+ booktitle={European Conference on Computer Vision},
130
+ pages={56--72},
131
+ year={2020},
132
+ organization={Springer}
133
+ }
134
+ ```
data/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from torch.utils.data import dataloader
3
+
4
+
5
+ class Data:
6
+ def __init__(self, args):
7
+ self.loader_test = []
8
+ for d in args.data_test:
9
+ if d in ['Set5', 'Set14', 'B100', 'Urban100']:
10
+ m = import_module('data.benchmark')
11
+ testset = getattr(m, 'Benchmark')(args, name=d)
12
+ else:
13
+ raise NotImplementedError
14
+
15
+ self.loader_test.append(
16
+ dataloader.DataLoader(
17
+ testset,
18
+ batch_size=1,
19
+ shuffle=False,
20
+ pin_memory=False,
21
+ num_workers=args.n_threads,
22
+ )
23
+ )
data/benchmark.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from data import srdata
4
+
5
+ class Benchmark(srdata.SRData):
6
+ def __init__(self, args, name='', benchmark=True):
7
+ super(Benchmark, self).__init__(
8
+ args, name=name, benchmark=True
9
+ )
10
+
11
+ def _set_filesystem(self, dir_data):
12
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
13
+ self.dir_hr = os.path.join(self.apath, 'HR')
14
+ if self.input_large:
15
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
16
+ else:
17
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
18
+ self.ext = ('', '.png')
data/common.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def set_channel(*args, n_channels=3):
9
+ def _set_channel(img):
10
+ if img.ndim == 2:
11
+ img = np.expand_dims(img, axis=2)
12
+
13
+ c = img.shape[2]
14
+ if n_channels == 1 and c == 3:
15
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
16
+ elif n_channels == 3 and c == 1:
17
+ img = np.concatenate([img] * n_channels, 2)
18
+
19
+ return img
20
+
21
+ return [_set_channel(a) for a in args]
22
+
23
+ def np2Tensor(*args, rgb_range=255):
24
+ def _np2Tensor(img):
25
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
26
+ tensor = torch.from_numpy(np_transpose).float()
27
+ tensor.mul_(rgb_range / 255)
28
+
29
+ return tensor
30
+
31
+ return [_np2Tensor(a) for a in args]
data/data_tiling.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+
5
+ def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
6
+ """
7
+ Parameters:
8
+ - session: an ONNX Runtime session object that contains the super-resolution model
9
+ - lr: the low-resolution image
10
+ - overlapping: the number of pixels to overlap between adjacent patches
11
+ - patch_size: a tuple of (height, width) that specifies the size of each patch
12
+ Returns: - a numpy array that represents the enhanced image
13
+ """
14
+ _, _, h, w = lr.shape
15
+ sr = np.zeros((1, 3, 2*h, 2*w))
16
+ n_h = math.ceil(h / float(patch_size[0] - overlapping))
17
+ n_w = math.ceil(w / float(patch_size[1] - overlapping))
18
+ #every tilling input has same size of patch_size
19
+ for ih in range(n_h):
20
+ h_idx = ih * (patch_size[0] - overlapping)
21
+ h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0]
22
+ for iw in range(n_w):
23
+ w_idx = iw * (patch_size[1] - overlapping)
24
+ w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
25
+
26
+ tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
27
+ sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
28
+
29
+ left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
30
+ left += overlapping//2
31
+ right -= overlapping//2
32
+ top += overlapping//2
33
+ bottom -= overlapping//2
34
+ #processing edge pixels
35
+ if w_idx == 0:
36
+ left -= overlapping//2
37
+ if h_idx == 0:
38
+ top -= overlapping//2
39
+ if h_idx+patch_size[0]>=h:
40
+ bottom += overlapping//2
41
+ if w_idx+patch_size[1]>=w:
42
+ right += overlapping//2
43
+
44
+ #get preditions
45
+ sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
46
+ return sr
data/srdata.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+
6
+ from data import common
7
+
8
+ import imageio
9
+ import torch.utils.data as data
10
+
11
+ class SRData(data.Dataset):
12
+ def __init__(self, args, name='', benchmark=True):
13
+ self.args = args
14
+ self.name = name
15
+ self.benchmark = benchmark
16
+ self.input_large = False
17
+ self.scale = args.scale
18
+ self.idx_scale = 0
19
+
20
+ self._set_filesystem(args.dir_data)
21
+ if args.ext.find('img') < 0:
22
+ path_bin = os.path.join(self.apath, 'bin')
23
+ os.makedirs(path_bin, exist_ok=True)
24
+
25
+ list_hr, list_lr = self._scan()
26
+ if args.ext.find('img') >= 0 or benchmark:
27
+ self.images_hr, self.images_lr = list_hr, list_lr
28
+ elif args.ext.find('sep') >= 0:
29
+ os.makedirs(
30
+ self.dir_hr.replace(self.apath, path_bin),
31
+ exist_ok=True
32
+ )
33
+ for s in self.scale:
34
+ os.makedirs(
35
+ os.path.join(
36
+ self.dir_lr.replace(self.apath, path_bin),
37
+ 'X{}'.format(s)
38
+ ),
39
+ exist_ok=True
40
+ )
41
+
42
+ self.images_hr, self.images_lr = [], [[] for _ in self.scale]
43
+ for h in list_hr:
44
+ b = h.replace(self.apath, path_bin)
45
+ b = b.replace(self.ext[0], '.pt')
46
+ self.images_hr.append(b)
47
+ self._check_and_load(args.ext, h, b, verbose=True)
48
+ for i, ll in enumerate(list_lr):
49
+ for l in ll:
50
+ b = l.replace(self.apath, path_bin)
51
+ b = b.replace(self.ext[1], '.pt')
52
+ self.images_lr[i].append(b)
53
+ self._check_and_load(args.ext, l, b, verbose=True)
54
+
55
+ # Below functions as used to prepare images
56
+ def _scan(self):
57
+ names_hr = sorted(
58
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
59
+ )
60
+ names_lr = [[] for _ in self.scale]
61
+ for f in names_hr:
62
+ filename, _ = os.path.splitext(os.path.basename(f))
63
+ for si, s in enumerate(self.scale):
64
+ names_lr[si].append(os.path.join(
65
+ self.dir_lr, 'X{}/{}x{}{}'.format(
66
+ s, filename, s, self.ext[1]
67
+ )
68
+ ))
69
+
70
+ return names_hr, names_lr
71
+
72
+ def _set_filesystem(self, dir_data):
73
+ self.apath = os.path.join(dir_data, self.name)
74
+ self.dir_hr = os.path.join(self.apath, 'HR')
75
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
76
+ if self.input_large: self.dir_lr += 'L'
77
+ self.ext = ('.png', '.png')
78
+
79
+ def _check_and_load(self, ext, img, f, verbose=True):
80
+ if not os.path.isfile(f) or ext.find('reset') >= 0:
81
+ if verbose:
82
+ print('Making a binary: {}'.format(f))
83
+ with open(f, 'wb') as _f:
84
+ pickle.dump(imageio.imread(img), _f)
85
+
86
+ def __getitem__(self, idx):
87
+ lr, hr, filename = self._load_file(idx)
88
+ pair = self.get_patch(lr, hr)
89
+ pair = common.set_channel(*pair, n_channels=self.args.n_colors)
90
+ pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
91
+
92
+ return pair_t[0], pair_t[1], filename
93
+
94
+ def __len__(self):
95
+ return len(self.images_hr)
96
+
97
+ def _get_index(self, idx):
98
+ return idx
99
+
100
+ def _load_file(self, idx):
101
+ idx = self._get_index(idx)
102
+ f_hr = self.images_hr[idx]
103
+ f_lr = self.images_lr[self.idx_scale][idx]
104
+
105
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
106
+ if self.args.ext == 'img' or self.benchmark:
107
+ hr = imageio.imread(f_hr)
108
+ lr = imageio.imread(f_lr)
109
+ elif self.args.ext.find('sep') >= 0:
110
+ with open(f_hr, 'rb') as _f:
111
+ hr = pickle.load(_f)
112
+ with open(f_lr, 'rb') as _f:
113
+ lr = pickle.load(_f)
114
+
115
+ return lr, hr, filename
116
+
117
+ def get_patch(self, lr, hr):
118
+ scale = self.scale[self.idx_scale]
119
+
120
+ ih, iw = lr.shape[:2]
121
+ hr = hr[0:ih * scale, 0:iw * scale]
122
+
123
+ return lr, hr
124
+
125
+ def set_scale(self, idx_scale):
126
+ if not self.input_large:
127
+ self.idx_scale = idx_scale
128
+ else:
129
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
130
+
eval_onnx.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import pathlib
3
+ CURRENT_DIR = pathlib.Path(__file__).parent
4
+ sys.path.append(str(CURRENT_DIR))
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+ import utility
9
+ import data
10
+ from option import args
11
+ import onnxruntime
12
+ from data.data_tiling import tiling_inference
13
+
14
+
15
+ def test_model(session, loader):
16
+ torch.set_grad_enabled(False)
17
+ self_scale = [2]
18
+ for idx_data, d in enumerate(loader.loader_test):
19
+ eval_ssim = 0
20
+ eval_psnr = 0
21
+ for idx_scale, scale in enumerate(self_scale):
22
+ d.dataset.set_scale(idx_scale)
23
+ for lr, hr, filename in tqdm(d, ncols=80):
24
+
25
+ # Tiled inference
26
+ sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
27
+ sr = torch.from_numpy(sr)
28
+ sr = utility.quantize(sr, 255)
29
+ eval_psnr += utility.calc_psnr(
30
+ sr, hr, scale, 255, benchmark=d)
31
+ eval_ssim += utility.calc_ssim(
32
+ sr, hr, scale, 255, dataset=d)
33
+ mean_ssim = eval_ssim / len(d)
34
+ mean_psnr = eval_psnr / len(d)
35
+ print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
36
+ return mean_psnr, mean_ssim
37
+
38
+ def main():
39
+ loader = data.Data(args)
40
+ onnx_file_name = args.onnx_path
41
+ if args.ipu:
42
+ providers = ["VitisAIExecutionProvider"]
43
+ provider_options = [{"config_file": args.provider_config}]
44
+ else:
45
+ providers = ['CPUExecutionProvider']
46
+ provider_options = None
47
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
48
+ test_model(ort_session, loader)
49
+
50
+
51
+ if __name__ == '__main__':
52
+ main()
infer_onnx.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import pathlib
3
+ CURRENT_DIR = pathlib.Path(__file__).parent
4
+ sys.path.append(str(CURRENT_DIR))
5
+
6
+ import onnxruntime
7
+ import cv2
8
+ import numpy as np
9
+ from data.data_tiling import tiling_inference
10
+ import argparse
11
+
12
+
13
+ def main(args):
14
+ onnx_file_name = args.onnx_path
15
+ image_path = args.image_path
16
+ output_path = args.output_path
17
+
18
+ if args.ipu:
19
+ providers = ["VitisAIExecutionProvider"]
20
+ provider_options = [{"config_file": args.provider_config}]
21
+ else:
22
+ providers = ['CPUExecutionProvider']
23
+ provider_options = None
24
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
25
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
26
+
27
+ # Tiled inference
28
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
29
+ sr = np.clip(sr, 0, 255)
30
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
31
+ cv2.imwrite(output_path, sr)
32
+
33
+
34
+
35
+ if __name__ == '__main__':
36
+ parser = argparse.ArgumentParser(description='PAN')
37
+ parser.add_argument('--onnx_path',
38
+ type=str,
39
+ default='PAN_int8.onnx',
40
+ help='Path to onnx model')
41
+ parser.add_argument('--image_path',
42
+ type=str,
43
+ default='test_data/test.png',
44
+ help='Path to your low resolution input image.')
45
+ parser.add_argument('--output_path',
46
+ type=str,
47
+ default='test_data/sr.png',
48
+ help='Path to your upscaled output image.')
49
+ parser.add_argument('--provider_config',
50
+ type=str,
51
+ default="vaip_config.json",
52
+ help="Path of the config file for seting provider_options.")
53
+ parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
54
+
55
+ args = parser.parse_args()
56
+ main(args)
option.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ parser = argparse.ArgumentParser(description='PAN')
4
+
5
+ # Hardware specifications
6
+ parser.add_argument('--n_threads', type=int, default=6,
7
+ help='Number of threads for data loading')
8
+ parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
9
+
10
+ # Data specifications
11
+ parser.add_argument('--dir_data', type=str, default='dataset/',
12
+ help='Dataset directory')
13
+ parser.add_argument('--data_test', type=str, default='Set5',
14
+ help='Test dataset name')
15
+ parser.add_argument('--ext', type=str, default='sep',
16
+ help='Dataset file extension')
17
+ parser.add_argument('--scale', type=str, default='2',
18
+ help='Super resolution scale')
19
+ parser.add_argument('--rgb_range', type=int, default=255,
20
+ help='Maximum value of RGB')
21
+ parser.add_argument('--n_colors', type=int, default=3,
22
+ help='Number of color channels to use')
23
+ parser.add_argument('--onnx_path', type=str, default='PAN_int8.onnx',
24
+ help='Path to onnx model')
25
+ parser.add_argument('--provider_config',
26
+ type=str,
27
+ default="vaip_config.json",
28
+ help="Path of the config file for seting provider_options.")
29
+
30
+
31
+ args = parser.parse_args()
32
+
33
+ args.scale = list(map(lambda x: int(x), args.scale.split('+')))
34
+ args.data_test = args.data_test.split('+')
35
+
36
+
37
+ for arg in vars(args):
38
+ if vars(args)[arg] == 'True':
39
+ vars(args)[arg] = True
40
+ elif vars(args)[arg] == 'False':
41
+ vars(args)[arg] = False
42
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ numpy>=1.23.5
3
+ scipy>=1.9
4
+ opencv-python
5
+ pandas
6
+ pillow
7
+ scikit-image
8
+ tqdm
test_data/sr.png ADDED
test_data/test.png ADDED
utility.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from scipy import signal
4
+
5
+
6
+ def quantize(img, rgb_range):
7
+ pixel_range = 255 / rgb_range
8
+ return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
9
+
10
+ def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
11
+ if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
12
+ print("the dimention of sr image is not equal to hr's! ")
13
+ sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
14
+ diff = (sr - hr).data.div(rgb_range)
15
+
16
+ if benchmark:
17
+ shave = scale
18
+ if diff.size(1) > 1:
19
+ convert = diff.new(1, 3, 1, 1)
20
+ convert[0, 0, 0, 0] = 65.738
21
+ convert[0, 1, 0, 0] = 129.057
22
+ convert[0, 2, 0, 0] = 25.064
23
+ diff.mul_(convert).div_(256)
24
+ diff = diff.sum(dim=1, keepdim=True)
25
+ else:
26
+ shave = scale + 6
27
+ valid = diff[:, :, shave:-shave, shave:-shave]
28
+ mse = valid.pow(2).mean()
29
+
30
+ return -10 * math.log10(mse)
31
+
32
+
33
+ def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
34
+ """
35
+ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
36
+ Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
37
+ """
38
+ m,n = [(ss-1.)/2. for ss in shape]
39
+ y,x = np.ogrid[-m:m+1,-n:n+1]
40
+ h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
41
+ h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
42
+ sumh = h.sum()
43
+ if sumh != 0:
44
+ h /= sumh
45
+ return h
46
+
47
+
48
+ def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
49
+ '''
50
+ X : y channel (i.e., luminance) of transformed YCbCr space of X
51
+ Y : y channel (i.e., luminance) of transformed YCbCr space of Y
52
+ '''
53
+ gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
54
+
55
+ shave = scale
56
+ if X.size(1) > 1:
57
+ gray_coeffs = [65.738, 129.057, 25.064]
58
+ convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
59
+ X = X.mul(convert).sum(dim=1)
60
+ Y = Y.mul(convert).sum(dim=1)
61
+
62
+
63
+ X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
64
+ Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
65
+
66
+ window = gaussian_filter
67
+
68
+ ux = signal.convolve2d(X, window, mode='same', boundary='symm')
69
+ uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
70
+
71
+ uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
72
+ uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
73
+ uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
74
+
75
+ vx = uxx - ux * ux
76
+ vy = uyy - uy * uy
77
+ vxy = uxy - ux * uy
78
+
79
+ C1 = (K1 * R) ** 2
80
+ C2 = (K2 * R) ** 2
81
+
82
+ A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
83
+ D = B1 * B2
84
+ S = (A1 * A2) / D
85
+ mssim = S.mean()
86
+
87
+ return mssim
88
+
89
+