Spaces:
Runtime error
Runtime error
menghanxia
commited on
Commit
·
6e70c4a
1
Parent(s):
7f58cb0
created the space
Browse files- LICENSE +26 -0
- app.py +80 -0
- inference.py +83 -0
- model/__init__.py +0 -0
- model/base_module.py +81 -0
- model/hourglass.py +70 -0
- model/loss.py +93 -0
- model/model.py +66 -0
- requirements.txt +17 -0
- scripts/invhalf_full.json +42 -0
- scripts/invhalf_warm.json +39 -0
- train.py +272 -0
- train_warm.py +256 -0
- utils/__init__.py +1 -0
- utils/_dct.py +268 -0
- utils/dataset.py +39 -0
- utils/dct.py +29 -0
- utils/filters_tensor.py +81 -0
- utils/pytorch_ssim.py +77 -0
- utils/util.py +78 -0
LICENSE
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Deep Halftoning with Reversible Binary Pattern
|
2 |
+
|
3 |
+
Copyright (c) 2021 The Chinese University of Hong Kong
|
4 |
+
|
5 |
+
Copyright and License Information: The source code, the binary executable, and all data files (hereafter, Software) are copyrighted by The Chinese University of Hong Kong and Tien-Tsin Wong (hereafter, Author), Copyright (c) 2021 The Chinese University of Hong Kong. All Rights Reserved.
|
6 |
+
|
7 |
+
The Author grants to you ("Licensee") a non-exclusive license to use the Software for academic, research and commercial purposes, without fee. For commercial use, Licensee should submit a WRITTEN NOTICE to the Author. The notice should clearly identify the software package/system/hardware (name, version, and/or model number) using the Software. Licensee may distribute the Software to third parties provided that the copyright notice and this statement appears on all copies. Licensee agrees that the copyright notice and this statement will appear on all copies of the Software, or portions thereof. The Author retains exclusive ownership of the Software.
|
8 |
+
|
9 |
+
Licensee may make derivatives of the Software, provided that such derivatives can only be used for the purposes specified in the license grant above.
|
10 |
+
|
11 |
+
THE AUTHOR MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. THE AUTHOR SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES.
|
12 |
+
|
13 |
+
By using the source code, Licensee agrees to cite the following papers in
|
14 |
+
Licensee's publication/work:
|
15 |
+
|
16 |
+
Menghan Xia, Wenbo Hu, Xueting Liu and Tien-Tsin Wong
|
17 |
+
"Deep Halftoning with Reversible Binary Pattern"
|
18 |
+
IEEE International Conference on Computer Vision (ICCV), 2021.
|
19 |
+
|
20 |
+
|
21 |
+
By using or copying the Software, Licensee agrees to abide by the intellectual property laws, and all other applicable laws of the U.S., and the terms of this license.
|
22 |
+
|
23 |
+
Author shall have the right to terminate this license immediately by written notice upon Licensee's breach of, or non-compliance with, any of its terms.
|
24 |
+
Licensee may be held legally responsible for any infringement that is caused or encouraged by Licensee's failure to abide by the terms of this license.
|
25 |
+
|
26 |
+
For more information or comments, send mail to: [email protected]
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os, requests
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from model.model import ResHalf
|
6 |
+
from inference import Inferencer
|
7 |
+
from utils import util
|
8 |
+
|
9 |
+
## local | remote
|
10 |
+
RUN_MODE = "remote"
|
11 |
+
if RUN_MODE != "local":
|
12 |
+
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar")
|
13 |
+
os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar")
|
14 |
+
## examples
|
15 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/girl.png")
|
16 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/wave.png")
|
17 |
+
os.system("wget https://huggingface.co/menghanxia/disco/resolve/main/painting.png")
|
18 |
+
|
19 |
+
## step 1: set up model
|
20 |
+
device = "cpu"
|
21 |
+
checkpt_path = "checkpoints/model_best.pth.tar"
|
22 |
+
invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False)
|
23 |
+
|
24 |
+
|
25 |
+
def prepare_data(input_img, decoding_only=False):
|
26 |
+
input_img = np.array(input_img / 255., np.float32)
|
27 |
+
if decoding_only:
|
28 |
+
input_img = input_img[:,:,:1]
|
29 |
+
input_img = util.img2tensor(input_img * 2. - 1.)
|
30 |
+
return input_img
|
31 |
+
|
32 |
+
|
33 |
+
def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
|
34 |
+
input_img = prepare_data(input_img, decoding_only)
|
35 |
+
input_img = input_img.to(device)
|
36 |
+
if decoding_only:
|
37 |
+
print('>>>:restoration mode')
|
38 |
+
resColor = invhalfer(input_img, decoding_only=decoding_only)
|
39 |
+
output = util.tensor2img(resColor / 2. + 0.5) * 255.
|
40 |
+
else:
|
41 |
+
print('>>>:halftoning mode')
|
42 |
+
resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
|
43 |
+
output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
|
44 |
+
return (output+0.5).astype(np.uint8)
|
45 |
+
|
46 |
+
|
47 |
+
def click_run(input_img, decoding_only):
|
48 |
+
output = run_invhalf(invhalfer, input_img, decoding_only, device)
|
49 |
+
return output
|
50 |
+
|
51 |
+
## step 2: configure interface
|
52 |
+
demo = gr.Blocks(title="ReversibleHalftoning")
|
53 |
+
with demo:
|
54 |
+
gr.Markdown(value="""
|
55 |
+
**Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛.
|
56 |
+
""")
|
57 |
+
with gr.Row():
|
58 |
+
with gr.Column():
|
59 |
+
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
|
60 |
+
with gr.Row():
|
61 |
+
Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \
|
62 |
+
label="Choose a running mode", value="Halftoning (Photo2Halftone)")
|
63 |
+
Button_run = gr.Button(value="Run")
|
64 |
+
with gr.Column():
|
65 |
+
Image_output = gr.Image(type="numpy", label="Output").style(height=480)
|
66 |
+
|
67 |
+
Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
|
68 |
+
|
69 |
+
if RUN_MODE == "local":
|
70 |
+
gr.Examples(examples=[
|
71 |
+
['girl.png', "Halftoning (Photo2Halftone)"],
|
72 |
+
['wave.png', "Halftoning (Photo2Halftone)"],
|
73 |
+
['painting.png', "Restoration (Halftone2Photo)"],
|
74 |
+
],
|
75 |
+
inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
|
76 |
+
|
77 |
+
if RUN_MODE != "local":
|
78 |
+
demo.launch(server_name='9.134.253.83',server_port=7788)
|
79 |
+
else:
|
80 |
+
demo.launch()
|
inference.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os, argparse, json
|
4 |
+
from os.path import join
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from model.model import ResHalf
|
11 |
+
from model.model import Quantize
|
12 |
+
from model.loss import l1_loss
|
13 |
+
from utils import util
|
14 |
+
from utils.dct import DCT_Lowfrequency
|
15 |
+
from utils.filters_tensor import bgr2gray
|
16 |
+
|
17 |
+
|
18 |
+
class Inferencer:
|
19 |
+
def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
|
20 |
+
self.checkpoint = torch.load(checkpoint_path)
|
21 |
+
self.use_cuda = use_cuda
|
22 |
+
self.model = model.eval()
|
23 |
+
if multi_gpu:
|
24 |
+
self.model = torch.nn.DataParallel(self.model)
|
25 |
+
if self.use_cuda:
|
26 |
+
self.model = self.model.cuda()
|
27 |
+
self.model.load_state_dict(self.checkpoint['state_dict'])
|
28 |
+
|
29 |
+
def __call__(self, input_img, decoding_only=False):
|
30 |
+
with torch.no_grad():
|
31 |
+
scale = 8
|
32 |
+
_, _, H, W = input_img.shape
|
33 |
+
if H % scale != 0 or W % scale != 0:
|
34 |
+
input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect')
|
35 |
+
if self.use_cuda:
|
36 |
+
input_img = input_img.cuda()
|
37 |
+
if decoding_only:
|
38 |
+
resColor = self.model(input_img, decoding_only)
|
39 |
+
if H % scale != 0 or W % scale != 0:
|
40 |
+
resColor = resColor[:, :, :H, :W]
|
41 |
+
return resColor
|
42 |
+
else:
|
43 |
+
resHalftone, resColor = self.model(input_img, decoding_only)
|
44 |
+
resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1.
|
45 |
+
if H % scale != 0 or W % scale != 0:
|
46 |
+
resHalftone = resHalftone[:, :, :H, :W]
|
47 |
+
resColor = resColor[:, :, :H, :W]
|
48 |
+
return resHalftone, resColor
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
parser = argparse.ArgumentParser(description='invHalf')
|
53 |
+
parser.add_argument('--model', default=None, type=str,
|
54 |
+
help='model weight file path')
|
55 |
+
parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input')
|
56 |
+
parser.add_argument('--data_dir', default=None, type=str,
|
57 |
+
help='where to load input data (RGB images)')
|
58 |
+
parser.add_argument('--save_dir', default=None, type=str,
|
59 |
+
help='where to save the result')
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
invhalfer = Inferencer(
|
63 |
+
checkpoint_path=args.model,
|
64 |
+
model=ResHalf(train=False)
|
65 |
+
)
|
66 |
+
save_dir = os.path.join(args.save_dir)
|
67 |
+
util.ensure_dir(save_dir)
|
68 |
+
test_imgs = glob(join(args.data_dir, '*.*g'))
|
69 |
+
print('------loaded %d images.' % len(test_imgs) )
|
70 |
+
for img in test_imgs:
|
71 |
+
print('[*] processing %s ...' % img)
|
72 |
+
if args.decoding:
|
73 |
+
input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1.
|
74 |
+
c = invhalfer(util.img2tensor(input_img), decoding_only=True)
|
75 |
+
c = util.tensor2img(c / 2. + 0.5) * 255.
|
76 |
+
cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
|
77 |
+
else:
|
78 |
+
input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1.
|
79 |
+
h, c = invhalfer(util.img2tensor(input_img), decoding_only=False)
|
80 |
+
h = util.tensor2img(h / 2. + 0.5) * 255.
|
81 |
+
c = util.tensor2img(c / 2. + 0.5) * 255.
|
82 |
+
cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h)
|
83 |
+
cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
|
model/__init__.py
ADDED
File without changes
|
model/base_module.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def tensor2array(tensors):
|
7 |
+
arrays = tensors.detach().to("cpu").numpy()
|
8 |
+
return np.transpose(arrays, (0, 2, 3, 1))
|
9 |
+
|
10 |
+
|
11 |
+
class ResidualBlock(nn.Module):
|
12 |
+
def __init__(self, channels):
|
13 |
+
super(ResidualBlock, self).__init__()
|
14 |
+
self.conv = nn.Sequential(
|
15 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
16 |
+
nn.ReLU(inplace=True),
|
17 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
residual = self.conv(x)
|
22 |
+
return x + residual
|
23 |
+
|
24 |
+
|
25 |
+
class DownsampleBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels, out_channels, withConvRelu=True):
|
27 |
+
super(DownsampleBlock, self).__init__()
|
28 |
+
if withConvRelu:
|
29 |
+
self.conv = nn.Sequential(
|
30 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2),
|
31 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
32 |
+
nn.ReLU(inplace=True)
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.conv(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConvBlock(nn.Module):
|
42 |
+
def __init__(self, inChannels, outChannels, convNum):
|
43 |
+
super(ConvBlock, self).__init__()
|
44 |
+
self.inConv = nn.Sequential(
|
45 |
+
nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
|
46 |
+
nn.ReLU(inplace=True)
|
47 |
+
)
|
48 |
+
layers = []
|
49 |
+
for _ in range(convNum - 1):
|
50 |
+
layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
|
51 |
+
layers.append(nn.ReLU(inplace=True))
|
52 |
+
self.conv = nn.Sequential(*layers)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.inConv(x)
|
56 |
+
x = self.conv(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class UpsampleBlock(nn.Module):
|
61 |
+
def __init__(self, in_channels, out_channels):
|
62 |
+
super(UpsampleBlock, self).__init__()
|
63 |
+
self.conv = nn.Sequential(
|
64 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
|
65 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
66 |
+
nn.ReLU(inplace=True)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
71 |
+
return self.conv(x)
|
72 |
+
|
73 |
+
|
74 |
+
class SkipConnection(nn.Module):
|
75 |
+
def __init__(self, channels):
|
76 |
+
super(SkipConnection, self).__init__()
|
77 |
+
self.conv = nn.Conv2d(2 * channels, channels, 1, bias=False)
|
78 |
+
|
79 |
+
def forward(self, x, y):
|
80 |
+
x = torch.cat((x, y), 1)
|
81 |
+
return self.conv(x)
|
model/hourglass.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from .base_module import ConvBlock, DownsampleBlock, ResidualBlock, SkipConnection, UpsampleBlock
|
3 |
+
|
4 |
+
|
5 |
+
class HourGlass(nn.Module):
|
6 |
+
def __init__(self, convNum=4, resNum=4, inChannel=6, outChannel=3):
|
7 |
+
super(HourGlass, self).__init__()
|
8 |
+
self.inConv = ConvBlock(inChannel, 64, convNum=2)
|
9 |
+
self.down1 = nn.Sequential(*[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)])
|
10 |
+
self.down2 = nn.Sequential(
|
11 |
+
*[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=convNum)])
|
12 |
+
self.down3 = nn.Sequential(
|
13 |
+
*[DownsampleBlock(256, 512, withConvRelu=False), ConvBlock(512, 512, convNum=convNum)])
|
14 |
+
self.residual = nn.Sequential(*[ResidualBlock(512) for _ in range(resNum)])
|
15 |
+
self.up3 = nn.Sequential(*[UpsampleBlock(512, 256), ConvBlock(256, 256, convNum=convNum)])
|
16 |
+
self.skip3 = SkipConnection(256)
|
17 |
+
self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)])
|
18 |
+
self.skip2 = SkipConnection(128)
|
19 |
+
self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)])
|
20 |
+
self.skip1 = SkipConnection(64)
|
21 |
+
self.outConv = nn.Sequential(
|
22 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv2d(64, outChannel, kernel_size=1, padding=0)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
f1 = self.inConv(x)
|
29 |
+
f2 = self.down1(f1)
|
30 |
+
f3 = self.down2(f2)
|
31 |
+
f4 = self.down3(f3)
|
32 |
+
r4 = self.residual(f4)
|
33 |
+
r3 = self.skip3(self.up3(r4), f3)
|
34 |
+
r2 = self.skip2(self.up2(r3), f2)
|
35 |
+
r1 = self.skip1(self.up1(r2), f1)
|
36 |
+
y = self.outConv(r1)
|
37 |
+
return y
|
38 |
+
|
39 |
+
|
40 |
+
class ResidualHourGlass(nn.Module):
|
41 |
+
def __init__(self, resNum=4, inChannel=6, outChannel=3):
|
42 |
+
super(ResidualHourGlass, self).__init__()
|
43 |
+
self.inConv = nn.Conv2d(inChannel, 64, kernel_size=3, padding=1)
|
44 |
+
self.residualBefore = nn.Sequential(*[ResidualBlock(64) for _ in range(2)])
|
45 |
+
self.down1 = nn.Sequential(
|
46 |
+
*[DownsampleBlock(64, 128, withConvRelu=False), ConvBlock(128, 128, convNum=2)])
|
47 |
+
self.down2 = nn.Sequential(
|
48 |
+
*[DownsampleBlock(128, 256, withConvRelu=False), ConvBlock(256, 256, convNum=2)])
|
49 |
+
self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)])
|
50 |
+
self.up2 = nn.Sequential(*[UpsampleBlock(256, 128), ConvBlock(128, 128, convNum=2)])
|
51 |
+
self.skip2 = SkipConnection(128)
|
52 |
+
self.up1 = nn.Sequential(*[UpsampleBlock(128, 64), ConvBlock(64, 64, convNum=2)])
|
53 |
+
self.skip1 = SkipConnection(64)
|
54 |
+
self.residualAfter = nn.Sequential(*[ResidualBlock(64) for _ in range(2)])
|
55 |
+
self.outConv = nn.Sequential(
|
56 |
+
nn.Conv2d(64, outChannel, kernel_size=3, padding=1),
|
57 |
+
nn.Tanh()
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
f1 = self.inConv(x)
|
62 |
+
f1 = self.residualBefore(f1)
|
63 |
+
f2 = self.down1(f1)
|
64 |
+
f3 = self.down2(f2)
|
65 |
+
r3 = self.residual(f3)
|
66 |
+
r2 = self.skip2(self.up2(r3), f2)
|
67 |
+
r1 = self.skip1(self.up1(r2), f1)
|
68 |
+
y = self.residualAfter(r1)
|
69 |
+
y = self.outConv(y)
|
70 |
+
return y
|
model/loss.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from utils.filters_tensor import GaussianSmoothing, bgr2gray
|
5 |
+
from utils import pytorch_ssim
|
6 |
+
from torch import nn
|
7 |
+
from .hourglass import HourGlass
|
8 |
+
from torchvision.models.vgg import vgg19
|
9 |
+
|
10 |
+
|
11 |
+
def l2_loss(y_input, y_target):
|
12 |
+
return F.mse_loss(y_input, y_target)
|
13 |
+
|
14 |
+
|
15 |
+
def l1_loss(y_input, y_target):
|
16 |
+
return F.l1_loss(y_input, y_target)
|
17 |
+
|
18 |
+
|
19 |
+
def gaussianL2(yInput, yTarget):
|
20 |
+
# data range [-1,1]
|
21 |
+
smoother = GaussianSmoothing(channels=1, kernel_size=11, sigma=2.0)
|
22 |
+
gaussianInput = smoother(yInput)
|
23 |
+
gaussianTarget = smoother(bgr2gray(yTarget))
|
24 |
+
return F.mse_loss(gaussianInput, gaussianTarget)
|
25 |
+
|
26 |
+
|
27 |
+
def binL1(yInput):
|
28 |
+
# data range is [-1,1]
|
29 |
+
return (yInput.abs() - 1.0).abs().mean()
|
30 |
+
|
31 |
+
|
32 |
+
def ssimLoss(yInput, yTarget):
|
33 |
+
# data range is [-1,1]
|
34 |
+
ssim = pytorch_ssim.ssim(yInput / 2. + 0.5, bgr2gray(yTarget / 2. + 0.5), window_size=11)
|
35 |
+
return 1. - ssim
|
36 |
+
|
37 |
+
|
38 |
+
class InverseHalf(nn.Module):
|
39 |
+
def __init__(self):
|
40 |
+
super(InverseHalf, self).__init__()
|
41 |
+
self.net = HourGlass(inChannel=1, outChannel=1)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
grayscale = self.net(x)
|
45 |
+
return grayscale
|
46 |
+
|
47 |
+
|
48 |
+
class FeatureLoss:
|
49 |
+
def __init__(self, pretrainedPath, requireGrad=False, multiGpu=True):
|
50 |
+
self.featureExactor = InverseHalf()
|
51 |
+
if multiGpu:
|
52 |
+
self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
|
53 |
+
print("-loading feature extractor: {} ...".format(pretrainedPath))
|
54 |
+
checkpoint = torch.load(pretrainedPath)
|
55 |
+
self.featureExactor.load_state_dict(checkpoint['state_dict'])
|
56 |
+
print("-feature network loaded")
|
57 |
+
if not requireGrad:
|
58 |
+
for param in self.featureExactor.parameters():
|
59 |
+
param.requires_grad = False
|
60 |
+
|
61 |
+
def __call__(self, yInput, yTarget):
|
62 |
+
inFeature = self.featureExactor(yInput)
|
63 |
+
return l2_loss(inFeature, yTarget)
|
64 |
+
|
65 |
+
|
66 |
+
class Vgg19Loss:
|
67 |
+
def __init__(self, multiGpu=True):
|
68 |
+
os.environ['TORCH_HOME']='~/bigdata/0ProgramS/checkpoints'
|
69 |
+
# data in BGR format, [0,1] range
|
70 |
+
self.mean = [0.485, 0.456, 0.406]
|
71 |
+
self.mean.reverse()
|
72 |
+
self.std = [0.229, 0.224, 0.225]
|
73 |
+
self.std.reverse()
|
74 |
+
vgg = vgg19(pretrained=True)
|
75 |
+
# maxpoll after conv4_4
|
76 |
+
self.featureExactor = nn.Sequential(*list(vgg.features)[:28]).eval()
|
77 |
+
for param in self.featureExactor.parameters():
|
78 |
+
param.requires_grad = False
|
79 |
+
if multiGpu:
|
80 |
+
self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda()
|
81 |
+
print('[*] Vgg19Loss init!')
|
82 |
+
|
83 |
+
def normalize(self, tensor):
|
84 |
+
tensor = tensor.clone()
|
85 |
+
mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
|
86 |
+
std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
|
87 |
+
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
88 |
+
return tensor
|
89 |
+
|
90 |
+
def __call__(self, yInput, yTarget):
|
91 |
+
inFeature = self.featureExactor(self.normalize(yInput).flip(1))
|
92 |
+
targetFeature = self.featureExactor(self.normalize(yTarget).flip(1))
|
93 |
+
return l2_loss(inFeature, targetFeature)
|
model/model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Function
|
5 |
+
|
6 |
+
from .hourglass import HourGlass
|
7 |
+
from utils.dct import DCT_Lowfrequency
|
8 |
+
from utils.filters_tensor import bgr2gray
|
9 |
+
|
10 |
+
from collections import OrderedDict
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
class Quantize(Function):
|
15 |
+
@staticmethod
|
16 |
+
def forward(ctx, x):
|
17 |
+
ctx.save_for_backward(x)
|
18 |
+
y = x.round()
|
19 |
+
return y
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def backward(ctx, grad_output):
|
23 |
+
inputX = ctx.saved_tensors
|
24 |
+
return grad_output
|
25 |
+
|
26 |
+
|
27 |
+
class ResHalf(nn.Module):
|
28 |
+
def __init__(self, train=True, warm_stage=False):
|
29 |
+
super(ResHalf, self).__init__()
|
30 |
+
self.encoder = HourGlass(inChannel=4, outChannel=1, resNum=4, convNum=4)
|
31 |
+
self.decoder = HourGlass(inChannel=1, outChannel=3, resNum=4, convNum=4)
|
32 |
+
self.dcter = DCT_Lowfrequency(size=256, fLimit=50)
|
33 |
+
# quantize [-1,1] data to be {-1,1}
|
34 |
+
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
|
35 |
+
self.isTrain = train
|
36 |
+
if warm_stage:
|
37 |
+
for name, param in self.decoder.named_parameters():
|
38 |
+
param.requires_grad = False
|
39 |
+
|
40 |
+
def add_impluse_noise(self, input_halfs, p=0.0):
|
41 |
+
N,C,H,W = input_halfs.shape
|
42 |
+
SNR = 1-p
|
43 |
+
np_input_halfs = input_halfs.detach().to("cpu").numpy()
|
44 |
+
np_input_halfs = np.transpose(np_input_halfs, (0, 2, 3, 1))
|
45 |
+
for i in range(N):
|
46 |
+
mask = np.random.choice((0, 1, 2), size=(H, W, 1), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.])
|
47 |
+
np_input_halfs[i, mask==1] = 1.0
|
48 |
+
np_input_halfs[i, mask==2] = -1.0
|
49 |
+
return torch.from_numpy(np_input_halfs.transpose((0, 3, 1, 2))).to(input_halfs.device)
|
50 |
+
|
51 |
+
def forward(self, input_img, decoding_only=False):
|
52 |
+
if decoding_only:
|
53 |
+
halfResQ = self.quantizer(input_img)
|
54 |
+
restored = self.decoder(halfResQ)
|
55 |
+
return restored
|
56 |
+
|
57 |
+
noise = torch.randn_like(input_img) * 0.3
|
58 |
+
halfRes = self.encoder(torch.cat((input_img, noise[:,:1,:,:]), dim=1))
|
59 |
+
halfResQ = self.quantizer(halfRes)
|
60 |
+
restored = self.decoder(halfResQ)
|
61 |
+
if self.isTrain:
|
62 |
+
halfDCT = self.dcter(halfRes / 2. + 0.5)
|
63 |
+
refDCT = self.dcter(bgr2gray(input_img / 2. + 0.5))
|
64 |
+
return halfRes, halfDCT, refDCT, restored
|
65 |
+
else:
|
66 |
+
return halfRes, restored
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict
|
2 |
+
future
|
3 |
+
numpy
|
4 |
+
opencv-python
|
5 |
+
pandas
|
6 |
+
Pillow
|
7 |
+
pyyaml
|
8 |
+
requests
|
9 |
+
scikit-image
|
10 |
+
scikit-learn
|
11 |
+
scipy
|
12 |
+
torch>=1.8.0
|
13 |
+
torchvision
|
14 |
+
tensorboardx>=2.4
|
15 |
+
tqdm
|
16 |
+
yapf
|
17 |
+
lpips
|
scripts/invhalf_full.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "invhalf_full",
|
3 |
+
"initial_ckpt": "checkpoints/model_warm.pth.tar",
|
4 |
+
"model": "ResHalf",
|
5 |
+
"data_dir": "dataset/",
|
6 |
+
"save_dir": "./",
|
7 |
+
"trainer": {
|
8 |
+
"epochs": 1000,
|
9 |
+
"save_epochs": 5
|
10 |
+
},
|
11 |
+
"data_loader": {
|
12 |
+
"dataset": "HalftoneVOC2012.json",
|
13 |
+
"special_set": "special_color.json",
|
14 |
+
"batch_size": 1,
|
15 |
+
"shuffle": true,
|
16 |
+
"num_workers": 32
|
17 |
+
},
|
18 |
+
"quantizeLoss": "binL1",
|
19 |
+
"quantizeLossWeight": 0.1,
|
20 |
+
"toneLoss": "gaussianL2",
|
21 |
+
"toneLossWeight": 0.6,
|
22 |
+
"structureLoss": "ssimLoss",
|
23 |
+
"structureLossWeight": 0.0,
|
24 |
+
"restoreLoss": "l2_loss",
|
25 |
+
"restoreLossWeight": 1.0,
|
26 |
+
"blueNoiseLossWeight": 0.3,
|
27 |
+
"vggLossWeight": 0.0002,
|
28 |
+
"cuda": true,
|
29 |
+
"multi-gpus": true,
|
30 |
+
"optimizer_type": "Adam",
|
31 |
+
"optimizer": {
|
32 |
+
"lr": 0.0001,
|
33 |
+
"weight_decay": 0
|
34 |
+
},
|
35 |
+
"lr_sheduler": {
|
36 |
+
"factor": 0.5,
|
37 |
+
"patience": 3,
|
38 |
+
"threshold": 1e-05,
|
39 |
+
"cooldown": 0
|
40 |
+
},
|
41 |
+
"seed": 131
|
42 |
+
}
|
scripts/invhalf_warm.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "invhalf_warmup",
|
3 |
+
"model": "ResHalf",
|
4 |
+
"data_dir": "dataset/",
|
5 |
+
"save_dir": "./",
|
6 |
+
"trainer": {
|
7 |
+
"epochs": 1000,
|
8 |
+
"save_epochs": 5
|
9 |
+
},
|
10 |
+
"data_loader": {
|
11 |
+
"dataset": "HalftoneVOC2012.json",
|
12 |
+
"special_set": "special_color.json",
|
13 |
+
"batch_size": 8,
|
14 |
+
"shuffle": true,
|
15 |
+
"num_workers": 32
|
16 |
+
},
|
17 |
+
"quantizeLoss": "binL1",
|
18 |
+
"quantizeLossWeight": 0.2,
|
19 |
+
"toneLoss": "gaussianL2",
|
20 |
+
"toneLossWeight": 0.6,
|
21 |
+
"structureLoss": "ssimLoss",
|
22 |
+
"structureLossWeight": 0.0,
|
23 |
+
"blueNoiseLossWeight": 0.3,
|
24 |
+
"featureLossWeight": 1.0,
|
25 |
+
"cuda": true,
|
26 |
+
"multi-gpus": true,
|
27 |
+
"optimizer_type": "Adam",
|
28 |
+
"optimizer": {
|
29 |
+
"lr": 0.0001,
|
30 |
+
"weight_decay": 0
|
31 |
+
},
|
32 |
+
"lr_sheduler": {
|
33 |
+
"factor": 0.5,
|
34 |
+
"patience": 3,
|
35 |
+
"threshold": 1e-05,
|
36 |
+
"cooldown": 0
|
37 |
+
},
|
38 |
+
"seed": 131
|
39 |
+
}
|
train.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, datetime, time
|
2 |
+
import argparse, json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.optim as optim
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import torchvision
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.backends import cudnn
|
10 |
+
|
11 |
+
from model.base_module import tensor2array
|
12 |
+
from model.model import ResHalf
|
13 |
+
from model.loss import *
|
14 |
+
from utils.dataset import HalftoneVOC2012 as Dataset
|
15 |
+
from utils.util import ensure_dir, save_list, save_images_from_batch
|
16 |
+
|
17 |
+
|
18 |
+
class Trainer():
|
19 |
+
def __init__(self, config, resume):
|
20 |
+
self.config = config
|
21 |
+
self.name = config['name']
|
22 |
+
self.resume_path = resume
|
23 |
+
self.n_epochs = config['trainer']['epochs']
|
24 |
+
self.with_cuda = config['cuda'] and torch.cuda.is_available()
|
25 |
+
self.seed = config['seed']
|
26 |
+
self.start_epoch = 0
|
27 |
+
self.save_freq = config['trainer']['save_epochs']
|
28 |
+
self.checkpoint_dir = os.path.join(config['save_dir'], self.name)
|
29 |
+
ensure_dir(self.checkpoint_dir)
|
30 |
+
json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
|
31 |
+
indent=4, sort_keys=False)
|
32 |
+
print("@Workspace: %s *************"%self.checkpoint_dir)
|
33 |
+
self.cache = os.path.join(self.checkpoint_dir, 'train_cache')
|
34 |
+
self.val_halftone = os.path.join(self.cache, 'halftone')
|
35 |
+
self.val_restored = os.path.join(self.cache, 'restored')
|
36 |
+
ensure_dir(self.val_halftone)
|
37 |
+
ensure_dir(self.val_restored)
|
38 |
+
|
39 |
+
## model
|
40 |
+
self.model = eval(config['model'])()
|
41 |
+
if self.config['multi-gpus']:
|
42 |
+
self.model = torch.nn.DataParallel(self.model).cuda()
|
43 |
+
elif self.with_cuda:
|
44 |
+
self.model = self.model.cuda()
|
45 |
+
|
46 |
+
## optimizer
|
47 |
+
self.optimizer = getattr(optim, config['optimizer_type'])(self.model.parameters(), **config['optimizer'])
|
48 |
+
self.lr_sheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **config['lr_sheduler'])
|
49 |
+
|
50 |
+
## dataset loader
|
51 |
+
with open(os.path.join(config['data_dir'], config['data_loader']['dataset'])) as f:
|
52 |
+
dataset = json.load(f)
|
53 |
+
train_set = Dataset(dataset['train'])
|
54 |
+
self.train_data_loader = DataLoader(train_set, batch_size=config['data_loader']['batch_size'],
|
55 |
+
shuffle=config['data_loader']['shuffle'],
|
56 |
+
num_workers=config['data_loader']['num_workers'])
|
57 |
+
val_set = Dataset(dataset['val'])
|
58 |
+
self.valid_data_loader = DataLoader(val_set, batch_size=config['data_loader']['batch_size'],
|
59 |
+
shuffle=False,
|
60 |
+
num_workers=config['data_loader']['num_workers'])
|
61 |
+
# special dataloader: constant color images
|
62 |
+
with open(os.path.join(config['data_dir'], config['data_loader']['special_set'])) as f:
|
63 |
+
dataset = json.load(f)
|
64 |
+
specialSet = Dataset(dataset['train'])
|
65 |
+
self.specialDataloader = DataLoader(specialSet, batch_size=config['data_loader']['batch_size'],
|
66 |
+
shuffle=config['data_loader']['shuffle'],
|
67 |
+
num_workers=config['data_loader']['num_workers'])
|
68 |
+
|
69 |
+
## loss function
|
70 |
+
self.quantizeLoss = eval(config['quantizeLoss'])
|
71 |
+
self.quantizeLossWeight = config['quantizeLossWeight']
|
72 |
+
self.toneLoss = eval(config['toneLoss'])
|
73 |
+
self.toneLossWeight = config['toneLossWeight']
|
74 |
+
self.structureLoss = eval(config['structureLoss'])
|
75 |
+
self.structureLossWeight = config['structureLossWeight']
|
76 |
+
self.restoreLoss = eval(config['restoreLoss'])
|
77 |
+
self.restoreLossWeight = config['restoreLossWeight']
|
78 |
+
# quantize [-1,1] data to be {-1,1}
|
79 |
+
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
|
80 |
+
self.blueNoiseLossWeight = config['blueNoiseLossWeight']
|
81 |
+
self.vggloss = Vgg19Loss()
|
82 |
+
self.vggLossWeight = config['vggLossWeight']
|
83 |
+
|
84 |
+
# resume checkpoint or load warm-up checkpoint
|
85 |
+
checkpt_path = self.config['initial_ckpt']
|
86 |
+
if self.resume_path:
|
87 |
+
checkpt_path = self.resume_path
|
88 |
+
assert os.path.exists(checkpt_path), 'Invalid checkpoint Path: %s' % checkpt_path
|
89 |
+
self.load_checkpoint(checkpt_path)
|
90 |
+
|
91 |
+
|
92 |
+
def _train(self):
|
93 |
+
torch.manual_seed(self.config['seed'])
|
94 |
+
torch.cuda.manual_seed(self.config['seed'])
|
95 |
+
cudnn.benchmark = True
|
96 |
+
|
97 |
+
start_time = time.time()
|
98 |
+
self.monitor_best = 999.
|
99 |
+
for epoch in range(self.start_epoch, self.n_epochs + 1):
|
100 |
+
ep_st = time.time()
|
101 |
+
epoch_loss = self._train_epoch(epoch)
|
102 |
+
# perform lr_sheduler
|
103 |
+
self.lr_sheduler.step(epoch_loss['total_loss'])
|
104 |
+
epoch_lr = self.optimizer.state_dict()['param_groups'][0]['lr']
|
105 |
+
epoch_metric = self._valid_epoch(epoch)
|
106 |
+
print("[*] --- epoch: %d/%d | loss: %4.4f | metric: %4.4f | time-consumed: %4.2f ---" % \
|
107 |
+
(epoch+1, self.n_epochs, epoch_loss['total_loss'], epoch_metric, (time.time()-ep_st)))
|
108 |
+
|
109 |
+
# save losses and learning rate
|
110 |
+
epoch_loss['metric'] = epoch_metric
|
111 |
+
epoch_loss['lr'] = epoch_lr
|
112 |
+
self.save_loss(epoch_loss, epoch)
|
113 |
+
if ((epoch+1) % self.save_freq == 0 or epoch == (self.n_epochs-1)):
|
114 |
+
print('---------- saving model ...')
|
115 |
+
self.save_checkpoint(epoch)
|
116 |
+
if self.monitor_best > epoch_metric:
|
117 |
+
self.monitor_best = epoch_metric
|
118 |
+
self.save_checkpoint(epoch, save_best=True)
|
119 |
+
|
120 |
+
print("Training finished! consumed %f sec" % (time.time() - start_time))
|
121 |
+
|
122 |
+
|
123 |
+
def _to_variable(self, data, target):
|
124 |
+
data, target = Variable(data), Variable(target)
|
125 |
+
if self.with_cuda:
|
126 |
+
data, target = data.cuda(), target.cuda()
|
127 |
+
return data, target
|
128 |
+
|
129 |
+
|
130 |
+
def _train_epoch(self, epoch):
|
131 |
+
self.model.train()
|
132 |
+
total_loss, quantize_loss, restore_loss = 0, 0, 0
|
133 |
+
tone_loss, structure_loss, blue_noise_loss = 0, 0, 0
|
134 |
+
|
135 |
+
specialIter = iter(self.specialDataloader)
|
136 |
+
time_stamp = time.time()
|
137 |
+
for batch_idx, (color, halftone) in enumerate(self.train_data_loader):
|
138 |
+
color, halftone = self._to_variable(color, halftone)
|
139 |
+
# special data
|
140 |
+
try:
|
141 |
+
specialColor, specialHalftone = next(specialIter)
|
142 |
+
except StopIteration:
|
143 |
+
# reinitialize data loader
|
144 |
+
specialIter = iter(self.specialDataloader)
|
145 |
+
specialColor, specialHalftone = next(specialIter)
|
146 |
+
specialColor, specialHalftone = self._to_variable(specialColor, specialHalftone)
|
147 |
+
self.optimizer.zero_grad()
|
148 |
+
output = self.model(color, halftone)
|
149 |
+
quantizeLoss = self.quantizeLoss(output[0])
|
150 |
+
toneLoss = self.toneLoss(output[0], color)
|
151 |
+
structureLoss = self.structureLoss(output[0], color)
|
152 |
+
|
153 |
+
# restore
|
154 |
+
restoredColor = output[-1]
|
155 |
+
restoreLoss = self.restoreLoss(restoredColor, color)
|
156 |
+
vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5)
|
157 |
+
|
158 |
+
# special data
|
159 |
+
output = self.model(specialColor, specialHalftone)
|
160 |
+
toneLossSpecial = self.toneLoss(output[0], specialColor)
|
161 |
+
blueNoiseLoss = l1_loss(output[1], output[2])
|
162 |
+
quantizeLossSpecial = self.quantizeLoss(output[0])
|
163 |
+
loss = (self.toneLossWeight * toneLoss + self.blueNoiseLossWeight*toneLossSpecial) \
|
164 |
+
+ self.quantizeLossWeight * (0.5*quantizeLoss + 0.5*quantizeLossSpecial) \
|
165 |
+
+ self.structureLossWeight * structureLoss \
|
166 |
+
+ self.blueNoiseLossWeight * blueNoiseLoss \
|
167 |
+
+ self.vggLossWeight * vggLoss \
|
168 |
+
+ self.restoreLossWeight * restoreLoss
|
169 |
+
|
170 |
+
loss.backward()
|
171 |
+
# apply grad clip to make training roboust
|
172 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.0001)
|
173 |
+
self.optimizer.step()
|
174 |
+
|
175 |
+
total_loss += loss.item()
|
176 |
+
quantize_loss += quantizeLoss.item()
|
177 |
+
restore_loss += (self.restoreLossWeight*restoreLoss + self.vggLossWeight*vggLoss).item()
|
178 |
+
tone_loss += toneLoss.item()
|
179 |
+
structure_loss += structureLoss.item()
|
180 |
+
blue_noise_loss += blueNoiseLoss.item()
|
181 |
+
if batch_idx % 100 == 0:
|
182 |
+
tm = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
183 |
+
print("%s >> [%d/%d] iter:%d loss:%4.4f "%(tm, epoch+1, self.n_epochs, batch_idx+1, loss.item()))
|
184 |
+
|
185 |
+
epoch_loss = dict()
|
186 |
+
epoch_loss['total_loss'] = total_loss / (batch_idx+1)
|
187 |
+
epoch_loss['quantize_loss'] = quantize_loss / (batch_idx+1)
|
188 |
+
epoch_loss['tone_loss'] = tone_loss / (batch_idx+1)
|
189 |
+
epoch_loss['structure_loss'] = structure_loss / (batch_idx+1)
|
190 |
+
epoch_loss['bluenoise_loss'] = blue_noise_loss / (batch_idx+1)
|
191 |
+
epoch_loss['restore_loss'] = restore_loss / (batch_idx+1)
|
192 |
+
|
193 |
+
return epoch_loss
|
194 |
+
|
195 |
+
|
196 |
+
def _valid_epoch(self, epoch):
|
197 |
+
self.model.eval()
|
198 |
+
total_loss = 0
|
199 |
+
with torch.no_grad():
|
200 |
+
for batch_idx, (color, halftone) in enumerate(self.valid_data_loader):
|
201 |
+
color, halftone = self._to_variable(color, halftone)
|
202 |
+
output = self.model(color, halftone)
|
203 |
+
quantizeLoss = self.quantizeLoss(output[0])
|
204 |
+
toneLoss = self.toneLoss(output[0], color)
|
205 |
+
structureLoss = self.structureLoss(output[0], color)
|
206 |
+
# restore
|
207 |
+
restoredColor = output[-1]
|
208 |
+
restoreLoss = self.restoreLoss(restoredColor, color)
|
209 |
+
vggLoss = self.vggloss(restoredColor / 2. + 0.5, color / 2. + 0.5)
|
210 |
+
|
211 |
+
loss = self.toneLossWeight * toneLoss \
|
212 |
+
+ self.quantizeLossWeight * quantizeLoss \
|
213 |
+
+ self.structureLossWeight * structureLoss \
|
214 |
+
+ self.vggLossWeight * vggLoss \
|
215 |
+
+ self.restoreLossWeight * restoreLoss
|
216 |
+
|
217 |
+
total_loss += loss.item()
|
218 |
+
#! save intermediate images
|
219 |
+
gray_imgs = tensor2array(output[0])
|
220 |
+
color_imgs = tensor2array(output[-1])
|
221 |
+
save_images_from_batch(gray_imgs, self.val_halftone, None, batch_idx)
|
222 |
+
save_images_from_batch(color_imgs, self.val_restored, None, batch_idx)
|
223 |
+
|
224 |
+
return total_loss
|
225 |
+
|
226 |
+
|
227 |
+
def save_loss(self, epoch_loss, epoch):
|
228 |
+
if epoch == 0:
|
229 |
+
for key in epoch_loss:
|
230 |
+
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=False)
|
231 |
+
else:
|
232 |
+
for key in epoch_loss:
|
233 |
+
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=True)
|
234 |
+
|
235 |
+
|
236 |
+
def load_checkpoint(self, checkpt_path):
|
237 |
+
print("-loading checkpoint from: {} ...".format(checkpt_path))
|
238 |
+
if self.resume_path:
|
239 |
+
checkpoint = torch.load(checkpt_path)
|
240 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
241 |
+
self.monitor_best = checkpoint['monitor_best']
|
242 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
243 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
244 |
+
else:
|
245 |
+
checkpoint = torch.load(checkpt_path)
|
246 |
+
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
|
247 |
+
print("-pretrained checkpoint loaded.")
|
248 |
+
|
249 |
+
|
250 |
+
def save_checkpoint(self, epoch, save_best=False):
|
251 |
+
state = {
|
252 |
+
'epoch': epoch,
|
253 |
+
'state_dict': self.model.state_dict(),
|
254 |
+
'optimizer': self.optimizer.state_dict(),
|
255 |
+
'monitor_best': self.monitor_best
|
256 |
+
}
|
257 |
+
save_path = os.path.join(self.checkpoint_dir, 'model_last.pth.tar')
|
258 |
+
if save_best:
|
259 |
+
save_path = os.path.join(self.checkpoint_dir, 'model_best.pth.tar')
|
260 |
+
torch.save(state, save_path)
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == '__main__':
|
264 |
+
parser = argparse.ArgumentParser(description='InvHalf')
|
265 |
+
parser.add_argument('-c', '--config', default=None, type=str,
|
266 |
+
help='config file path (default: None)')
|
267 |
+
parser.add_argument('-r', '--resume', default=None, type=str,
|
268 |
+
help='path to latest checkpoint (default: None)')
|
269 |
+
args = parser.parse_args()
|
270 |
+
config_dict = json.load(open(args.config))
|
271 |
+
node = Trainer(config_dict, resume=args.resume)
|
272 |
+
node._train()
|
train_warm.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, datetime, time
|
2 |
+
import argparse, json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.optim as optim
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import torchvision
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.backends import cudnn
|
10 |
+
|
11 |
+
from model.base_module import tensor2array
|
12 |
+
from model.model import ResHalf
|
13 |
+
from model.loss import *
|
14 |
+
from utils.dataset import HalftoneVOC2012 as Dataset
|
15 |
+
from utils.util import ensure_dir, save_list, save_images_from_batch
|
16 |
+
|
17 |
+
|
18 |
+
class Trainer():
|
19 |
+
def __init__(self, config, resume):
|
20 |
+
self.config = config
|
21 |
+
self.name = config['name']
|
22 |
+
self.resume_path = resume
|
23 |
+
self.n_epochs = config['trainer']['epochs']
|
24 |
+
self.with_cuda = config['cuda'] and torch.cuda.is_available()
|
25 |
+
self.seed = config['seed']
|
26 |
+
self.start_epoch = 0
|
27 |
+
self.save_freq = config['trainer']['save_epochs']
|
28 |
+
self.checkpoint_dir = os.path.join(config['save_dir'], self.name)
|
29 |
+
ensure_dir(self.checkpoint_dir)
|
30 |
+
json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
|
31 |
+
indent=4, sort_keys=False)
|
32 |
+
print("@Workspace: %s *************"%self.checkpoint_dir)
|
33 |
+
self.cache = os.path.join(self.checkpoint_dir, 'train_cache')
|
34 |
+
self.val_halftone = os.path.join(self.cache, 'halftone')
|
35 |
+
self.val_restored = os.path.join(self.cache, 'restored')
|
36 |
+
ensure_dir(self.val_halftone)
|
37 |
+
ensure_dir(self.val_restored)
|
38 |
+
|
39 |
+
## model
|
40 |
+
self.model = eval(config['model'])(train=True, warm_stage=True)
|
41 |
+
if self.config['multi-gpus']:
|
42 |
+
self.model = torch.nn.DataParallel(self.model).cuda()
|
43 |
+
elif self.with_cuda:
|
44 |
+
self.model = self.model.cuda()
|
45 |
+
|
46 |
+
## optimizer
|
47 |
+
self.optimizer = getattr(optim, config['optimizer_type'])(self.model.parameters(), **config['optimizer'])
|
48 |
+
self.lr_sheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **config['lr_sheduler'])
|
49 |
+
|
50 |
+
## dataset loader
|
51 |
+
with open(os.path.join(config['data_dir'], config['data_loader']['dataset'])) as f:
|
52 |
+
dataset = json.load(f)
|
53 |
+
train_set = Dataset(dataset['train'])
|
54 |
+
self.train_data_loader = DataLoader(train_set, batch_size=config['data_loader']['batch_size'],
|
55 |
+
shuffle=config['data_loader']['shuffle'],
|
56 |
+
num_workers=config['data_loader']['num_workers'])
|
57 |
+
val_set = Dataset(dataset['val'])
|
58 |
+
self.valid_data_loader = DataLoader(val_set, batch_size=config['data_loader']['batch_size'],
|
59 |
+
shuffle=False,
|
60 |
+
num_workers=config['data_loader']['num_workers'])
|
61 |
+
# special dataloader: constant color images
|
62 |
+
with open(os.path.join(config['data_dir'], config['data_loader']['special_set'])) as f:
|
63 |
+
dataset = json.load(f)
|
64 |
+
specialSet = Dataset(dataset['train'])
|
65 |
+
self.specialDataloader = DataLoader(specialSet, batch_size=config['data_loader']['batch_size'],
|
66 |
+
shuffle=config['data_loader']['shuffle'],
|
67 |
+
num_workers=config['data_loader']['num_workers'])
|
68 |
+
|
69 |
+
## loss function
|
70 |
+
self.quantizeLoss = eval(config['quantizeLoss'])
|
71 |
+
self.quantizeLossWeight = config['quantizeLossWeight']
|
72 |
+
self.toneLoss = eval(config['toneLoss'])
|
73 |
+
self.toneLossWeight = config['toneLossWeight']
|
74 |
+
self.structureLoss = eval(config['structureLoss'])
|
75 |
+
self.structureLossWeight = config['structureLossWeight']
|
76 |
+
# quantize [-1,1] data to be {-1,1}
|
77 |
+
self.quantizer = lambda x: Quantize.apply(0.5 * (x + 1.)) * 2. - 1.
|
78 |
+
self.blueNoiseLossWeight = config['blueNoiseLossWeight']
|
79 |
+
self.featureLoss = FeatureLoss(
|
80 |
+
requireGrad=False, pretrainedPath='checkpoints/invhalftone_checkpt/model_best.pth.tar')
|
81 |
+
self.featureLossWeight = config['featureLossWeight']
|
82 |
+
|
83 |
+
# resume checkpoint
|
84 |
+
if self.resume_path:
|
85 |
+
assert os.path.exists(resume_path), 'Invalid checkpoint Path: %s' % resume_path
|
86 |
+
self.load_checkpoint(self.resume_path)
|
87 |
+
|
88 |
+
|
89 |
+
def _train(self):
|
90 |
+
torch.manual_seed(self.config['seed'])
|
91 |
+
torch.cuda.manual_seed(self.config['seed'])
|
92 |
+
cudnn.benchmark = True
|
93 |
+
|
94 |
+
start_time = time.time()
|
95 |
+
self.monitor_best = 999.
|
96 |
+
for epoch in range(self.start_epoch, self.n_epochs + 1):
|
97 |
+
ep_st = time.time()
|
98 |
+
epoch_loss = self._train_epoch(epoch)
|
99 |
+
# perform lr_sheduler
|
100 |
+
self.lr_sheduler.step(epoch_loss['total_loss'])
|
101 |
+
epoch_lr = self.optimizer.state_dict()['param_groups'][0]['lr']
|
102 |
+
epoch_metric = self._valid_epoch(epoch)
|
103 |
+
print("[*] --- epoch: %d/%d | loss: %4.4f | metric: %4.4f | time-consumed: %4.2f ---" % \
|
104 |
+
(epoch+1, self.n_epochs, epoch_loss['total_loss'], epoch_metric, (time.time()-ep_st)))
|
105 |
+
|
106 |
+
# save losses and learning rate
|
107 |
+
epoch_loss['metric'] = epoch_metric
|
108 |
+
epoch_loss['lr'] = epoch_lr
|
109 |
+
self.save_loss(epoch_loss, epoch)
|
110 |
+
if ((epoch+1) % self.save_freq == 0 or epoch == (self.n_epochs-1)):
|
111 |
+
print('---------- saving model ...')
|
112 |
+
self.save_checkpoint(epoch)
|
113 |
+
if self.monitor_best > epoch_metric:
|
114 |
+
self.monitor_best = epoch_metric
|
115 |
+
self.save_checkpoint(epoch, save_best=True)
|
116 |
+
|
117 |
+
print("Training finished! consumed %f sec" % (time.time() - start_time))
|
118 |
+
|
119 |
+
|
120 |
+
def _to_variable(self, data, target):
|
121 |
+
data, target = Variable(data), Variable(target)
|
122 |
+
if self.with_cuda:
|
123 |
+
data, target = data.cuda(), target.cuda()
|
124 |
+
return data, target
|
125 |
+
|
126 |
+
|
127 |
+
def _train_epoch(self, epoch):
|
128 |
+
self.model.train()
|
129 |
+
total_loss, quantize_loss, feature_loss = 0, 0, 0
|
130 |
+
tone_loss, structure_loss, blue_noise_loss = 0, 0, 0
|
131 |
+
|
132 |
+
specialIter = iter(self.specialDataloader)
|
133 |
+
time_stamp = time.time()
|
134 |
+
for batch_idx, (color, halftone) in enumerate(self.train_data_loader):
|
135 |
+
color, halftone = self._to_variable(color, halftone)
|
136 |
+
# special data
|
137 |
+
try:
|
138 |
+
specialColor, specialHalftone = next(specialIter)
|
139 |
+
except StopIteration:
|
140 |
+
# reinitialize data loader
|
141 |
+
specialIter = iter(self.specialDataloader)
|
142 |
+
specialColor, specialHalftone = next(specialIter)
|
143 |
+
specialColor, specialHalftone = self._to_variable(specialColor, specialHalftone)
|
144 |
+
self.optimizer.zero_grad()
|
145 |
+
output = self.model(color, halftone)
|
146 |
+
quantizeLoss = self.quantizeLoss(output[0])
|
147 |
+
toneLoss = self.toneLoss(output[0], color)
|
148 |
+
structureLoss = self.structureLoss(output[0], color)
|
149 |
+
featureLoss = self.featureLoss(output[0], bgr2gray(color))
|
150 |
+
|
151 |
+
# special data
|
152 |
+
output = self.model(specialColor, specialHalftone)
|
153 |
+
toneLossSpecial = self.toneLoss(output[0], specialColor)
|
154 |
+
blueNoiseLoss = l1_loss(output[1], output[2])
|
155 |
+
quantizeLossSpecial = self.quantizeLoss(output[0])
|
156 |
+
loss = (self.toneLossWeight * toneLoss + self.blueNoiseLossWeight*toneLossSpecial) \
|
157 |
+
+ self.quantizeLossWeight * (0.5*quantizeLoss + 0.5*quantizeLossSpecial) \
|
158 |
+
+ self.structureLossWeight * structureLoss \
|
159 |
+
+ self.blueNoiseLossWeight * blueNoiseLoss \
|
160 |
+
+ self.featureLossWeight * featureLoss
|
161 |
+
|
162 |
+
loss.backward()
|
163 |
+
# apply grad clip to make training roboust
|
164 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.0001)
|
165 |
+
self.optimizer.step()
|
166 |
+
|
167 |
+
total_loss += loss.item()
|
168 |
+
quantize_loss += quantizeLoss.item()
|
169 |
+
feature_loss += featureLoss.item()
|
170 |
+
tone_loss += toneLoss.item()
|
171 |
+
structure_loss += structureLoss.item()
|
172 |
+
blue_noise_loss += blueNoiseLoss.item()
|
173 |
+
if batch_idx % 100 == 0:
|
174 |
+
tm = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
175 |
+
print("%s >> [%d/%d] iter:%d loss:%4.4f "%(tm, epoch+1, self.n_epochs, batch_idx+1, loss.item()))
|
176 |
+
|
177 |
+
epoch_loss = dict()
|
178 |
+
epoch_loss['total_loss'] = total_loss / (batch_idx+1)
|
179 |
+
epoch_loss['quantize_loss'] = quantize_loss / (batch_idx+1)
|
180 |
+
epoch_loss['tone_loss'] = tone_loss / (batch_idx+1)
|
181 |
+
epoch_loss['structure_loss'] = structure_loss / (batch_idx+1)
|
182 |
+
epoch_loss['bluenoise_loss'] = blue_noise_loss / (batch_idx+1)
|
183 |
+
epoch_loss['feature_loss'] = feature_loss / (batch_idx+1)
|
184 |
+
|
185 |
+
return epoch_loss
|
186 |
+
|
187 |
+
|
188 |
+
def _valid_epoch(self, epoch):
|
189 |
+
self.model.eval()
|
190 |
+
total_loss = 0
|
191 |
+
with torch.no_grad():
|
192 |
+
for batch_idx, (color, halftone) in enumerate(self.valid_data_loader):
|
193 |
+
color, halftone = self._to_variable(color, halftone)
|
194 |
+
output = self.model(color, halftone)
|
195 |
+
quantizeLoss = self.quantizeLoss(output[0])
|
196 |
+
toneLoss = self.toneLoss(output[0], color)
|
197 |
+
structureLoss = self.structureLoss(output[0], color)
|
198 |
+
featureLoss = self.featureLoss(output[0], bgr2gray(color))
|
199 |
+
|
200 |
+
loss = self.toneLossWeight * toneLoss \
|
201 |
+
+ self.quantizeLossWeight * quantizeLoss \
|
202 |
+
+ self.structureLossWeight * structureLoss \
|
203 |
+
+ self.featureLossWeight * featureLoss
|
204 |
+
|
205 |
+
total_loss += loss.item()
|
206 |
+
#! save intermediate images
|
207 |
+
gray_imgs = tensor2array(output[0])
|
208 |
+
color_imgs = tensor2array(output[-1])
|
209 |
+
save_images_from_batch(gray_imgs, self.val_halftone, None, batch_idx)
|
210 |
+
save_images_from_batch(color_imgs, self.val_restored, None, batch_idx)
|
211 |
+
|
212 |
+
return total_loss
|
213 |
+
|
214 |
+
|
215 |
+
def save_loss(self, epoch_loss, epoch):
|
216 |
+
if epoch == 0:
|
217 |
+
for key in epoch_loss:
|
218 |
+
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=False)
|
219 |
+
else:
|
220 |
+
for key in epoch_loss:
|
221 |
+
save_list(os.path.join(self.cache, key), [epoch_loss[key]], append_mode=True)
|
222 |
+
|
223 |
+
|
224 |
+
def load_checkpoint(self, checkpt_path):
|
225 |
+
print("-loading checkpoint from: {} ...".format(checkpt_path))
|
226 |
+
checkpoint = torch.load(checkpt_path)
|
227 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
228 |
+
self.monitor_best = checkpoint['monitor_best']
|
229 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
230 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
231 |
+
print("-pretrained checkpoint loaded.")
|
232 |
+
|
233 |
+
|
234 |
+
def save_checkpoint(self, epoch, save_best=False):
|
235 |
+
state = {
|
236 |
+
'epoch': epoch,
|
237 |
+
'state_dict': self.model.state_dict(),
|
238 |
+
'optimizer': self.optimizer.state_dict(),
|
239 |
+
'monitor_best': self.monitor_best
|
240 |
+
}
|
241 |
+
save_path = os.path.join(self.checkpoint_dir, 'model_last.pth.tar')
|
242 |
+
if save_best:
|
243 |
+
save_path = os.path.join(self.checkpoint_dir, 'model_best.pth.tar')
|
244 |
+
torch.save(state, save_path)
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == '__main__':
|
248 |
+
parser = argparse.ArgumentParser(description='InvHalf')
|
249 |
+
parser.add_argument('-c', '--config', default=None, type=str,
|
250 |
+
help='config file path (default: None)')
|
251 |
+
parser.add_argument('-r', '--resume', default=None, type=str,
|
252 |
+
help='path to latest checkpoint (default: None)')
|
253 |
+
args = parser.parse_args()
|
254 |
+
config_dict = json.load(open(args.config))
|
255 |
+
node = Trainer(config_dict, resume=args.resume)
|
256 |
+
node._train()
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .util import *
|
utils/_dct.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def dct1(x):
|
8 |
+
"""
|
9 |
+
Discrete Cosine Transform, Type I
|
10 |
+
|
11 |
+
:param x: the input signal
|
12 |
+
:return: the DCT-I of the signal over the last dimension
|
13 |
+
"""
|
14 |
+
x_shape = x.shape
|
15 |
+
x = x.view(-1, x_shape[-1])
|
16 |
+
|
17 |
+
#return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
|
18 |
+
return torch.fft.fft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
|
19 |
+
|
20 |
+
|
21 |
+
def idct1(X):
|
22 |
+
"""
|
23 |
+
The inverse of DCT-I, which is just a scaled DCT-I
|
24 |
+
|
25 |
+
Our definition if idct1 is such that idct1(dct1(x)) == x
|
26 |
+
|
27 |
+
:param X: the input signal
|
28 |
+
:return: the inverse DCT-I of the signal over the last dimension
|
29 |
+
"""
|
30 |
+
n = X.shape[-1]
|
31 |
+
return dct1(X) / (2 * (n - 1))
|
32 |
+
|
33 |
+
|
34 |
+
def dct(x, norm=None):
|
35 |
+
"""
|
36 |
+
Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
37 |
+
|
38 |
+
For the meaning of the parameter `norm`, see:
|
39 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
40 |
+
|
41 |
+
:param x: the input signal
|
42 |
+
:param norm: the normalization, None or 'ortho'
|
43 |
+
:return: the DCT-II of the signal over the last dimension
|
44 |
+
"""
|
45 |
+
x_shape = x.shape
|
46 |
+
N = x_shape[-1]
|
47 |
+
x = x.contiguous().view(-1, N)
|
48 |
+
|
49 |
+
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
50 |
+
|
51 |
+
#Vc = torch.fft.rfft(v, 1, onesided=False)
|
52 |
+
Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
|
53 |
+
|
54 |
+
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
55 |
+
W_r = torch.cos(k)
|
56 |
+
W_i = torch.sin(k)
|
57 |
+
|
58 |
+
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
59 |
+
|
60 |
+
if norm == 'ortho':
|
61 |
+
V[:, 0] /= np.sqrt(N) * 2
|
62 |
+
V[:, 1:] /= np.sqrt(N / 2) * 2
|
63 |
+
|
64 |
+
V = 2 * V.view(*x_shape)
|
65 |
+
|
66 |
+
return V
|
67 |
+
|
68 |
+
|
69 |
+
def idct(X, norm=None):
|
70 |
+
"""
|
71 |
+
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
72 |
+
|
73 |
+
Our definition of idct is that idct(dct(x)) == x
|
74 |
+
|
75 |
+
For the meaning of the parameter `norm`, see:
|
76 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
77 |
+
|
78 |
+
:param X: the input signal
|
79 |
+
:param norm: the normalization, None or 'ortho'
|
80 |
+
:return: the inverse DCT-II of the signal over the last dimension
|
81 |
+
"""
|
82 |
+
|
83 |
+
x_shape = X.shape
|
84 |
+
N = x_shape[-1]
|
85 |
+
|
86 |
+
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
|
87 |
+
|
88 |
+
if norm == 'ortho':
|
89 |
+
X_v[:, 0] *= np.sqrt(N) * 2
|
90 |
+
X_v[:, 1:] *= np.sqrt(N / 2) * 2
|
91 |
+
|
92 |
+
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
|
93 |
+
W_r = torch.cos(k)
|
94 |
+
W_i = torch.sin(k)
|
95 |
+
|
96 |
+
V_t_r = X_v
|
97 |
+
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
|
98 |
+
|
99 |
+
V_r = V_t_r * W_r - V_t_i * W_i
|
100 |
+
V_i = V_t_r * W_i + V_t_i * W_r
|
101 |
+
|
102 |
+
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
|
103 |
+
|
104 |
+
#v = torch.irfft(V, 1, onesided=False)
|
105 |
+
v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
|
106 |
+
x = v.new_zeros(v.shape)
|
107 |
+
x[:, ::2] += v[:, :N - (N // 2)]
|
108 |
+
x[:, 1::2] += v.flip([1])[:, :N // 2]
|
109 |
+
|
110 |
+
return x.view(*x_shape)
|
111 |
+
|
112 |
+
|
113 |
+
def dct_2d(x, norm=None):
|
114 |
+
"""
|
115 |
+
2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
116 |
+
|
117 |
+
For the meaning of the parameter `norm`, see:
|
118 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
119 |
+
|
120 |
+
:param x: the input signal
|
121 |
+
:param norm: the normalization, None or 'ortho'
|
122 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
123 |
+
"""
|
124 |
+
X1 = dct(x, norm=norm)
|
125 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
126 |
+
return X2.transpose(-1, -2)
|
127 |
+
|
128 |
+
|
129 |
+
def idct_2d(X, norm=None):
|
130 |
+
"""
|
131 |
+
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
132 |
+
|
133 |
+
Our definition of idct is that idct_2d(dct_2d(x)) == x
|
134 |
+
|
135 |
+
For the meaning of the parameter `norm`, see:
|
136 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
137 |
+
|
138 |
+
:param X: the input signal
|
139 |
+
:param norm: the normalization, None or 'ortho'
|
140 |
+
:return: the DCT-II of the signal over the last 2 dimensions
|
141 |
+
"""
|
142 |
+
x1 = idct(X, norm=norm)
|
143 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
144 |
+
return x2.transpose(-1, -2)
|
145 |
+
|
146 |
+
|
147 |
+
def dct_3d(x, norm=None):
|
148 |
+
"""
|
149 |
+
3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
150 |
+
|
151 |
+
For the meaning of the parameter `norm`, see:
|
152 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
153 |
+
|
154 |
+
:param x: the input signal
|
155 |
+
:param norm: the normalization, None or 'ortho'
|
156 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
157 |
+
"""
|
158 |
+
X1 = dct(x, norm=norm)
|
159 |
+
X2 = dct(X1.transpose(-1, -2), norm=norm)
|
160 |
+
X3 = dct(X2.transpose(-1, -3), norm=norm)
|
161 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
162 |
+
|
163 |
+
|
164 |
+
def idct_3d(X, norm=None):
|
165 |
+
"""
|
166 |
+
The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
|
167 |
+
|
168 |
+
Our definition of idct is that idct_3d(dct_3d(x)) == x
|
169 |
+
|
170 |
+
For the meaning of the parameter `norm`, see:
|
171 |
+
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
172 |
+
|
173 |
+
:param X: the input signal
|
174 |
+
:param norm: the normalization, None or 'ortho'
|
175 |
+
:return: the DCT-II of the signal over the last 3 dimensions
|
176 |
+
"""
|
177 |
+
x1 = idct(X, norm=norm)
|
178 |
+
x2 = idct(x1.transpose(-1, -2), norm=norm)
|
179 |
+
x3 = idct(x2.transpose(-1, -3), norm=norm)
|
180 |
+
return x3.transpose(-1, -3).transpose(-1, -2)
|
181 |
+
|
182 |
+
|
183 |
+
# class LinearDCT(nn.Linear):
|
184 |
+
# """Implement any DCT as a linear layer; in practice this executes around
|
185 |
+
# 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
|
186 |
+
# increase memory usage.
|
187 |
+
# :param in_features: size of expected input
|
188 |
+
# :param type: which dct function in this file to use"""
|
189 |
+
#
|
190 |
+
# def __init__(self, in_features, type, norm=None, bias=False):
|
191 |
+
# self.type = type
|
192 |
+
# self.N = in_features
|
193 |
+
# self.norm = norm
|
194 |
+
# super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
|
195 |
+
#
|
196 |
+
# def reset_parameters(self):
|
197 |
+
# # initialise using dct function
|
198 |
+
# I = torch.eye(self.N)
|
199 |
+
# if self.type == 'dct1':
|
200 |
+
# self.weight.data = dct1(I).data.t()
|
201 |
+
# elif self.type == 'idct1':
|
202 |
+
# self.weight.data = idct1(I).data.t()
|
203 |
+
# elif self.type == 'dct':
|
204 |
+
# self.weight.data = dct(I, norm=self.norm).data.t()
|
205 |
+
# elif self.type == 'idct':
|
206 |
+
# self.weight.data = idct(I, norm=self.norm).data.t()
|
207 |
+
# self.weight.require_grad = False # don't learn this!
|
208 |
+
|
209 |
+
class LinearDCT(nn.Module):
|
210 |
+
"""Implement any DCT as a linear layer; in practice this executes around
|
211 |
+
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
|
212 |
+
increase memory usage.
|
213 |
+
:param in_features: size of expected input
|
214 |
+
:param type: which dct function in this file to use"""
|
215 |
+
|
216 |
+
def __init__(self, in_features, type, norm=None):
|
217 |
+
super(LinearDCT, self).__init__()
|
218 |
+
self.type = type
|
219 |
+
self.N = in_features
|
220 |
+
self.norm = norm
|
221 |
+
I = torch.eye(self.N)
|
222 |
+
if self.type == 'dct1':
|
223 |
+
self.weight = dct1(I).data.t()
|
224 |
+
elif self.type == 'idct1':
|
225 |
+
self.weight = idct1(I).data.t()
|
226 |
+
elif self.type == 'dct':
|
227 |
+
self.weight = dct(I, norm=self.norm).data.t()
|
228 |
+
elif self.type == 'idct':
|
229 |
+
self.weight = idct(I, norm=self.norm).data.t()
|
230 |
+
# self.register_buffer('weight', kernel)
|
231 |
+
# self.weight = kernel
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return F.linear(x, weight=self.weight.cuda(x.get_device()))
|
235 |
+
|
236 |
+
|
237 |
+
def apply_linear_2d(x, linear_layer):
|
238 |
+
"""Can be used with a LinearDCT layer to do a 2D DCT.
|
239 |
+
:param x: the input signal
|
240 |
+
:param linear_layer: any PyTorch Linear layer
|
241 |
+
:return: result of linear layer applied to last 2 dimensions
|
242 |
+
"""
|
243 |
+
X1 = linear_layer(x)
|
244 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
245 |
+
return X2.transpose(-1, -2)
|
246 |
+
|
247 |
+
|
248 |
+
def apply_linear_3d(x, linear_layer):
|
249 |
+
"""Can be used with a LinearDCT layer to do a 3D DCT.
|
250 |
+
:param x: the input signal
|
251 |
+
:param linear_layer: any PyTorch Linear layer
|
252 |
+
:return: result of linear layer applied to last 3 dimensions
|
253 |
+
"""
|
254 |
+
X1 = linear_layer(x)
|
255 |
+
X2 = linear_layer(X1.transpose(-1, -2))
|
256 |
+
X3 = linear_layer(X2.transpose(-1, -3))
|
257 |
+
return X3.transpose(-1, -3).transpose(-1, -2)
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == '__main__':
|
261 |
+
x = torch.Tensor(1000, 4096)
|
262 |
+
x.normal_(0, 1)
|
263 |
+
linear_dct = LinearDCT(4096, 'dct')
|
264 |
+
error = torch.abs(dct(x) - linear_dct(x))
|
265 |
+
assert error.max() < 1e-3, (error, error.max())
|
266 |
+
linear_idct = LinearDCT(4096, 'idct')
|
267 |
+
error = torch.abs(idct(x) - linear_idct(x))
|
268 |
+
assert error.max() < 1e-3, (error, error.max())
|
utils/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data as data
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from os.path import join
|
6 |
+
|
7 |
+
|
8 |
+
class HalftoneVOC2012(data.Dataset):
|
9 |
+
# data range is [-1,1], color image is in BGR format
|
10 |
+
def __init__(self, data_list):
|
11 |
+
super(HalftoneVOC2012, self).__init__()
|
12 |
+
self.inputs = [join('Data', x) for x in data_list['inputs']]
|
13 |
+
self.labels = [join('Data', x) for x in data_list['labels']]
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def load_input(name):
|
17 |
+
img = cv2.imread(name, flags=cv2.IMREAD_COLOR)
|
18 |
+
# transpose data
|
19 |
+
img = img.transpose((2, 0, 1))
|
20 |
+
# to Tensor
|
21 |
+
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
|
22 |
+
return img
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def load_label(name):
|
26 |
+
img = cv2.imread(name, flags=cv2.IMREAD_GRAYSCALE)
|
27 |
+
# transpose data
|
28 |
+
img = img[np.newaxis, :, :]
|
29 |
+
# to Tensor
|
30 |
+
img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
|
31 |
+
return img
|
32 |
+
|
33 |
+
def __getitem__(self, index):
|
34 |
+
input_data = self.load_input(self.inputs[index])
|
35 |
+
label_data = self.load_label(self.labels[index])
|
36 |
+
return input_data, label_data
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.inputs)
|
utils/dct.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
-------------------------------------------------
|
4 |
+
File Name: dct
|
5 |
+
Author : wenbo
|
6 |
+
date: 12/4/2019
|
7 |
+
Description :
|
8 |
+
-------------------------------------------------
|
9 |
+
Change Activity:
|
10 |
+
12/4/2019:
|
11 |
+
-------------------------------------------------
|
12 |
+
"""
|
13 |
+
__author__ = 'wenbo'
|
14 |
+
|
15 |
+
from torch import nn
|
16 |
+
from ._dct import LinearDCT, apply_linear_2d
|
17 |
+
|
18 |
+
|
19 |
+
class DCT_Lowfrequency(nn.Module):
|
20 |
+
def __init__(self, size=256, fLimit=50):
|
21 |
+
super(DCT_Lowfrequency, self).__init__()
|
22 |
+
self.fLimit = fLimit
|
23 |
+
self.dct = LinearDCT(size, type='dct', norm='ortho')
|
24 |
+
self.dctTransformer = lambda x: apply_linear_2d(x, self.dct)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.dctTransformer(x)
|
28 |
+
x = x[:, :, :self.fLimit, :self.fLimit]
|
29 |
+
return x
|
utils/filters_tensor.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class GaussianSmoothing(nn.Module):
|
9 |
+
"""
|
10 |
+
Apply gaussian smoothing on a
|
11 |
+
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
12 |
+
in the input using a depthwise convolution.
|
13 |
+
Arguments:
|
14 |
+
channels (int, sequence): Number of channels of the input tensors. Output will
|
15 |
+
have this number of channels as well.
|
16 |
+
kernel_size (int, sequence): Size of the gaussian kernel.
|
17 |
+
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
18 |
+
dim (int, optional): The number of dimensions of the data.
|
19 |
+
Default value is 2 (spatial).
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, channels, kernel_size, sigma, dim=2, cuda=True):
|
23 |
+
super(GaussianSmoothing, self).__init__()
|
24 |
+
if isinstance(kernel_size, numbers.Number):
|
25 |
+
kernel_size = [kernel_size] * dim
|
26 |
+
if isinstance(sigma, numbers.Number):
|
27 |
+
sigma = [sigma] * dim
|
28 |
+
|
29 |
+
# The gaussian kernel is the product of the
|
30 |
+
# gaussian function of each dimension.
|
31 |
+
kernel = 1
|
32 |
+
meshgrids = torch.meshgrid(
|
33 |
+
[
|
34 |
+
torch.arange(size, dtype=torch.float32)
|
35 |
+
for size in kernel_size
|
36 |
+
]
|
37 |
+
)
|
38 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
39 |
+
mean = (size - 1) / 2
|
40 |
+
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
41 |
+
|
42 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
43 |
+
kernel = kernel / torch.sum(kernel)
|
44 |
+
|
45 |
+
# Reshape to depthwise convolutional weight
|
46 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
47 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
48 |
+
|
49 |
+
# if cuda:
|
50 |
+
# kernel = kernel.cuda()
|
51 |
+
# self.register_buffer('weight', kernel)
|
52 |
+
self.weight = kernel
|
53 |
+
self.groups = channels
|
54 |
+
|
55 |
+
if dim == 1:
|
56 |
+
self.conv = F.conv1d
|
57 |
+
elif dim == 2:
|
58 |
+
self.conv = F.conv2d
|
59 |
+
elif dim == 3:
|
60 |
+
self.conv = F.conv3d
|
61 |
+
else:
|
62 |
+
raise RuntimeError(
|
63 |
+
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, input):
|
67 |
+
"""
|
68 |
+
Apply gaussian filter to input.
|
69 |
+
Arguments:
|
70 |
+
input (torch.Tensor): Input to apply gaussian filter on.
|
71 |
+
Returns:
|
72 |
+
filtered (torch.Tensor): Filtered output.
|
73 |
+
"""
|
74 |
+
return self.conv(input, weight=self.weight.cuda(input.get_device()), groups=self.groups)
|
75 |
+
|
76 |
+
|
77 |
+
def bgr2gray(color):
|
78 |
+
# gray = 0.299⋅R+0.587⋅G+0.114⋅B
|
79 |
+
gray = color[:, 0, ...] * 0.114 + color[:, 1, ...] * 0.587 + color[:, 2, ...] * 0.299
|
80 |
+
gray = gray.unsqueeze_(1)
|
81 |
+
return gray
|
utils/pytorch_ssim.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import numpy as np
|
5 |
+
from math import exp
|
6 |
+
|
7 |
+
|
8 |
+
def gaussian(window_size, sigma):
|
9 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
10 |
+
return gauss / gauss.sum()
|
11 |
+
|
12 |
+
|
13 |
+
def create_window(window_size, channel):
|
14 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
15 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
16 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
17 |
+
return window
|
18 |
+
|
19 |
+
|
20 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
21 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
22 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
23 |
+
|
24 |
+
mu1_sq = mu1.pow(2)
|
25 |
+
mu2_sq = mu2.pow(2)
|
26 |
+
mu1_mu2 = mu1 * mu2
|
27 |
+
|
28 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
29 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
30 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
31 |
+
|
32 |
+
C1 = 0.01 ** 2
|
33 |
+
C2 = 0.03 ** 2
|
34 |
+
|
35 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
36 |
+
|
37 |
+
if size_average:
|
38 |
+
return ssim_map.mean()
|
39 |
+
else:
|
40 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
41 |
+
|
42 |
+
|
43 |
+
class SSIM(torch.nn.Module):
|
44 |
+
def __init__(self, window_size=11, size_average=True):
|
45 |
+
super(SSIM, self).__init__()
|
46 |
+
self.window_size = window_size
|
47 |
+
self.size_average = size_average
|
48 |
+
self.channel = 1
|
49 |
+
self.window = create_window(window_size, self.channel)
|
50 |
+
|
51 |
+
def forward(self, img1, img2):
|
52 |
+
(_, channel, _, _) = img1.size()
|
53 |
+
|
54 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
55 |
+
window = self.window
|
56 |
+
else:
|
57 |
+
window = create_window(self.window_size, channel)
|
58 |
+
|
59 |
+
if img1.is_cuda:
|
60 |
+
window = window.cuda(img1.get_device())
|
61 |
+
window = window.type_as(img1)
|
62 |
+
|
63 |
+
self.window = window
|
64 |
+
self.channel = channel
|
65 |
+
|
66 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
67 |
+
|
68 |
+
|
69 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
70 |
+
(_, channel, _, _) = img1.size()
|
71 |
+
window = create_window(window_size, channel)
|
72 |
+
|
73 |
+
if img1.is_cuda:
|
74 |
+
window = window.cuda(img1.get_device())
|
75 |
+
window = window.type_as(img1)
|
76 |
+
|
77 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
utils/util.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def ensure_dir(path):
|
7 |
+
if not os.path.exists(path):
|
8 |
+
os.makedirs(path)
|
9 |
+
|
10 |
+
|
11 |
+
def get_filelist(data_dir):
|
12 |
+
file_list = glob.glob(os.path.join(data_dir, '*.*'))
|
13 |
+
file_list.sort()
|
14 |
+
return file_list
|
15 |
+
|
16 |
+
|
17 |
+
def collect_filenames(data_dir):
|
18 |
+
file_list = get_filelist(data_dir)
|
19 |
+
name_list = []
|
20 |
+
for file_path in file_list:
|
21 |
+
_, file_name = os.path.split(file_path)
|
22 |
+
name_list.append(file_name)
|
23 |
+
name_list.sort()
|
24 |
+
return name_list
|
25 |
+
|
26 |
+
|
27 |
+
def save_list(save_path, data_list, append_mode=False):
|
28 |
+
n = len(data_list)
|
29 |
+
if append_mode:
|
30 |
+
with open(save_path, 'a') as f:
|
31 |
+
f.writelines([str(data_list[i]) + '\n' for i in range(n-1,n)])
|
32 |
+
else:
|
33 |
+
with open(save_path, 'w') as f:
|
34 |
+
f.writelines([str(data_list[i]) + '\n' for i in range(n)])
|
35 |
+
return None
|
36 |
+
|
37 |
+
|
38 |
+
def save_images_from_batch(img_batch, save_dir, filename_list, batch_no=-1):
|
39 |
+
N,H,W,C = img_batch.shape
|
40 |
+
if C == 3:
|
41 |
+
#! rgb color image
|
42 |
+
for i in range(N):
|
43 |
+
# [-1,1] >>> [0,255]
|
44 |
+
img_batch_i = np.clip(img_batch[i,:,:,:]*0.5+0.5, 0, 1)
|
45 |
+
image = (255.0*img_batch_i).astype(np.uint8)
|
46 |
+
save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*N+i)
|
47 |
+
cv2.imwrite(os.path.join(save_dir, save_name), image)
|
48 |
+
elif C == 1:
|
49 |
+
#! single-channel gray image
|
50 |
+
for i in range(N):
|
51 |
+
# [-1,1] >>> [0,255]
|
52 |
+
img_batch_i = np.clip(img_batch[i,:,:,0]*0.5+0.5, 0, 1)
|
53 |
+
image = (255.0*img_batch_i).astype(np.uint8)
|
54 |
+
save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*img_batch.shape[0]+i)
|
55 |
+
cv2.imwrite(os.path.join(save_dir, save_name), image)
|
56 |
+
return None
|
57 |
+
|
58 |
+
|
59 |
+
def imagesc(nd_array):
|
60 |
+
plt.imshow(nd_array)
|
61 |
+
plt.colorbar()
|
62 |
+
plt.show()
|
63 |
+
|
64 |
+
|
65 |
+
def img2tensor(img):
|
66 |
+
if len(img.shape) == 2:
|
67 |
+
img = img[..., np.newaxis]
|
68 |
+
img_t = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
69 |
+
img_t = torch.from_numpy(img_t.astype(np.float32))
|
70 |
+
return img_t
|
71 |
+
|
72 |
+
|
73 |
+
def tensor2img(img_t):
|
74 |
+
img = img_t[0].detach().to("cpu").numpy()
|
75 |
+
img = np.transpose(img, (1, 2, 0))
|
76 |
+
if img.shape[-1] == 1:
|
77 |
+
img = img[..., 0]
|
78 |
+
return img
|