zhengrongzhang
commited on
Commit
·
3135a01
1
Parent(s):
00dcd01
init model
Browse files- PAN_int8.onnx +3 -0
- README.md +134 -0
- data/__init__.py +23 -0
- data/benchmark.py +18 -0
- data/common.py +31 -0
- data/data_tiling.py +46 -0
- data/srdata.py +130 -0
- eval_onnx.py +52 -0
- infer_onnx.py +56 -0
- option.py +42 -0
- requirements.txt +8 -0
- test_data/sr.png +0 -0
- test_data/test.png +0 -0
- utility.py +89 -0
PAN_int8.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5b5e35f9eeaf54988685263e868a1c54cb075a0560d5228af5f423d123af3be
|
3 |
+
size 1263469
|
README.md
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- Set5
|
5 |
+
- Div2K
|
6 |
+
language:
|
7 |
+
- en
|
8 |
+
tags:
|
9 |
+
- RyzenAI
|
10 |
+
- PAN
|
11 |
+
- Pytorch
|
12 |
+
- Super Resolution
|
13 |
+
- Vision
|
14 |
+
pipeline_tag: image-to-image
|
15 |
+
---
|
16 |
+
|
17 |
+
## Model description
|
18 |
+
|
19 |
+
PAN is an lightwight image super-resolution method with pixel pttention. It was introduced in the paper [Efficient Image Super-Resolution Using Pixel Attention](https://arxiv.org/abs/2010.01073) by Hengyuan Zhao et al. and first released in [this repository](https://github.com/zhaohengyuan1/PAN).
|
20 |
+
|
21 |
+
We changed the negative slope of the leaky ReLU of the original model and replaced the sigmoid activation with hard sigmoid to make the model compatible with [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html). We loaded the published model parameters and fine-tuned them on the DIV2K dataset.
|
22 |
+
|
23 |
+
|
24 |
+
## Intended uses & limitations
|
25 |
+
|
26 |
+
You can use the raw model for super resolution. See the [model hub](https://huggingface.co/models?search=amd/pan) to look for all available PAN models.
|
27 |
+
|
28 |
+
|
29 |
+
## How to use
|
30 |
+
|
31 |
+
### Installation
|
32 |
+
|
33 |
+
Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
|
34 |
+
Run the following script to install pre-requisites for this model.
|
35 |
+
```bash
|
36 |
+
pip install -r requirements.txt
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Data Preparation (optional: for accuracy evaluation)
|
41 |
+
|
42 |
+
1. Download the benchmark(https://cv.snu.ac.kr/research/EDSR/benchmark.tar) dataset.
|
43 |
+
3. Unzip the dataset and put it under the project folder. Organize the dataset directory as follows:
|
44 |
+
```Plain
|
45 |
+
PAN
|
46 |
+
└── dataset
|
47 |
+
└── benchmark
|
48 |
+
├── Set5
|
49 |
+
├── HR
|
50 |
+
| ├── baby.png
|
51 |
+
| ├── ...
|
52 |
+
└── LR_bicubic
|
53 |
+
└──X2
|
54 |
+
├──babyx2.png
|
55 |
+
├── ...
|
56 |
+
├── Set14
|
57 |
+
├── ...
|
58 |
+
```
|
59 |
+
|
60 |
+
### Test & Evaluation
|
61 |
+
|
62 |
+
- Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
|
63 |
+
```python
|
64 |
+
parser = argparse.ArgumentParser(description='PAN SR')
|
65 |
+
parser.add_argument('--onnx_path',
|
66 |
+
type=str,
|
67 |
+
default='PAN_int8.onnx',
|
68 |
+
help='Onnx path')
|
69 |
+
parser.add_argument('--image_path',
|
70 |
+
type=str,
|
71 |
+
default='test_data/test.png',
|
72 |
+
help='Path to your input image.')
|
73 |
+
parser.add_argument('--output_path',
|
74 |
+
type=str,
|
75 |
+
default='test_data/sr.png',
|
76 |
+
help='Path to your output image.')
|
77 |
+
parser.add_argument('--provider_config',
|
78 |
+
type=str,
|
79 |
+
default="vaip_config.json",
|
80 |
+
help="Path of the config file for seting provider_options.")
|
81 |
+
parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
|
85 |
+
onnx_file_name = args.onnx_path
|
86 |
+
image_path = args.image_path
|
87 |
+
output_path = args.output_path
|
88 |
+
|
89 |
+
if args.ipu:
|
90 |
+
providers = ["VitisAIExecutionProvider"]
|
91 |
+
provider_options = [{"config_file": args.provider_config}]
|
92 |
+
else:
|
93 |
+
providers = ['CPUExecutionProvider']
|
94 |
+
provider_options = None
|
95 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
96 |
+
|
97 |
+
lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
|
98 |
+
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
99 |
+
sr = np.clip(sr, 0, 255)
|
100 |
+
sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
|
101 |
+
sr = cv2.imwrite(output_path, sr)
|
102 |
+
```
|
103 |
+
|
104 |
+
- Run inference for a single image
|
105 |
+
```python
|
106 |
+
python infer_onnx.py --onnx_path PAN_int8.onnx --image_path /Path/To/Your/Image --ipu --provider_config Path\To\vaip_config.json
|
107 |
+
```
|
108 |
+
|
109 |
+
- Test accuracy of the quantized model
|
110 |
+
```python
|
111 |
+
python eval_onnx.py --onnx_path PAN_int8.onnx --data_test Set5 --ipu --provider_config Path\To\vaip_config.json
|
112 |
+
```
|
113 |
+
|
114 |
+
Note: **vaip_config.json** is located at the setup package of Ryzen AI (refer to [Installation](https://huggingface.co/amd/yolox-s#installation))
|
115 |
+
|
116 |
+
### Performance
|
117 |
+
|
118 |
+
| Method | Scale | Flops | Set5 |
|
119 |
+
|------------|-------|-------|--------------|
|
120 |
+
|PAN (float) |X2 |141G |38.00 / 0.961|
|
121 |
+
|PAN_amd (float) |X2 |141G |37.859 / 0.960|
|
122 |
+
|PAN_amd (int8) |X2 |141G |37.18 / 0.952|
|
123 |
+
- Note: the Flops is calculated with the output resolution is 360x640
|
124 |
+
|
125 |
+
```bibtex
|
126 |
+
@inproceedings{zhao2020efficient,
|
127 |
+
title={Efficient image super-resolution using pixel attention},
|
128 |
+
author={Zhao, Hengyuan and Kong, Xiangtao and He, Jingwen and Qiao, Yu and Dong, Chao},
|
129 |
+
booktitle={European Conference on Computer Vision},
|
130 |
+
pages={56--72},
|
131 |
+
year={2020},
|
132 |
+
organization={Springer}
|
133 |
+
}
|
134 |
+
```
|
data/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from torch.utils.data import dataloader
|
3 |
+
|
4 |
+
|
5 |
+
class Data:
|
6 |
+
def __init__(self, args):
|
7 |
+
self.loader_test = []
|
8 |
+
for d in args.data_test:
|
9 |
+
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
|
10 |
+
m = import_module('data.benchmark')
|
11 |
+
testset = getattr(m, 'Benchmark')(args, name=d)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
self.loader_test.append(
|
16 |
+
dataloader.DataLoader(
|
17 |
+
testset,
|
18 |
+
batch_size=1,
|
19 |
+
shuffle=False,
|
20 |
+
pin_memory=False,
|
21 |
+
num_workers=args.n_threads,
|
22 |
+
)
|
23 |
+
)
|
data/benchmark.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from data import srdata
|
4 |
+
|
5 |
+
class Benchmark(srdata.SRData):
|
6 |
+
def __init__(self, args, name='', benchmark=True):
|
7 |
+
super(Benchmark, self).__init__(
|
8 |
+
args, name=name, benchmark=True
|
9 |
+
)
|
10 |
+
|
11 |
+
def _set_filesystem(self, dir_data):
|
12 |
+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
|
13 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
14 |
+
if self.input_large:
|
15 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
|
16 |
+
else:
|
17 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
18 |
+
self.ext = ('', '.png')
|
data/common.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import skimage.color as sc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def set_channel(*args, n_channels=3):
|
9 |
+
def _set_channel(img):
|
10 |
+
if img.ndim == 2:
|
11 |
+
img = np.expand_dims(img, axis=2)
|
12 |
+
|
13 |
+
c = img.shape[2]
|
14 |
+
if n_channels == 1 and c == 3:
|
15 |
+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
16 |
+
elif n_channels == 3 and c == 1:
|
17 |
+
img = np.concatenate([img] * n_channels, 2)
|
18 |
+
|
19 |
+
return img
|
20 |
+
|
21 |
+
return [_set_channel(a) for a in args]
|
22 |
+
|
23 |
+
def np2Tensor(*args, rgb_range=255):
|
24 |
+
def _np2Tensor(img):
|
25 |
+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
26 |
+
tensor = torch.from_numpy(np_transpose).float()
|
27 |
+
tensor.mul_(rgb_range / 255)
|
28 |
+
|
29 |
+
return tensor
|
30 |
+
|
31 |
+
return [_np2Tensor(a) for a in args]
|
data/data_tiling.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
|
5 |
+
def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)):
|
6 |
+
"""
|
7 |
+
Parameters:
|
8 |
+
- session: an ONNX Runtime session object that contains the super-resolution model
|
9 |
+
- lr: the low-resolution image
|
10 |
+
- overlapping: the number of pixels to overlap between adjacent patches
|
11 |
+
- patch_size: a tuple of (height, width) that specifies the size of each patch
|
12 |
+
Returns: - a numpy array that represents the enhanced image
|
13 |
+
"""
|
14 |
+
_, _, h, w = lr.shape
|
15 |
+
sr = np.zeros((1, 3, 2*h, 2*w))
|
16 |
+
n_h = math.ceil(h / float(patch_size[0] - overlapping))
|
17 |
+
n_w = math.ceil(w / float(patch_size[1] - overlapping))
|
18 |
+
#every tilling input has same size of patch_size
|
19 |
+
for ih in range(n_h):
|
20 |
+
h_idx = ih * (patch_size[0] - overlapping)
|
21 |
+
h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0]
|
22 |
+
for iw in range(n_w):
|
23 |
+
w_idx = iw * (patch_size[1] - overlapping)
|
24 |
+
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
25 |
+
|
26 |
+
tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
|
27 |
+
sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
|
28 |
+
|
29 |
+
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
30 |
+
left += overlapping//2
|
31 |
+
right -= overlapping//2
|
32 |
+
top += overlapping//2
|
33 |
+
bottom -= overlapping//2
|
34 |
+
#processing edge pixels
|
35 |
+
if w_idx == 0:
|
36 |
+
left -= overlapping//2
|
37 |
+
if h_idx == 0:
|
38 |
+
top -= overlapping//2
|
39 |
+
if h_idx+patch_size[0]>=h:
|
40 |
+
bottom += overlapping//2
|
41 |
+
if w_idx+patch_size[1]>=w:
|
42 |
+
right += overlapping//2
|
43 |
+
|
44 |
+
#get preditions
|
45 |
+
sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
|
46 |
+
return sr
|
data/srdata.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from data import common
|
7 |
+
|
8 |
+
import imageio
|
9 |
+
import torch.utils.data as data
|
10 |
+
|
11 |
+
class SRData(data.Dataset):
|
12 |
+
def __init__(self, args, name='', benchmark=True):
|
13 |
+
self.args = args
|
14 |
+
self.name = name
|
15 |
+
self.benchmark = benchmark
|
16 |
+
self.input_large = False
|
17 |
+
self.scale = args.scale
|
18 |
+
self.idx_scale = 0
|
19 |
+
|
20 |
+
self._set_filesystem(args.dir_data)
|
21 |
+
if args.ext.find('img') < 0:
|
22 |
+
path_bin = os.path.join(self.apath, 'bin')
|
23 |
+
os.makedirs(path_bin, exist_ok=True)
|
24 |
+
|
25 |
+
list_hr, list_lr = self._scan()
|
26 |
+
if args.ext.find('img') >= 0 or benchmark:
|
27 |
+
self.images_hr, self.images_lr = list_hr, list_lr
|
28 |
+
elif args.ext.find('sep') >= 0:
|
29 |
+
os.makedirs(
|
30 |
+
self.dir_hr.replace(self.apath, path_bin),
|
31 |
+
exist_ok=True
|
32 |
+
)
|
33 |
+
for s in self.scale:
|
34 |
+
os.makedirs(
|
35 |
+
os.path.join(
|
36 |
+
self.dir_lr.replace(self.apath, path_bin),
|
37 |
+
'X{}'.format(s)
|
38 |
+
),
|
39 |
+
exist_ok=True
|
40 |
+
)
|
41 |
+
|
42 |
+
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
|
43 |
+
for h in list_hr:
|
44 |
+
b = h.replace(self.apath, path_bin)
|
45 |
+
b = b.replace(self.ext[0], '.pt')
|
46 |
+
self.images_hr.append(b)
|
47 |
+
self._check_and_load(args.ext, h, b, verbose=True)
|
48 |
+
for i, ll in enumerate(list_lr):
|
49 |
+
for l in ll:
|
50 |
+
b = l.replace(self.apath, path_bin)
|
51 |
+
b = b.replace(self.ext[1], '.pt')
|
52 |
+
self.images_lr[i].append(b)
|
53 |
+
self._check_and_load(args.ext, l, b, verbose=True)
|
54 |
+
|
55 |
+
# Below functions as used to prepare images
|
56 |
+
def _scan(self):
|
57 |
+
names_hr = sorted(
|
58 |
+
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
59 |
+
)
|
60 |
+
names_lr = [[] for _ in self.scale]
|
61 |
+
for f in names_hr:
|
62 |
+
filename, _ = os.path.splitext(os.path.basename(f))
|
63 |
+
for si, s in enumerate(self.scale):
|
64 |
+
names_lr[si].append(os.path.join(
|
65 |
+
self.dir_lr, 'X{}/{}x{}{}'.format(
|
66 |
+
s, filename, s, self.ext[1]
|
67 |
+
)
|
68 |
+
))
|
69 |
+
|
70 |
+
return names_hr, names_lr
|
71 |
+
|
72 |
+
def _set_filesystem(self, dir_data):
|
73 |
+
self.apath = os.path.join(dir_data, self.name)
|
74 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
75 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
76 |
+
if self.input_large: self.dir_lr += 'L'
|
77 |
+
self.ext = ('.png', '.png')
|
78 |
+
|
79 |
+
def _check_and_load(self, ext, img, f, verbose=True):
|
80 |
+
if not os.path.isfile(f) or ext.find('reset') >= 0:
|
81 |
+
if verbose:
|
82 |
+
print('Making a binary: {}'.format(f))
|
83 |
+
with open(f, 'wb') as _f:
|
84 |
+
pickle.dump(imageio.imread(img), _f)
|
85 |
+
|
86 |
+
def __getitem__(self, idx):
|
87 |
+
lr, hr, filename = self._load_file(idx)
|
88 |
+
pair = self.get_patch(lr, hr)
|
89 |
+
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
|
90 |
+
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
|
91 |
+
|
92 |
+
return pair_t[0], pair_t[1], filename
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.images_hr)
|
96 |
+
|
97 |
+
def _get_index(self, idx):
|
98 |
+
return idx
|
99 |
+
|
100 |
+
def _load_file(self, idx):
|
101 |
+
idx = self._get_index(idx)
|
102 |
+
f_hr = self.images_hr[idx]
|
103 |
+
f_lr = self.images_lr[self.idx_scale][idx]
|
104 |
+
|
105 |
+
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
106 |
+
if self.args.ext == 'img' or self.benchmark:
|
107 |
+
hr = imageio.imread(f_hr)
|
108 |
+
lr = imageio.imread(f_lr)
|
109 |
+
elif self.args.ext.find('sep') >= 0:
|
110 |
+
with open(f_hr, 'rb') as _f:
|
111 |
+
hr = pickle.load(_f)
|
112 |
+
with open(f_lr, 'rb') as _f:
|
113 |
+
lr = pickle.load(_f)
|
114 |
+
|
115 |
+
return lr, hr, filename
|
116 |
+
|
117 |
+
def get_patch(self, lr, hr):
|
118 |
+
scale = self.scale[self.idx_scale]
|
119 |
+
|
120 |
+
ih, iw = lr.shape[:2]
|
121 |
+
hr = hr[0:ih * scale, 0:iw * scale]
|
122 |
+
|
123 |
+
return lr, hr
|
124 |
+
|
125 |
+
def set_scale(self, idx_scale):
|
126 |
+
if not self.input_large:
|
127 |
+
self.idx_scale = idx_scale
|
128 |
+
else:
|
129 |
+
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
130 |
+
|
eval_onnx.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import pathlib
|
3 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
4 |
+
sys.path.append(str(CURRENT_DIR))
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
import utility
|
9 |
+
import data
|
10 |
+
from option import args
|
11 |
+
import onnxruntime
|
12 |
+
from data.data_tiling import tiling_inference
|
13 |
+
|
14 |
+
|
15 |
+
def test_model(session, loader):
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
self_scale = [2]
|
18 |
+
for idx_data, d in enumerate(loader.loader_test):
|
19 |
+
eval_ssim = 0
|
20 |
+
eval_psnr = 0
|
21 |
+
for idx_scale, scale in enumerate(self_scale):
|
22 |
+
d.dataset.set_scale(idx_scale)
|
23 |
+
for lr, hr, filename in tqdm(d, ncols=80):
|
24 |
+
|
25 |
+
# Tiled inference
|
26 |
+
sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
|
27 |
+
sr = torch.from_numpy(sr)
|
28 |
+
sr = utility.quantize(sr, 255)
|
29 |
+
eval_psnr += utility.calc_psnr(
|
30 |
+
sr, hr, scale, 255, benchmark=d)
|
31 |
+
eval_ssim += utility.calc_ssim(
|
32 |
+
sr, hr, scale, 255, dataset=d)
|
33 |
+
mean_ssim = eval_ssim / len(d)
|
34 |
+
mean_psnr = eval_psnr / len(d)
|
35 |
+
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
|
36 |
+
return mean_psnr, mean_ssim
|
37 |
+
|
38 |
+
def main():
|
39 |
+
loader = data.Data(args)
|
40 |
+
onnx_file_name = args.onnx_path
|
41 |
+
if args.ipu:
|
42 |
+
providers = ["VitisAIExecutionProvider"]
|
43 |
+
provider_options = [{"config_file": args.provider_config}]
|
44 |
+
else:
|
45 |
+
providers = ['CPUExecutionProvider']
|
46 |
+
provider_options = None
|
47 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
48 |
+
test_model(ort_session, loader)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
infer_onnx.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import pathlib
|
3 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
4 |
+
sys.path.append(str(CURRENT_DIR))
|
5 |
+
|
6 |
+
import onnxruntime
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from data.data_tiling import tiling_inference
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
onnx_file_name = args.onnx_path
|
15 |
+
image_path = args.image_path
|
16 |
+
output_path = args.output_path
|
17 |
+
|
18 |
+
if args.ipu:
|
19 |
+
providers = ["VitisAIExecutionProvider"]
|
20 |
+
provider_options = [{"config_file": args.provider_config}]
|
21 |
+
else:
|
22 |
+
providers = ['CPUExecutionProvider']
|
23 |
+
provider_options = None
|
24 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
25 |
+
lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
|
26 |
+
|
27 |
+
# Tiled inference
|
28 |
+
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
29 |
+
sr = np.clip(sr, 0, 255)
|
30 |
+
sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
|
31 |
+
cv2.imwrite(output_path, sr)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
parser = argparse.ArgumentParser(description='PAN')
|
37 |
+
parser.add_argument('--onnx_path',
|
38 |
+
type=str,
|
39 |
+
default='PAN_int8.onnx',
|
40 |
+
help='Path to onnx model')
|
41 |
+
parser.add_argument('--image_path',
|
42 |
+
type=str,
|
43 |
+
default='test_data/test.png',
|
44 |
+
help='Path to your low resolution input image.')
|
45 |
+
parser.add_argument('--output_path',
|
46 |
+
type=str,
|
47 |
+
default='test_data/sr.png',
|
48 |
+
help='Path to your upscaled output image.')
|
49 |
+
parser.add_argument('--provider_config',
|
50 |
+
type=str,
|
51 |
+
default="vaip_config.json",
|
52 |
+
help="Path of the config file for seting provider_options.")
|
53 |
+
parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
main(args)
|
option.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
parser = argparse.ArgumentParser(description='PAN')
|
4 |
+
|
5 |
+
# Hardware specifications
|
6 |
+
parser.add_argument('--n_threads', type=int, default=6,
|
7 |
+
help='Number of threads for data loading')
|
8 |
+
parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
|
9 |
+
|
10 |
+
# Data specifications
|
11 |
+
parser.add_argument('--dir_data', type=str, default='dataset/',
|
12 |
+
help='Dataset directory')
|
13 |
+
parser.add_argument('--data_test', type=str, default='Set5',
|
14 |
+
help='Test dataset name')
|
15 |
+
parser.add_argument('--ext', type=str, default='sep',
|
16 |
+
help='Dataset file extension')
|
17 |
+
parser.add_argument('--scale', type=str, default='2',
|
18 |
+
help='Super resolution scale')
|
19 |
+
parser.add_argument('--rgb_range', type=int, default=255,
|
20 |
+
help='Maximum value of RGB')
|
21 |
+
parser.add_argument('--n_colors', type=int, default=3,
|
22 |
+
help='Number of color channels to use')
|
23 |
+
parser.add_argument('--onnx_path', type=str, default='PAN_int8.onnx',
|
24 |
+
help='Path to onnx model')
|
25 |
+
parser.add_argument('--provider_config',
|
26 |
+
type=str,
|
27 |
+
default="vaip_config.json",
|
28 |
+
help="Path of the config file for seting provider_options.")
|
29 |
+
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
|
34 |
+
args.data_test = args.data_test.split('+')
|
35 |
+
|
36 |
+
|
37 |
+
for arg in vars(args):
|
38 |
+
if vars(args)[arg] == 'True':
|
39 |
+
vars(args)[arg] = True
|
40 |
+
elif vars(args)[arg] == 'False':
|
41 |
+
vars(args)[arg] = False
|
42 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
numpy>=1.23.5
|
3 |
+
scipy>=1.9
|
4 |
+
opencv-python
|
5 |
+
pandas
|
6 |
+
pillow
|
7 |
+
scikit-image
|
8 |
+
tqdm
|
test_data/sr.png
ADDED
test_data/test.png
ADDED
utility.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from scipy import signal
|
4 |
+
|
5 |
+
|
6 |
+
def quantize(img, rgb_range):
|
7 |
+
pixel_range = 255 / rgb_range
|
8 |
+
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
|
9 |
+
|
10 |
+
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
|
11 |
+
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
|
12 |
+
print("the dimention of sr image is not equal to hr's! ")
|
13 |
+
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
|
14 |
+
diff = (sr - hr).data.div(rgb_range)
|
15 |
+
|
16 |
+
if benchmark:
|
17 |
+
shave = scale
|
18 |
+
if diff.size(1) > 1:
|
19 |
+
convert = diff.new(1, 3, 1, 1)
|
20 |
+
convert[0, 0, 0, 0] = 65.738
|
21 |
+
convert[0, 1, 0, 0] = 129.057
|
22 |
+
convert[0, 2, 0, 0] = 25.064
|
23 |
+
diff.mul_(convert).div_(256)
|
24 |
+
diff = diff.sum(dim=1, keepdim=True)
|
25 |
+
else:
|
26 |
+
shave = scale + 6
|
27 |
+
valid = diff[:, :, shave:-shave, shave:-shave]
|
28 |
+
mse = valid.pow(2).mean()
|
29 |
+
|
30 |
+
return -10 * math.log10(mse)
|
31 |
+
|
32 |
+
|
33 |
+
def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
|
34 |
+
"""
|
35 |
+
2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
|
36 |
+
Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
|
37 |
+
"""
|
38 |
+
m,n = [(ss-1.)/2. for ss in shape]
|
39 |
+
y,x = np.ogrid[-m:m+1,-n:n+1]
|
40 |
+
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
|
41 |
+
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
|
42 |
+
sumh = h.sum()
|
43 |
+
if sumh != 0:
|
44 |
+
h /= sumh
|
45 |
+
return h
|
46 |
+
|
47 |
+
|
48 |
+
def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
|
49 |
+
'''
|
50 |
+
X : y channel (i.e., luminance) of transformed YCbCr space of X
|
51 |
+
Y : y channel (i.e., luminance) of transformed YCbCr space of Y
|
52 |
+
'''
|
53 |
+
gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
|
54 |
+
|
55 |
+
shave = scale
|
56 |
+
if X.size(1) > 1:
|
57 |
+
gray_coeffs = [65.738, 129.057, 25.064]
|
58 |
+
convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
|
59 |
+
X = X.mul(convert).sum(dim=1)
|
60 |
+
Y = Y.mul(convert).sum(dim=1)
|
61 |
+
|
62 |
+
|
63 |
+
X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
|
64 |
+
Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
|
65 |
+
|
66 |
+
window = gaussian_filter
|
67 |
+
|
68 |
+
ux = signal.convolve2d(X, window, mode='same', boundary='symm')
|
69 |
+
uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
|
70 |
+
|
71 |
+
uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
|
72 |
+
uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
|
73 |
+
uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
|
74 |
+
|
75 |
+
vx = uxx - ux * ux
|
76 |
+
vy = uyy - uy * uy
|
77 |
+
vxy = uxy - ux * uy
|
78 |
+
|
79 |
+
C1 = (K1 * R) ** 2
|
80 |
+
C2 = (K2 * R) ** 2
|
81 |
+
|
82 |
+
A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
|
83 |
+
D = B1 * B2
|
84 |
+
S = (A1 * A2) / D
|
85 |
+
mssim = S.mean()
|
86 |
+
|
87 |
+
return mssim
|
88 |
+
|
89 |
+
|