Spaces:
Build error
Build error
BayesCap demo to EuroCrypt
Browse files- LICENSE +201 -0
- README.md +19 -6
- app.py +152 -0
- demo_examples/baby.png +0 -0
- demo_examples/bird.png +0 -0
- demo_examples/butterfly.png +0 -0
- demo_examples/head.png +0 -0
- demo_examples/tue.jpeg +0 -0
- demo_examples/woman.png +0 -0
- ds.py +485 -0
- losses.py +131 -0
- networks_SRGAN.py +347 -0
- networks_T1toT2.py +477 -0
- requirements.txt +158 -0
- src/.gitkeep +0 -0
- src/README.md +26 -0
- src/__pycache__/ds.cpython-310.pyc +0 -0
- src/__pycache__/losses.cpython-310.pyc +0 -0
- src/__pycache__/networks_SRGAN.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/app.py +115 -0
- src/ds.py +485 -0
- src/flagged/Alpha/0.png +0 -0
- src/flagged/Beta/0.png +0 -0
- src/flagged/Low-res/0.png +0 -0
- src/flagged/Orignal/0.png +0 -0
- src/flagged/Super-res/0.png +0 -0
- src/flagged/Uncertainty/0.png +0 -0
- src/flagged/log.csv +2 -0
- src/losses.py +131 -0
- src/networks_SRGAN.py +347 -0
- src/networks_T1toT2.py +477 -0
- src/utils.py +1273 -0
- utils.py +117 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,26 @@
|
|
1 |
---
|
2 |
title: BayesCap
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.24
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: cc
|
11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
1 |
---
|
2 |
title: BayesCap
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
9 |
---
|
10 |
+
# Configuration
|
11 |
+
`title`: _string_
|
12 |
+
Display title for the Space
|
13 |
+
`emoji`: _string_
|
14 |
+
Space emoji (emoji-only character allowed)
|
15 |
+
`colorFrom`: _string_
|
16 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
17 |
+
`colorTo`: _string_
|
18 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
19 |
+
`sdk`: _string_
|
20 |
+
Can be either `gradio` or `streamlit`
|
21 |
+
`app_file`: _string_
|
22 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
23 |
+
Path is relative to the root of the repository.
|
24 |
|
25 |
+
`pinned`: _boolean_
|
26 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import cm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.models as models
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from ds import *
|
17 |
+
from losses import *
|
18 |
+
from networks_SRGAN import *
|
19 |
+
from utils import *
|
20 |
+
|
21 |
+
device = 'cpu'
|
22 |
+
if device == 'cuda':
|
23 |
+
dtype = torch.cuda.FloatTensor
|
24 |
+
else:
|
25 |
+
dtype = torch.FloatTensor
|
26 |
+
|
27 |
+
NetG = Generator()
|
28 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
29 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
30 |
+
print("Number of Parameters:", params)
|
31 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
32 |
+
|
33 |
+
ensure_checkpoint_exists('BayesCap_SRGAN.pth')
|
34 |
+
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
|
35 |
+
NetG.to(device)
|
36 |
+
NetG.eval()
|
37 |
+
|
38 |
+
ensure_checkpoint_exists('BayesCap_ckpt.pth')
|
39 |
+
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
|
40 |
+
NetC.to(device)
|
41 |
+
NetC.eval()
|
42 |
+
|
43 |
+
def tensor01_to_pil(xt):
|
44 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
45 |
+
return r
|
46 |
+
|
47 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
48 |
+
"""Convert ``PIL.Image`` to Tensor.
|
49 |
+
Args:
|
50 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
51 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
52 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
53 |
+
Returns:
|
54 |
+
Normalized image data
|
55 |
+
Examples:
|
56 |
+
>>> image = Image.open("image.bmp")
|
57 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
58 |
+
"""
|
59 |
+
tensor = F.to_tensor(image)
|
60 |
+
|
61 |
+
if range_norm:
|
62 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
63 |
+
if half:
|
64 |
+
tensor = tensor.half()
|
65 |
+
|
66 |
+
return tensor
|
67 |
+
|
68 |
+
|
69 |
+
def predict(img):
|
70 |
+
"""
|
71 |
+
img: image
|
72 |
+
"""
|
73 |
+
image_size = (256,256)
|
74 |
+
upscale_factor = 4
|
75 |
+
# lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
76 |
+
# to retain aspect ratio
|
77 |
+
lr_transforms = transforms.Resize(image_size[0]//upscale_factor, interpolation=IMode.BICUBIC, antialias=True)
|
78 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
79 |
+
|
80 |
+
img = Image.fromarray(np.array(img))
|
81 |
+
img = lr_transforms(img)
|
82 |
+
lr_tensor = image2tensor(img, range_norm=False, half=False)
|
83 |
+
|
84 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
85 |
+
xLR = xLR.type(dtype)
|
86 |
+
# pass them through the network
|
87 |
+
with torch.no_grad():
|
88 |
+
xSR = NetG(xLR)
|
89 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
90 |
+
|
91 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
92 |
+
b_map = xSRC_beta[0].to('cpu').data
|
93 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
94 |
+
|
95 |
+
|
96 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
97 |
+
|
98 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
99 |
+
|
100 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
101 |
+
|
102 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
103 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
104 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
105 |
+
|
106 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
107 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
108 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
109 |
+
|
110 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
111 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
112 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
113 |
+
|
114 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
115 |
+
|
116 |
+
import gradio as gr
|
117 |
+
|
118 |
+
title = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks"
|
119 |
+
|
120 |
+
abstract="<b>Abstract.</b> High-quality calibrated uncertainty estimates are crucial for numerous real-world applications, especially for deep learning-based deployed ML systems. While Bayesian deep learning techniques allow uncertainty estimation, training them with large-scale datasets is an expensive process that does not always yield models competitive with non-Bayesian counterparts. Moreover, many of the high-performing deep learning models that are already trained and deployed are non-Bayesian in nature and do not provide uncertainty estimates. To address these issues, we propose BayesCap that learns a Bayesian identity mapping for the frozen model, allowing uncertainty estimation. BayesCap is a memory-efficient method that can be trained on a small fraction of the original dataset, enhancing pretrained non-Bayesian computer vision models by providing calibrated uncertainty estimates for the predictions without (i) hampering the performance of the model and (ii) the need for expensive retraining the model from scratch. The proposed method is agnostic to various architectures and tasks. We show the efficacy of our method on a wide variety of tasks with a diverse set of architectures, including image super-resolution, deblurring, inpainting, and crucial application such as medical image translation. Moreover, we apply the derived uncertainty estimates to detect out-of-distribution samples in critical scenarios like depth estimation in autonomous driving. Code is available <a href='https://github.com/ExplainableML/BayesCap'>here</a>. <br> <br>"
|
121 |
+
|
122 |
+
method = "In this demo, we show an application of BayesCap on top of SRGAN for the task of super resolution. BayesCap estimates the per-pixel uncertainty of a pretrained computer vision model like SRGAN (used for super-resolution). BayesCap takes the ouput of the pretrained model (in this case SRGAN), and predicts the per-pixel distribution parameters for the output, that can be used to quantify the per-pixel uncertainty. In our work, we model the per-pixel output as a <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>Generalized Gaussian distribution</a> that is parameterized by 3 parameters the mean, scale (alpha), and the shape (beta). As a result our model predicts these three parameters as shown below. From these 3 parameters one can compute the uncertainty as shown in <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>this article</a>. <br><br>"
|
123 |
+
|
124 |
+
closing = "For more details, please find the <a href='https://arxiv.org/'>ECCV 2022 paper here</a>."
|
125 |
+
|
126 |
+
description = abstract + method + closing
|
127 |
+
|
128 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
129 |
+
|
130 |
+
|
131 |
+
gr.Interface(
|
132 |
+
fn=predict,
|
133 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
134 |
+
outputs=[
|
135 |
+
gr.outputs.Image(type='pil', label="Low-resolution image (input to SRGAN)"),
|
136 |
+
gr.outputs.Image(type='pil', label="Super-resolved image (output of SRGAN)"),
|
137 |
+
gr.outputs.Image(type='pil', label="Alpha parameter map characterizing per-pixel distribution (output of BayesCap)"),
|
138 |
+
gr.outputs.Image(type='pil', label="Beta parameter map characterizing per-pixel distribution (output of BayesCap)"),
|
139 |
+
gr.outputs.Image(type='pil', label="Per-pixel uncertainty map (derived using outputs of BayesCap)")
|
140 |
+
],
|
141 |
+
title=title,
|
142 |
+
description=description,
|
143 |
+
article=article,
|
144 |
+
examples=[
|
145 |
+
["./demo_examples/tue.jpeg"],
|
146 |
+
["./demo_examples/baby.png"],
|
147 |
+
["./demo_examples/bird.png"],
|
148 |
+
["./demo_examples/butterfly.png"],
|
149 |
+
["./demo_examples/head.png"],
|
150 |
+
["./demo_examples/woman.png"],
|
151 |
+
]
|
152 |
+
).launch()
|
demo_examples/baby.png
ADDED
![]() |
demo_examples/bird.png
ADDED
![]() |
demo_examples/butterfly.png
ADDED
![]() |
demo_examples/head.png
ADDED
![]() |
demo_examples/tue.jpeg
ADDED
![]() |
demo_examples/woman.png
ADDED
![]() |
ds.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import skimage.transform
|
10 |
+
from collections import Counter
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.utils.data as data
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torchvision import transforms
|
18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
19 |
+
|
20 |
+
import utils
|
21 |
+
|
22 |
+
class ImgDset(Dataset):
|
23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataroot (str): Training data set address
|
27 |
+
image_size (int): High resolution image size
|
28 |
+
upscale_factor (int): Image magnification
|
29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
30 |
+
and the verification data set is not for data enhancement
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
35 |
+
super(ImgDset, self).__init__()
|
36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
37 |
+
|
38 |
+
if mode == "train":
|
39 |
+
self.hr_transforms = transforms.Compose([
|
40 |
+
transforms.RandomCrop(image_size),
|
41 |
+
transforms.RandomRotation(90),
|
42 |
+
transforms.RandomHorizontalFlip(0.5),
|
43 |
+
])
|
44 |
+
else:
|
45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
46 |
+
|
47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
48 |
+
|
49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
50 |
+
# Read a batch of image data
|
51 |
+
image = Image.open(self.filenames[batch_index])
|
52 |
+
|
53 |
+
# Transform image
|
54 |
+
hr_image = self.hr_transforms(image)
|
55 |
+
lr_image = self.lr_transforms(hr_image)
|
56 |
+
|
57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
58 |
+
# Note: The range of input and output is between [0, 1]
|
59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
61 |
+
|
62 |
+
return lr_tensor, hr_tensor
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
return len(self.filenames)
|
66 |
+
|
67 |
+
|
68 |
+
class PairedImages_w_nameList(Dataset):
|
69 |
+
'''
|
70 |
+
can act as supervised or un-supervised based on flists
|
71 |
+
'''
|
72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
73 |
+
self.flist1 = flist1
|
74 |
+
self.flist2 = flist2
|
75 |
+
self.transform1 = transform1
|
76 |
+
self.transform2 = transform2
|
77 |
+
self.do_aug = do_aug
|
78 |
+
def __getitem__(self, index):
|
79 |
+
impath1 = self.flist1[index]
|
80 |
+
img1 = Image.open(impath1).convert('RGB')
|
81 |
+
impath2 = self.flist2[index]
|
82 |
+
img2 = Image.open(impath2).convert('RGB')
|
83 |
+
|
84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
86 |
+
|
87 |
+
if self.transform1 is not None:
|
88 |
+
img1 = self.transform1(img1)
|
89 |
+
if self.transform2 is not None:
|
90 |
+
img2 = self.transform2(img2)
|
91 |
+
|
92 |
+
return img1, img2
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.flist1)
|
95 |
+
|
96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
97 |
+
'''
|
98 |
+
can act as supervised or un-supervised based on flists
|
99 |
+
'''
|
100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
101 |
+
self.flist1 = flist1
|
102 |
+
self.flist2 = flist2
|
103 |
+
self.transform1 = transform1
|
104 |
+
self.transform2 = transform2
|
105 |
+
self.do_aug = do_aug
|
106 |
+
def __getitem__(self, index):
|
107 |
+
impath1 = self.flist1[index]
|
108 |
+
img1 = np.load(impath1)
|
109 |
+
impath2 = self.flist2[index]
|
110 |
+
img2 = np.load(impath2)
|
111 |
+
|
112 |
+
if self.transform1 is not None:
|
113 |
+
img1 = self.transform1(img1)
|
114 |
+
if self.transform2 is not None:
|
115 |
+
img2 = self.transform2(img2)
|
116 |
+
|
117 |
+
return img1, img2
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.flist1)
|
120 |
+
|
121 |
+
# def call_paired():
|
122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
124 |
+
|
125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
127 |
+
|
128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
129 |
+
|
130 |
+
#### KITTI depth
|
131 |
+
|
132 |
+
def load_velodyne_points(filename):
|
133 |
+
"""Load 3D point cloud from KITTI file format
|
134 |
+
(adapted from https://github.com/hunse/kitti)
|
135 |
+
"""
|
136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
137 |
+
points[:, 3] = 1.0 # homogeneous
|
138 |
+
return points
|
139 |
+
|
140 |
+
|
141 |
+
def read_calib_file(path):
|
142 |
+
"""Read KITTI calibration file
|
143 |
+
(from https://github.com/hunse/kitti)
|
144 |
+
"""
|
145 |
+
float_chars = set("0123456789.e+- ")
|
146 |
+
data = {}
|
147 |
+
with open(path, 'r') as f:
|
148 |
+
for line in f.readlines():
|
149 |
+
key, value = line.split(':', 1)
|
150 |
+
value = value.strip()
|
151 |
+
data[key] = value
|
152 |
+
if float_chars.issuperset(value):
|
153 |
+
# try to cast to float array
|
154 |
+
try:
|
155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
156 |
+
except ValueError:
|
157 |
+
# casting error: data[key] already eq. value, so pass
|
158 |
+
pass
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
|
163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
164 |
+
"""Convert row, col matrix subscripts to linear indices
|
165 |
+
"""
|
166 |
+
m, n = matrixSize
|
167 |
+
return rowSub * (n-1) + colSub - 1
|
168 |
+
|
169 |
+
|
170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
171 |
+
"""Generate a depth map from velodyne data
|
172 |
+
"""
|
173 |
+
# load calibration files
|
174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
178 |
+
|
179 |
+
# get image shape
|
180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
181 |
+
|
182 |
+
# compute projection matrix velodyne->image plane
|
183 |
+
R_cam2rect = np.eye(4)
|
184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
187 |
+
|
188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
190 |
+
velo = load_velodyne_points(velo_filename)
|
191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
192 |
+
|
193 |
+
# project the points to the camera
|
194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
196 |
+
|
197 |
+
if vel_depth:
|
198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
199 |
+
|
200 |
+
# check if in bounds
|
201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
207 |
+
|
208 |
+
# project to image
|
209 |
+
depth = np.zeros((im_shape[:2]))
|
210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
211 |
+
|
212 |
+
# find the duplicate points and choose the closest depth
|
213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
215 |
+
for dd in dupe_inds:
|
216 |
+
pts = np.where(inds == dd)[0]
|
217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
220 |
+
depth[depth < 0] = 0
|
221 |
+
|
222 |
+
return depth
|
223 |
+
|
224 |
+
def pil_loader(path):
|
225 |
+
# open path as file to avoid ResourceWarning
|
226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
227 |
+
with open(path, 'rb') as f:
|
228 |
+
with Image.open(f) as img:
|
229 |
+
return img.convert('RGB')
|
230 |
+
|
231 |
+
|
232 |
+
class MonoDataset(data.Dataset):
|
233 |
+
"""Superclass for monocular dataloaders
|
234 |
+
|
235 |
+
Args:
|
236 |
+
data_path
|
237 |
+
filenames
|
238 |
+
height
|
239 |
+
width
|
240 |
+
frame_idxs
|
241 |
+
num_scales
|
242 |
+
is_train
|
243 |
+
img_ext
|
244 |
+
"""
|
245 |
+
def __init__(self,
|
246 |
+
data_path,
|
247 |
+
filenames,
|
248 |
+
height,
|
249 |
+
width,
|
250 |
+
frame_idxs,
|
251 |
+
num_scales,
|
252 |
+
is_train=False,
|
253 |
+
img_ext='.jpg'):
|
254 |
+
super(MonoDataset, self).__init__()
|
255 |
+
|
256 |
+
self.data_path = data_path
|
257 |
+
self.filenames = filenames
|
258 |
+
self.height = height
|
259 |
+
self.width = width
|
260 |
+
self.num_scales = num_scales
|
261 |
+
self.interp = Image.ANTIALIAS
|
262 |
+
|
263 |
+
self.frame_idxs = frame_idxs
|
264 |
+
|
265 |
+
self.is_train = is_train
|
266 |
+
self.img_ext = img_ext
|
267 |
+
|
268 |
+
self.loader = pil_loader
|
269 |
+
self.to_tensor = transforms.ToTensor()
|
270 |
+
|
271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
273 |
+
try:
|
274 |
+
self.brightness = (0.8, 1.2)
|
275 |
+
self.contrast = (0.8, 1.2)
|
276 |
+
self.saturation = (0.8, 1.2)
|
277 |
+
self.hue = (-0.1, 0.1)
|
278 |
+
transforms.ColorJitter.get_params(
|
279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
280 |
+
except TypeError:
|
281 |
+
self.brightness = 0.2
|
282 |
+
self.contrast = 0.2
|
283 |
+
self.saturation = 0.2
|
284 |
+
self.hue = 0.1
|
285 |
+
|
286 |
+
self.resize = {}
|
287 |
+
for i in range(self.num_scales):
|
288 |
+
s = 2 ** i
|
289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
290 |
+
interpolation=self.interp)
|
291 |
+
|
292 |
+
self.load_depth = self.check_depth()
|
293 |
+
|
294 |
+
def preprocess(self, inputs, color_aug):
|
295 |
+
"""Resize colour images to the required scales and augment if required
|
296 |
+
|
297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
299 |
+
same augmentation.
|
300 |
+
"""
|
301 |
+
for k in list(inputs):
|
302 |
+
frame = inputs[k]
|
303 |
+
if "color" in k:
|
304 |
+
n, im, i = k
|
305 |
+
for i in range(self.num_scales):
|
306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
307 |
+
|
308 |
+
for k in list(inputs):
|
309 |
+
f = inputs[k]
|
310 |
+
if "color" in k:
|
311 |
+
n, im, i = k
|
312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
314 |
+
|
315 |
+
def __len__(self):
|
316 |
+
return len(self.filenames)
|
317 |
+
|
318 |
+
def __getitem__(self, index):
|
319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
320 |
+
|
321 |
+
Values correspond to torch tensors.
|
322 |
+
Keys in the dictionary are either strings or tuples:
|
323 |
+
|
324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
327 |
+
"stereo_T" for camera extrinsics, and
|
328 |
+
"depth_gt" for ground truth depth maps.
|
329 |
+
|
330 |
+
<frame_id> is either:
|
331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
332 |
+
or
|
333 |
+
"s" for the opposite image in the stereo pair.
|
334 |
+
|
335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
336 |
+
-1 images at native resolution as loaded from disk
|
337 |
+
0 images resized to (self.width, self.height )
|
338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
341 |
+
"""
|
342 |
+
inputs = {}
|
343 |
+
|
344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
345 |
+
do_flip = self.is_train and random.random() > 0.5
|
346 |
+
|
347 |
+
line = self.filenames[index].split()
|
348 |
+
folder = line[0]
|
349 |
+
|
350 |
+
if len(line) == 3:
|
351 |
+
frame_index = int(line[1])
|
352 |
+
else:
|
353 |
+
frame_index = 0
|
354 |
+
|
355 |
+
if len(line) == 3:
|
356 |
+
side = line[2]
|
357 |
+
else:
|
358 |
+
side = None
|
359 |
+
|
360 |
+
for i in self.frame_idxs:
|
361 |
+
if i == "s":
|
362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
364 |
+
else:
|
365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
366 |
+
|
367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
368 |
+
for scale in range(self.num_scales):
|
369 |
+
K = self.K.copy()
|
370 |
+
|
371 |
+
K[0, :] *= self.width // (2 ** scale)
|
372 |
+
K[1, :] *= self.height // (2 ** scale)
|
373 |
+
|
374 |
+
inv_K = np.linalg.pinv(K)
|
375 |
+
|
376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
378 |
+
|
379 |
+
if do_color_aug:
|
380 |
+
color_aug = transforms.ColorJitter.get_params(
|
381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
382 |
+
else:
|
383 |
+
color_aug = (lambda x: x)
|
384 |
+
|
385 |
+
self.preprocess(inputs, color_aug)
|
386 |
+
|
387 |
+
for i in self.frame_idxs:
|
388 |
+
del inputs[("color", i, -1)]
|
389 |
+
del inputs[("color_aug", i, -1)]
|
390 |
+
|
391 |
+
if self.load_depth:
|
392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
395 |
+
|
396 |
+
if "s" in self.frame_idxs:
|
397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
398 |
+
baseline_sign = -1 if do_flip else 1
|
399 |
+
side_sign = -1 if side == "l" else 1
|
400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
401 |
+
|
402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
403 |
+
|
404 |
+
return inputs
|
405 |
+
|
406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
407 |
+
raise NotImplementedError
|
408 |
+
|
409 |
+
def check_depth(self):
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
class KITTIDataset(MonoDataset):
|
416 |
+
"""Superclass for different types of KITTI dataset loaders
|
417 |
+
"""
|
418 |
+
def __init__(self, *args, **kwargs):
|
419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
420 |
+
|
421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
425 |
+
# flip augmentation.
|
426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
427 |
+
[0, 1.92, 0.5, 0],
|
428 |
+
[0, 0, 1, 0],
|
429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
430 |
+
|
431 |
+
self.full_res_shape = (1242, 375)
|
432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
433 |
+
|
434 |
+
def check_depth(self):
|
435 |
+
line = self.filenames[0].split()
|
436 |
+
scene_name = line[0]
|
437 |
+
frame_index = int(line[1])
|
438 |
+
|
439 |
+
velo_filename = os.path.join(
|
440 |
+
self.data_path,
|
441 |
+
scene_name,
|
442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
443 |
+
|
444 |
+
return os.path.isfile(velo_filename)
|
445 |
+
|
446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
448 |
+
|
449 |
+
if do_flip:
|
450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
451 |
+
|
452 |
+
return color
|
453 |
+
|
454 |
+
|
455 |
+
class KITTIDepthDataset(KITTIDataset):
|
456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
457 |
+
"""
|
458 |
+
def __init__(self, *args, **kwargs):
|
459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
460 |
+
|
461 |
+
def get_image_path(self, folder, frame_index, side):
|
462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
463 |
+
image_path = os.path.join(
|
464 |
+
self.data_path,
|
465 |
+
folder,
|
466 |
+
"image_0{}/data".format(self.side_map[side]),
|
467 |
+
f_str)
|
468 |
+
return image_path
|
469 |
+
|
470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
471 |
+
f_str = "{:010d}.png".format(frame_index)
|
472 |
+
depth_path = os.path.join(
|
473 |
+
self.data_path,
|
474 |
+
folder,
|
475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
476 |
+
f_str)
|
477 |
+
|
478 |
+
depth_gt = Image.open(depth_path)
|
479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
481 |
+
|
482 |
+
if do_flip:
|
483 |
+
depth_gt = np.fliplr(depth_gt)
|
484 |
+
|
485 |
+
return depth_gt
|
losses.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
class ContentLoss(nn.Module):
|
8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
10 |
+
|
11 |
+
Paper reference list:
|
12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
super(ContentLoss, self).__init__()
|
20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
24 |
+
# Freeze model parameters.
|
25 |
+
for parameters in self.feature_extractor.parameters():
|
26 |
+
parameters.requires_grad = False
|
27 |
+
|
28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
31 |
+
|
32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
33 |
+
# Standardized operations
|
34 |
+
sr = sr.sub(self.mean).div(self.std)
|
35 |
+
hr = hr.sub(self.mean).div(self.std)
|
36 |
+
|
37 |
+
# Find the feature map difference between the two images
|
38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
39 |
+
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
class GenGaussLoss(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, reduction='mean',
|
46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
47 |
+
resi_min = 1e-4, resi_max=1e3
|
48 |
+
) -> None:
|
49 |
+
super(GenGaussLoss, self).__init__()
|
50 |
+
self.reduction = reduction
|
51 |
+
self.alpha_eps = alpha_eps
|
52 |
+
self.beta_eps = beta_eps
|
53 |
+
self.resi_min = resi_min
|
54 |
+
self.resi_max = resi_max
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
59 |
+
):
|
60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
61 |
+
beta1 = beta + self.beta_eps
|
62 |
+
|
63 |
+
resi = torch.abs(mean - target)
|
64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
66 |
+
## check if resi has nans
|
67 |
+
if torch.sum(resi != resi) > 0:
|
68 |
+
print('resi has nans!!')
|
69 |
+
return None
|
70 |
+
|
71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
72 |
+
log_beta = torch.log(beta1)
|
73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
74 |
+
|
75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
76 |
+
print('log_one_over_alpha has nan')
|
77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
78 |
+
print('lgamma_beta has nan')
|
79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
80 |
+
print('log_beta has nan')
|
81 |
+
|
82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
83 |
+
|
84 |
+
if self.reduction == 'mean':
|
85 |
+
return l.mean()
|
86 |
+
elif self.reduction == 'sum':
|
87 |
+
return l.sum()
|
88 |
+
else:
|
89 |
+
print('Reduction not supported')
|
90 |
+
return None
|
91 |
+
|
92 |
+
class TempCombLoss(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self, reduction='mean',
|
95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
96 |
+
resi_min = 1e-4, resi_max=1e3
|
97 |
+
) -> None:
|
98 |
+
super(TempCombLoss, self).__init__()
|
99 |
+
self.reduction = reduction
|
100 |
+
self.alpha_eps = alpha_eps
|
101 |
+
self.beta_eps = beta_eps
|
102 |
+
self.resi_min = resi_min
|
103 |
+
self.resi_max = resi_max
|
104 |
+
|
105 |
+
self.L_GenGauss = GenGaussLoss(
|
106 |
+
reduction=self.reduction,
|
107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
109 |
+
)
|
110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
115 |
+
T1: float, T2: float
|
116 |
+
):
|
117 |
+
l1 = self.L_l1(mean, target)
|
118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
119 |
+
l = T1*l1 + T2*l2
|
120 |
+
|
121 |
+
return l
|
122 |
+
|
123 |
+
|
124 |
+
# x1 = torch.randn(4,3,32,32)
|
125 |
+
# x2 = torch.rand(4,3,32,32)
|
126 |
+
# x3 = torch.rand(4,3,32,32)
|
127 |
+
# x4 = torch.randn(4,3,32,32)
|
128 |
+
|
129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
networks_SRGAN.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
# __all__ = [
|
8 |
+
# "ResidualConvBlock",
|
9 |
+
# "Discriminator", "Generator",
|
10 |
+
# ]
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualConvBlock(nn.Module):
|
14 |
+
"""Implements residual conv function.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
channels (int): Number of channels in the input image.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels: int) -> None:
|
21 |
+
super(ResidualConvBlock, self).__init__()
|
22 |
+
self.rcb = nn.Sequential(
|
23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
24 |
+
nn.BatchNorm2d(channels),
|
25 |
+
nn.PReLU(),
|
26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
27 |
+
nn.BatchNorm2d(channels),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
identity = x
|
32 |
+
|
33 |
+
out = self.rcb(x)
|
34 |
+
out = torch.add(out, identity)
|
35 |
+
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
class Discriminator(nn.Module):
|
40 |
+
def __init__(self) -> None:
|
41 |
+
super(Discriminator, self).__init__()
|
42 |
+
self.features = nn.Sequential(
|
43 |
+
# input size. (3) x 96 x 96
|
44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
45 |
+
nn.LeakyReLU(0.2, True),
|
46 |
+
# state size. (64) x 48 x 48
|
47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
48 |
+
nn.BatchNorm2d(64),
|
49 |
+
nn.LeakyReLU(0.2, True),
|
50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
51 |
+
nn.BatchNorm2d(128),
|
52 |
+
nn.LeakyReLU(0.2, True),
|
53 |
+
# state size. (128) x 24 x 24
|
54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
55 |
+
nn.BatchNorm2d(128),
|
56 |
+
nn.LeakyReLU(0.2, True),
|
57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
58 |
+
nn.BatchNorm2d(256),
|
59 |
+
nn.LeakyReLU(0.2, True),
|
60 |
+
# state size. (256) x 12 x 12
|
61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
62 |
+
nn.BatchNorm2d(256),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
65 |
+
nn.BatchNorm2d(512),
|
66 |
+
nn.LeakyReLU(0.2, True),
|
67 |
+
# state size. (512) x 6 x 6
|
68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
69 |
+
nn.BatchNorm2d(512),
|
70 |
+
nn.LeakyReLU(0.2, True),
|
71 |
+
)
|
72 |
+
|
73 |
+
self.classifier = nn.Sequential(
|
74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
75 |
+
nn.LeakyReLU(0.2, True),
|
76 |
+
nn.Linear(1024, 1),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor:
|
80 |
+
out = self.features(x)
|
81 |
+
out = torch.flatten(out, 1)
|
82 |
+
out = self.classifier(out)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
def __init__(self) -> None:
|
89 |
+
super(Generator, self).__init__()
|
90 |
+
# First conv layer.
|
91 |
+
self.conv_block1 = nn.Sequential(
|
92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
93 |
+
nn.PReLU(),
|
94 |
+
)
|
95 |
+
|
96 |
+
# Features trunk blocks.
|
97 |
+
trunk = []
|
98 |
+
for _ in range(16):
|
99 |
+
trunk.append(ResidualConvBlock(64))
|
100 |
+
self.trunk = nn.Sequential(*trunk)
|
101 |
+
|
102 |
+
# Second conv layer.
|
103 |
+
self.conv_block2 = nn.Sequential(
|
104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
105 |
+
nn.BatchNorm2d(64),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Upscale conv block.
|
109 |
+
self.upsampling = nn.Sequential(
|
110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
111 |
+
nn.PixelShuffle(2),
|
112 |
+
nn.PReLU(),
|
113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
114 |
+
nn.PixelShuffle(2),
|
115 |
+
nn.PReLU(),
|
116 |
+
)
|
117 |
+
|
118 |
+
# Output layer.
|
119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
120 |
+
|
121 |
+
# Initialize neural network weights.
|
122 |
+
self._initialize_weights()
|
123 |
+
|
124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
125 |
+
if not dop:
|
126 |
+
return self._forward_impl(x)
|
127 |
+
else:
|
128 |
+
return self._forward_w_dop_impl(x, dop)
|
129 |
+
|
130 |
+
# Support torch.script function.
|
131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
132 |
+
out1 = self.conv_block1(x)
|
133 |
+
out = self.trunk(out1)
|
134 |
+
out2 = self.conv_block2(out)
|
135 |
+
out = torch.add(out1, out2)
|
136 |
+
out = self.upsampling(out)
|
137 |
+
out = self.conv_block3(out)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
142 |
+
out1 = self.conv_block1(x)
|
143 |
+
out = self.trunk(out1)
|
144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
145 |
+
out = torch.add(out1, out2)
|
146 |
+
out = self.upsampling(out)
|
147 |
+
out = self.conv_block3(out)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
def _initialize_weights(self) -> None:
|
152 |
+
for module in self.modules():
|
153 |
+
if isinstance(module, nn.Conv2d):
|
154 |
+
nn.init.kaiming_normal_(module.weight)
|
155 |
+
if module.bias is not None:
|
156 |
+
nn.init.constant_(module.bias, 0)
|
157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
158 |
+
nn.init.constant_(module.weight, 1)
|
159 |
+
|
160 |
+
|
161 |
+
#### BayesCap
|
162 |
+
class BayesCap(nn.Module):
|
163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
164 |
+
super(BayesCap, self).__init__()
|
165 |
+
# First conv layer.
|
166 |
+
self.conv_block1 = nn.Sequential(
|
167 |
+
nn.Conv2d(
|
168 |
+
in_channels, 64,
|
169 |
+
kernel_size=9, stride=1, padding=4
|
170 |
+
),
|
171 |
+
nn.PReLU(),
|
172 |
+
)
|
173 |
+
|
174 |
+
# Features trunk blocks.
|
175 |
+
trunk = []
|
176 |
+
for _ in range(16):
|
177 |
+
trunk.append(ResidualConvBlock(64))
|
178 |
+
self.trunk = nn.Sequential(*trunk)
|
179 |
+
|
180 |
+
# Second conv layer.
|
181 |
+
self.conv_block2 = nn.Sequential(
|
182 |
+
nn.Conv2d(
|
183 |
+
64, 64,
|
184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
185 |
+
),
|
186 |
+
nn.BatchNorm2d(64),
|
187 |
+
)
|
188 |
+
|
189 |
+
# Output layer.
|
190 |
+
self.conv_block3_mu = nn.Conv2d(
|
191 |
+
64, out_channels=out_channels,
|
192 |
+
kernel_size=9, stride=1, padding=4
|
193 |
+
)
|
194 |
+
self.conv_block3_alpha = nn.Sequential(
|
195 |
+
nn.Conv2d(
|
196 |
+
64, 64,
|
197 |
+
kernel_size=9, stride=1, padding=4
|
198 |
+
),
|
199 |
+
nn.PReLU(),
|
200 |
+
nn.Conv2d(
|
201 |
+
64, 64,
|
202 |
+
kernel_size=9, stride=1, padding=4
|
203 |
+
),
|
204 |
+
nn.PReLU(),
|
205 |
+
nn.Conv2d(
|
206 |
+
64, 1,
|
207 |
+
kernel_size=9, stride=1, padding=4
|
208 |
+
),
|
209 |
+
nn.ReLU(),
|
210 |
+
)
|
211 |
+
self.conv_block3_beta = nn.Sequential(
|
212 |
+
nn.Conv2d(
|
213 |
+
64, 64,
|
214 |
+
kernel_size=9, stride=1, padding=4
|
215 |
+
),
|
216 |
+
nn.PReLU(),
|
217 |
+
nn.Conv2d(
|
218 |
+
64, 64,
|
219 |
+
kernel_size=9, stride=1, padding=4
|
220 |
+
),
|
221 |
+
nn.PReLU(),
|
222 |
+
nn.Conv2d(
|
223 |
+
64, 1,
|
224 |
+
kernel_size=9, stride=1, padding=4
|
225 |
+
),
|
226 |
+
nn.ReLU(),
|
227 |
+
)
|
228 |
+
|
229 |
+
# Initialize neural network weights.
|
230 |
+
self._initialize_weights()
|
231 |
+
|
232 |
+
def forward(self, x: Tensor) -> Tensor:
|
233 |
+
return self._forward_impl(x)
|
234 |
+
|
235 |
+
# Support torch.script function.
|
236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
237 |
+
out1 = self.conv_block1(x)
|
238 |
+
out = self.trunk(out1)
|
239 |
+
out2 = self.conv_block2(out)
|
240 |
+
out = out1 + out2
|
241 |
+
out_mu = self.conv_block3_mu(out)
|
242 |
+
out_alpha = self.conv_block3_alpha(out)
|
243 |
+
out_beta = self.conv_block3_beta(out)
|
244 |
+
return out_mu, out_alpha, out_beta
|
245 |
+
|
246 |
+
def _initialize_weights(self) -> None:
|
247 |
+
for module in self.modules():
|
248 |
+
if isinstance(module, nn.Conv2d):
|
249 |
+
nn.init.kaiming_normal_(module.weight)
|
250 |
+
if module.bias is not None:
|
251 |
+
nn.init.constant_(module.bias, 0)
|
252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
253 |
+
nn.init.constant_(module.weight, 1)
|
254 |
+
|
255 |
+
|
256 |
+
class BayesCap_noID(nn.Module):
|
257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
258 |
+
super(BayesCap_noID, self).__init__()
|
259 |
+
# First conv layer.
|
260 |
+
self.conv_block1 = nn.Sequential(
|
261 |
+
nn.Conv2d(
|
262 |
+
in_channels, 64,
|
263 |
+
kernel_size=9, stride=1, padding=4
|
264 |
+
),
|
265 |
+
nn.PReLU(),
|
266 |
+
)
|
267 |
+
|
268 |
+
# Features trunk blocks.
|
269 |
+
trunk = []
|
270 |
+
for _ in range(16):
|
271 |
+
trunk.append(ResidualConvBlock(64))
|
272 |
+
self.trunk = nn.Sequential(*trunk)
|
273 |
+
|
274 |
+
# Second conv layer.
|
275 |
+
self.conv_block2 = nn.Sequential(
|
276 |
+
nn.Conv2d(
|
277 |
+
64, 64,
|
278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
279 |
+
),
|
280 |
+
nn.BatchNorm2d(64),
|
281 |
+
)
|
282 |
+
|
283 |
+
# Output layer.
|
284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
285 |
+
# 64, out_channels=out_channels,
|
286 |
+
# kernel_size=9, stride=1, padding=4
|
287 |
+
# )
|
288 |
+
self.conv_block3_alpha = nn.Sequential(
|
289 |
+
nn.Conv2d(
|
290 |
+
64, 64,
|
291 |
+
kernel_size=9, stride=1, padding=4
|
292 |
+
),
|
293 |
+
nn.PReLU(),
|
294 |
+
nn.Conv2d(
|
295 |
+
64, 64,
|
296 |
+
kernel_size=9, stride=1, padding=4
|
297 |
+
),
|
298 |
+
nn.PReLU(),
|
299 |
+
nn.Conv2d(
|
300 |
+
64, 1,
|
301 |
+
kernel_size=9, stride=1, padding=4
|
302 |
+
),
|
303 |
+
nn.ReLU(),
|
304 |
+
)
|
305 |
+
self.conv_block3_beta = nn.Sequential(
|
306 |
+
nn.Conv2d(
|
307 |
+
64, 64,
|
308 |
+
kernel_size=9, stride=1, padding=4
|
309 |
+
),
|
310 |
+
nn.PReLU(),
|
311 |
+
nn.Conv2d(
|
312 |
+
64, 64,
|
313 |
+
kernel_size=9, stride=1, padding=4
|
314 |
+
),
|
315 |
+
nn.PReLU(),
|
316 |
+
nn.Conv2d(
|
317 |
+
64, 1,
|
318 |
+
kernel_size=9, stride=1, padding=4
|
319 |
+
),
|
320 |
+
nn.ReLU(),
|
321 |
+
)
|
322 |
+
|
323 |
+
# Initialize neural network weights.
|
324 |
+
self._initialize_weights()
|
325 |
+
|
326 |
+
def forward(self, x: Tensor) -> Tensor:
|
327 |
+
return self._forward_impl(x)
|
328 |
+
|
329 |
+
# Support torch.script function.
|
330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
331 |
+
out1 = self.conv_block1(x)
|
332 |
+
out = self.trunk(out1)
|
333 |
+
out2 = self.conv_block2(out)
|
334 |
+
out = out1 + out2
|
335 |
+
# out_mu = self.conv_block3_mu(out)
|
336 |
+
out_alpha = self.conv_block3_alpha(out)
|
337 |
+
out_beta = self.conv_block3_beta(out)
|
338 |
+
return out_alpha, out_beta
|
339 |
+
|
340 |
+
def _initialize_weights(self) -> None:
|
341 |
+
for module in self.modules():
|
342 |
+
if isinstance(module, nn.Conv2d):
|
343 |
+
nn.init.kaiming_normal_(module.weight)
|
344 |
+
if module.bias is not None:
|
345 |
+
nn.init.constant_(module.bias, 0)
|
346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
347 |
+
nn.init.constant_(module.weight, 1)
|
networks_T1toT2.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
|
6 |
+
### components
|
7 |
+
class ResConv(nn.Module):
|
8 |
+
"""
|
9 |
+
Residual convolutional block, where
|
10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
11 |
+
residual connection adds the input to the output
|
12 |
+
"""
|
13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
14 |
+
super().__init__()
|
15 |
+
if not mid_channels:
|
16 |
+
mid_channels = out_channels
|
17 |
+
self.double_conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
19 |
+
nn.BatchNorm2d(mid_channels),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
22 |
+
nn.BatchNorm2d(mid_channels),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
25 |
+
nn.BatchNorm2d(out_channels),
|
26 |
+
nn.ReLU(inplace=True)
|
27 |
+
)
|
28 |
+
self.double_conv1 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
30 |
+
nn.BatchNorm2d(out_channels),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
x_in = self.double_conv1(x)
|
35 |
+
x1 = self.double_conv(x)
|
36 |
+
return self.double_conv(x) + x_in
|
37 |
+
|
38 |
+
class Down(nn.Module):
|
39 |
+
"""Downscaling with maxpool then Resconv"""
|
40 |
+
def __init__(self, in_channels, out_channels):
|
41 |
+
super().__init__()
|
42 |
+
self.maxpool_conv = nn.Sequential(
|
43 |
+
nn.MaxPool2d(2),
|
44 |
+
ResConv(in_channels, out_channels)
|
45 |
+
)
|
46 |
+
def forward(self, x):
|
47 |
+
return self.maxpool_conv(x)
|
48 |
+
|
49 |
+
class Up(nn.Module):
|
50 |
+
"""Upscaling then double conv"""
|
51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
52 |
+
super().__init__()
|
53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
54 |
+
if bilinear:
|
55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
57 |
+
else:
|
58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
59 |
+
self.conv = ResConv(in_channels, out_channels)
|
60 |
+
def forward(self, x1, x2):
|
61 |
+
x1 = self.up(x1)
|
62 |
+
# input is CHW
|
63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
65 |
+
x1 = F.pad(
|
66 |
+
x1,
|
67 |
+
[
|
68 |
+
diffX // 2, diffX - diffX // 2,
|
69 |
+
diffY // 2, diffY - diffY // 2
|
70 |
+
]
|
71 |
+
)
|
72 |
+
# if you have padding issues, see
|
73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
75 |
+
x = torch.cat([x2, x1], dim=1)
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
class OutConv(nn.Module):
|
79 |
+
def __init__(self, in_channels, out_channels):
|
80 |
+
super(OutConv, self).__init__()
|
81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
82 |
+
def forward(self, x):
|
83 |
+
# return F.relu(self.conv(x))
|
84 |
+
return self.conv(x)
|
85 |
+
|
86 |
+
##### The composite networks
|
87 |
+
class UNet(nn.Module):
|
88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
89 |
+
super(UNet, self).__init__()
|
90 |
+
self.n_channels = n_channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
self.bilinear = bilinear
|
93 |
+
####
|
94 |
+
self.inc = ResConv(n_channels, 64)
|
95 |
+
self.down1 = Down(64, 128)
|
96 |
+
self.down2 = Down(128, 256)
|
97 |
+
self.down3 = Down(256, 512)
|
98 |
+
factor = 2 if bilinear else 1
|
99 |
+
self.down4 = Down(512, 1024 // factor)
|
100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
103 |
+
self.up4 = Up(128, 64, bilinear)
|
104 |
+
self.outc = OutConv(64, out_channels)
|
105 |
+
def forward(self, x):
|
106 |
+
x1 = self.inc(x)
|
107 |
+
x2 = self.down1(x1)
|
108 |
+
x3 = self.down2(x2)
|
109 |
+
x4 = self.down3(x3)
|
110 |
+
x5 = self.down4(x4)
|
111 |
+
x = self.up1(x5, x4)
|
112 |
+
x = self.up2(x, x3)
|
113 |
+
x = self.up3(x, x2)
|
114 |
+
x = self.up4(x, x1)
|
115 |
+
y = self.outc(x)
|
116 |
+
return y
|
117 |
+
|
118 |
+
class CasUNet(nn.Module):
|
119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
120 |
+
super(CasUNet, self).__init__()
|
121 |
+
self.n_unet = n_unet
|
122 |
+
self.io_channels = io_channels
|
123 |
+
self.bilinear = bilinear
|
124 |
+
####
|
125 |
+
self.unet_list = nn.ModuleList()
|
126 |
+
for i in range(self.n_unet):
|
127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
128 |
+
def forward(self, x, dop=None):
|
129 |
+
y = x
|
130 |
+
for i in range(self.n_unet):
|
131 |
+
if i==0:
|
132 |
+
if dop is not None:
|
133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
134 |
+
else:
|
135 |
+
y = self.unet_list[i](y)
|
136 |
+
else:
|
137 |
+
y = self.unet_list[i](y+x)
|
138 |
+
return y
|
139 |
+
|
140 |
+
class CasUNet_2head(nn.Module):
|
141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
142 |
+
super(CasUNet_2head, self).__init__()
|
143 |
+
self.n_unet = n_unet
|
144 |
+
self.io_channels = io_channels
|
145 |
+
self.bilinear = bilinear
|
146 |
+
####
|
147 |
+
self.unet_list = nn.ModuleList()
|
148 |
+
for i in range(self.n_unet):
|
149 |
+
if i != self.n_unet-1:
|
150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
151 |
+
else:
|
152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
153 |
+
def forward(self, x):
|
154 |
+
y = x
|
155 |
+
for i in range(self.n_unet):
|
156 |
+
if i==0:
|
157 |
+
y = self.unet_list[i](y)
|
158 |
+
else:
|
159 |
+
y = self.unet_list[i](y+x)
|
160 |
+
y_mean, y_sigma = y[0], y[1]
|
161 |
+
return y_mean, y_sigma
|
162 |
+
|
163 |
+
class CasUNet_3head(nn.Module):
|
164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
165 |
+
super(CasUNet_3head, self).__init__()
|
166 |
+
self.n_unet = n_unet
|
167 |
+
self.io_channels = io_channels
|
168 |
+
self.bilinear = bilinear
|
169 |
+
####
|
170 |
+
self.unet_list = nn.ModuleList()
|
171 |
+
for i in range(self.n_unet):
|
172 |
+
if i != self.n_unet-1:
|
173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
174 |
+
else:
|
175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
176 |
+
def forward(self, x):
|
177 |
+
y = x
|
178 |
+
for i in range(self.n_unet):
|
179 |
+
if i==0:
|
180 |
+
y = self.unet_list[i](y)
|
181 |
+
else:
|
182 |
+
y = self.unet_list[i](y+x)
|
183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
184 |
+
return y_mean, y_alpha, y_beta
|
185 |
+
|
186 |
+
class UNet_2head(nn.Module):
|
187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
188 |
+
super(UNet_2head, self).__init__()
|
189 |
+
self.n_channels = n_channels
|
190 |
+
self.out_channels = out_channels
|
191 |
+
self.bilinear = bilinear
|
192 |
+
####
|
193 |
+
self.inc = ResConv(n_channels, 64)
|
194 |
+
self.down1 = Down(64, 128)
|
195 |
+
self.down2 = Down(128, 256)
|
196 |
+
self.down3 = Down(256, 512)
|
197 |
+
factor = 2 if bilinear else 1
|
198 |
+
self.down4 = Down(512, 1024 // factor)
|
199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
202 |
+
self.up4 = Up(128, 64, bilinear)
|
203 |
+
#per pixel multiple channels may exist
|
204 |
+
self.out_mean = OutConv(64, out_channels)
|
205 |
+
#variance will always be a single number for a pixel
|
206 |
+
self.out_var = nn.Sequential(
|
207 |
+
OutConv(64, 128),
|
208 |
+
OutConv(128, 1),
|
209 |
+
)
|
210 |
+
def forward(self, x):
|
211 |
+
x1 = self.inc(x)
|
212 |
+
x2 = self.down1(x1)
|
213 |
+
x3 = self.down2(x2)
|
214 |
+
x4 = self.down3(x3)
|
215 |
+
x5 = self.down4(x4)
|
216 |
+
x = self.up1(x5, x4)
|
217 |
+
x = self.up2(x, x3)
|
218 |
+
x = self.up3(x, x2)
|
219 |
+
x = self.up4(x, x1)
|
220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
221 |
+
return y_mean, y_var
|
222 |
+
|
223 |
+
class UNet_3head(nn.Module):
|
224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
225 |
+
super(UNet_3head, self).__init__()
|
226 |
+
self.n_channels = n_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.bilinear = bilinear
|
229 |
+
####
|
230 |
+
self.inc = ResConv(n_channels, 64)
|
231 |
+
self.down1 = Down(64, 128)
|
232 |
+
self.down2 = Down(128, 256)
|
233 |
+
self.down3 = Down(256, 512)
|
234 |
+
factor = 2 if bilinear else 1
|
235 |
+
self.down4 = Down(512, 1024 // factor)
|
236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
239 |
+
self.up4 = Up(128, 64, bilinear)
|
240 |
+
#per pixel multiple channels may exist
|
241 |
+
self.out_mean = OutConv(64, out_channels)
|
242 |
+
#variance will always be a single number for a pixel
|
243 |
+
self.out_alpha = nn.Sequential(
|
244 |
+
OutConv(64, 128),
|
245 |
+
OutConv(128, 1),
|
246 |
+
nn.ReLU()
|
247 |
+
)
|
248 |
+
self.out_beta = nn.Sequential(
|
249 |
+
OutConv(64, 128),
|
250 |
+
OutConv(128, 1),
|
251 |
+
nn.ReLU()
|
252 |
+
)
|
253 |
+
def forward(self, x):
|
254 |
+
x1 = self.inc(x)
|
255 |
+
x2 = self.down1(x1)
|
256 |
+
x3 = self.down2(x2)
|
257 |
+
x4 = self.down3(x3)
|
258 |
+
x5 = self.down4(x4)
|
259 |
+
x = self.up1(x5, x4)
|
260 |
+
x = self.up2(x, x3)
|
261 |
+
x = self.up3(x, x2)
|
262 |
+
x = self.up4(x, x1)
|
263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
264 |
+
self.out_alpha(x), self.out_beta(x)
|
265 |
+
return y_mean, y_alpha, y_beta
|
266 |
+
|
267 |
+
class ResidualBlock(nn.Module):
|
268 |
+
def __init__(self, in_features):
|
269 |
+
super(ResidualBlock, self).__init__()
|
270 |
+
conv_block = [
|
271 |
+
nn.ReflectionPad2d(1),
|
272 |
+
nn.Conv2d(in_features, in_features, 3),
|
273 |
+
nn.InstanceNorm2d(in_features),
|
274 |
+
nn.ReLU(inplace=True),
|
275 |
+
nn.ReflectionPad2d(1),
|
276 |
+
nn.Conv2d(in_features, in_features, 3),
|
277 |
+
nn.InstanceNorm2d(in_features)
|
278 |
+
]
|
279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
280 |
+
def forward(self, x):
|
281 |
+
return x + self.conv_block(x)
|
282 |
+
|
283 |
+
class Generator(nn.Module):
|
284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
285 |
+
super(Generator, self).__init__()
|
286 |
+
# Initial convolution block
|
287 |
+
model = [
|
288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
290 |
+
]
|
291 |
+
# Downsampling
|
292 |
+
in_features = 64
|
293 |
+
out_features = in_features*2
|
294 |
+
for _ in range(2):
|
295 |
+
model += [
|
296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
297 |
+
nn.InstanceNorm2d(out_features),
|
298 |
+
nn.ReLU(inplace=True)
|
299 |
+
]
|
300 |
+
in_features = out_features
|
301 |
+
out_features = in_features*2
|
302 |
+
# Residual blocks
|
303 |
+
for _ in range(n_residual_blocks):
|
304 |
+
model += [ResidualBlock(in_features)]
|
305 |
+
# Upsampling
|
306 |
+
out_features = in_features//2
|
307 |
+
for _ in range(2):
|
308 |
+
model += [
|
309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
310 |
+
nn.InstanceNorm2d(out_features),
|
311 |
+
nn.ReLU(inplace=True)
|
312 |
+
]
|
313 |
+
in_features = out_features
|
314 |
+
out_features = in_features//2
|
315 |
+
# Output layer
|
316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
317 |
+
self.model = nn.Sequential(*model)
|
318 |
+
def forward(self, x):
|
319 |
+
return self.model(x)
|
320 |
+
|
321 |
+
|
322 |
+
class ResnetGenerator(nn.Module):
|
323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
328 |
+
"""Construct a Resnet-based generator
|
329 |
+
Parameters:
|
330 |
+
input_nc (int) -- the number of channels in input images
|
331 |
+
output_nc (int) -- the number of channels in output images
|
332 |
+
ngf (int) -- the number of filters in the last conv layer
|
333 |
+
norm_layer -- normalization layer
|
334 |
+
use_dropout (bool) -- if use dropout layers
|
335 |
+
n_blocks (int) -- the number of ResNet blocks
|
336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
337 |
+
"""
|
338 |
+
assert(n_blocks >= 0)
|
339 |
+
super(ResnetGenerator, self).__init__()
|
340 |
+
if type(norm_layer) == functools.partial:
|
341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
342 |
+
else:
|
343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
344 |
+
|
345 |
+
model = [nn.ReflectionPad2d(3),
|
346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
347 |
+
norm_layer(ngf),
|
348 |
+
nn.ReLU(True)]
|
349 |
+
|
350 |
+
n_downsampling = 2
|
351 |
+
for i in range(n_downsampling): # add downsampling layers
|
352 |
+
mult = 2 ** i
|
353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
354 |
+
norm_layer(ngf * mult * 2),
|
355 |
+
nn.ReLU(True)]
|
356 |
+
|
357 |
+
mult = 2 ** n_downsampling
|
358 |
+
for i in range(n_blocks): # add ResNet blocks
|
359 |
+
|
360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
361 |
+
|
362 |
+
for i in range(n_downsampling): # add upsampling layers
|
363 |
+
mult = 2 ** (n_downsampling - i)
|
364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
365 |
+
kernel_size=3, stride=2,
|
366 |
+
padding=1, output_padding=1,
|
367 |
+
bias=use_bias),
|
368 |
+
norm_layer(int(ngf * mult / 2)),
|
369 |
+
nn.ReLU(True)]
|
370 |
+
model += [nn.ReflectionPad2d(3)]
|
371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
372 |
+
model += [nn.Tanh()]
|
373 |
+
|
374 |
+
self.model = nn.Sequential(*model)
|
375 |
+
|
376 |
+
def forward(self, input):
|
377 |
+
"""Standard forward"""
|
378 |
+
return self.model(input)
|
379 |
+
|
380 |
+
|
381 |
+
class ResnetBlock(nn.Module):
|
382 |
+
"""Define a Resnet block"""
|
383 |
+
|
384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
385 |
+
"""Initialize the Resnet block
|
386 |
+
A resnet block is a conv block with skip connections
|
387 |
+
We construct a conv block with build_conv_block function,
|
388 |
+
and implement skip connections in <forward> function.
|
389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
390 |
+
"""
|
391 |
+
super(ResnetBlock, self).__init__()
|
392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
393 |
+
|
394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
395 |
+
"""Construct a convolutional block.
|
396 |
+
Parameters:
|
397 |
+
dim (int) -- the number of channels in the conv layer.
|
398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
399 |
+
norm_layer -- normalization layer
|
400 |
+
use_dropout (bool) -- if use dropout layers.
|
401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
403 |
+
"""
|
404 |
+
conv_block = []
|
405 |
+
p = 0
|
406 |
+
if padding_type == 'reflect':
|
407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
408 |
+
elif padding_type == 'replicate':
|
409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
410 |
+
elif padding_type == 'zero':
|
411 |
+
p = 1
|
412 |
+
else:
|
413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
414 |
+
|
415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
416 |
+
if use_dropout:
|
417 |
+
conv_block += [nn.Dropout(0.5)]
|
418 |
+
|
419 |
+
p = 0
|
420 |
+
if padding_type == 'reflect':
|
421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
422 |
+
elif padding_type == 'replicate':
|
423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
424 |
+
elif padding_type == 'zero':
|
425 |
+
p = 1
|
426 |
+
else:
|
427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
429 |
+
|
430 |
+
return nn.Sequential(*conv_block)
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
"""Forward function (with skip connections)"""
|
434 |
+
out = x + self.conv_block(x) # add skip connections
|
435 |
+
return out
|
436 |
+
|
437 |
+
### discriminator
|
438 |
+
class NLayerDiscriminator(nn.Module):
|
439 |
+
"""Defines a PatchGAN discriminator"""
|
440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
441 |
+
"""Construct a PatchGAN discriminator
|
442 |
+
Parameters:
|
443 |
+
input_nc (int) -- the number of channels in input images
|
444 |
+
ndf (int) -- the number of filters in the last conv layer
|
445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
446 |
+
norm_layer -- normalization layer
|
447 |
+
"""
|
448 |
+
super(NLayerDiscriminator, self).__init__()
|
449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
451 |
+
else:
|
452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
453 |
+
kw = 4
|
454 |
+
padw = 1
|
455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
456 |
+
nf_mult = 1
|
457 |
+
nf_mult_prev = 1
|
458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
459 |
+
nf_mult_prev = nf_mult
|
460 |
+
nf_mult = min(2 ** n, 8)
|
461 |
+
sequence += [
|
462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
463 |
+
norm_layer(ndf * nf_mult),
|
464 |
+
nn.LeakyReLU(0.2, True)
|
465 |
+
]
|
466 |
+
nf_mult_prev = nf_mult
|
467 |
+
nf_mult = min(2 ** n_layers, 8)
|
468 |
+
sequence += [
|
469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
470 |
+
norm_layer(ndf * nf_mult),
|
471 |
+
nn.LeakyReLU(0.2, True)
|
472 |
+
]
|
473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
474 |
+
self.model = nn.Sequential(*sequence)
|
475 |
+
def forward(self, input):
|
476 |
+
"""Standard forward."""
|
477 |
+
return self.model(input)
|
requirements.txt
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.1
|
2 |
+
aiosignal==1.2.0
|
3 |
+
albumentations
|
4 |
+
analytics-python==1.4.0
|
5 |
+
anyio==3.6.1
|
6 |
+
argon2-cffi==21.3.0
|
7 |
+
argon2-cffi-bindings==21.2.0
|
8 |
+
asttokens==2.0.5
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==21.4.0
|
11 |
+
Babel==2.10.1
|
12 |
+
backcall==0.2.0
|
13 |
+
backoff==1.10.0
|
14 |
+
bcrypt==3.2.2
|
15 |
+
beautifulsoup4==4.11.1
|
16 |
+
bleach==5.0.0
|
17 |
+
brotlipy==0.7.0
|
18 |
+
certifi
|
19 |
+
cffi
|
20 |
+
charset-normalizer
|
21 |
+
click==8.1.3
|
22 |
+
cloudpickle
|
23 |
+
cryptography
|
24 |
+
cycler==0.11.0
|
25 |
+
cytoolz==0.11.2
|
26 |
+
dask
|
27 |
+
debugpy==1.6.0
|
28 |
+
decorator==5.1.1
|
29 |
+
defusedxml==0.7.1
|
30 |
+
entrypoints==0.4
|
31 |
+
executing==0.8.3
|
32 |
+
fastapi==0.78.0
|
33 |
+
fastjsonschema==2.15.3
|
34 |
+
ffmpy==0.3.0
|
35 |
+
filelock==3.7.1
|
36 |
+
fire==0.4.0
|
37 |
+
fonttools==4.33.3
|
38 |
+
frozenlist==1.3.0
|
39 |
+
fsspec
|
40 |
+
ftfy==6.1.1
|
41 |
+
gdown==4.5.1
|
42 |
+
gradio==3.0.24
|
43 |
+
h11==0.12.0
|
44 |
+
httpcore==0.15.0
|
45 |
+
httpx==0.23.0
|
46 |
+
idna
|
47 |
+
imagecodecs
|
48 |
+
imageio
|
49 |
+
ipykernel==6.13.0
|
50 |
+
ipython==8.4.0
|
51 |
+
ipython-genutils==0.2.0
|
52 |
+
jedi==0.18.1
|
53 |
+
Jinja2==3.1.2
|
54 |
+
joblib
|
55 |
+
json5==0.9.8
|
56 |
+
jsonschema==4.6.0
|
57 |
+
jupyter-client==7.3.1
|
58 |
+
jupyter-core==4.10.0
|
59 |
+
jupyter-server==1.17.0
|
60 |
+
jupyterlab==3.4.2
|
61 |
+
jupyterlab-pygments==0.2.2
|
62 |
+
jupyterlab-server==2.14.0
|
63 |
+
kiwisolver==1.4.2
|
64 |
+
kornia==0.6.5
|
65 |
+
linkify-it-py==1.0.3
|
66 |
+
locket
|
67 |
+
markdown-it-py==2.1.0
|
68 |
+
MarkupSafe==2.1.1
|
69 |
+
matplotlib==3.5.2
|
70 |
+
matplotlib-inline==0.1.3
|
71 |
+
mdit-py-plugins==0.3.0
|
72 |
+
mdurl==0.1.1
|
73 |
+
mistune==0.8.4
|
74 |
+
mkl-fft==1.3.1
|
75 |
+
mkl-random
|
76 |
+
mkl-service==2.4.0
|
77 |
+
mltk==0.0.5
|
78 |
+
monotonic==1.6
|
79 |
+
multidict==6.0.2
|
80 |
+
munch==2.5.0
|
81 |
+
nbclassic==0.3.7
|
82 |
+
nbclient==0.6.4
|
83 |
+
nbconvert==6.5.0
|
84 |
+
nbformat==5.4.0
|
85 |
+
nest-asyncio==1.5.5
|
86 |
+
networkx
|
87 |
+
nltk==3.7
|
88 |
+
notebook==6.4.11
|
89 |
+
notebook-shim==0.1.0
|
90 |
+
ntk==1.1.3
|
91 |
+
numpy
|
92 |
+
opencv-python==4.6.0.66
|
93 |
+
orjson==3.7.7
|
94 |
+
packaging
|
95 |
+
pandas==1.4.2
|
96 |
+
pandocfilters==1.5.0
|
97 |
+
paramiko==2.11.0
|
98 |
+
parso==0.8.3
|
99 |
+
partd
|
100 |
+
pexpect==4.8.0
|
101 |
+
pickleshare==0.7.5
|
102 |
+
Pillow==9.0.1
|
103 |
+
prometheus-client==0.14.1
|
104 |
+
prompt-toolkit==3.0.29
|
105 |
+
psutil==5.9.1
|
106 |
+
ptyprocess==0.7.0
|
107 |
+
pure-eval==0.2.2
|
108 |
+
pycocotools==2.0.4
|
109 |
+
pycparser
|
110 |
+
pycryptodome==3.15.0
|
111 |
+
pydantic==1.9.1
|
112 |
+
pydub==0.25.1
|
113 |
+
Pygments==2.12.0
|
114 |
+
PyNaCl==1.5.0
|
115 |
+
pyOpenSSL
|
116 |
+
pyparsing
|
117 |
+
pyrsistent==0.18.1
|
118 |
+
PySocks
|
119 |
+
python-dateutil==2.8.2
|
120 |
+
python-multipart==0.0.5
|
121 |
+
pytz==2022.1
|
122 |
+
PyWavelets
|
123 |
+
PyYAML
|
124 |
+
pyzmq==23.1.0
|
125 |
+
qudida
|
126 |
+
regex==2022.6.2
|
127 |
+
requests
|
128 |
+
rfc3986==1.5.0
|
129 |
+
scikit-image
|
130 |
+
scikit-learn
|
131 |
+
scipy
|
132 |
+
seaborn==0.11.2
|
133 |
+
Send2Trash==1.8.0
|
134 |
+
six
|
135 |
+
sniffio==1.2.0
|
136 |
+
soupsieve==2.3.2.post1
|
137 |
+
stack-data==0.2.0
|
138 |
+
starlette==0.19.1
|
139 |
+
termcolor==1.1.0
|
140 |
+
terminado==0.15.0
|
141 |
+
threadpoolctl
|
142 |
+
tifffile
|
143 |
+
tinycss2==1.1.1
|
144 |
+
toolz
|
145 |
+
torch==1.11.0
|
146 |
+
torchaudio==0.11.0
|
147 |
+
torchvision==0.12.0
|
148 |
+
tornado==6.1
|
149 |
+
tqdm==4.64.0
|
150 |
+
traitlets==5.2.2.post1
|
151 |
+
typing_extensions
|
152 |
+
uc-micro-py==1.0.1
|
153 |
+
urllib3
|
154 |
+
uvicorn==0.18.2
|
155 |
+
wcwidth==0.2.5
|
156 |
+
webencodings==0.5.1
|
157 |
+
websocket-client==1.3.2
|
158 |
+
yarl==1.7.2
|
src/.gitkeep
ADDED
File without changes
|
src/README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: BayesCap
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
# Configuration
|
11 |
+
`title`: _string_
|
12 |
+
Display title for the Space
|
13 |
+
`emoji`: _string_
|
14 |
+
Space emoji (emoji-only character allowed)
|
15 |
+
`colorFrom`: _string_
|
16 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
17 |
+
`colorTo`: _string_
|
18 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
19 |
+
`sdk`: _string_
|
20 |
+
Can be either `gradio` or `streamlit`
|
21 |
+
`app_file`: _string_
|
22 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
23 |
+
Path is relative to the root of the repository.
|
24 |
+
|
25 |
+
`pinned`: _boolean_
|
26 |
+
Whether the Space stays on top of your list.
|
src/__pycache__/ds.cpython-310.pyc
ADDED
Binary file (14.6 kB). View file
|
|
src/__pycache__/losses.cpython-310.pyc
ADDED
Binary file (4.17 kB). View file
|
|
src/__pycache__/networks_SRGAN.cpython-310.pyc
ADDED
Binary file (6.99 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (34 kB). View file
|
|
src/app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import cm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.models as models
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from ds import *
|
17 |
+
from losses import *
|
18 |
+
from networks_SRGAN import *
|
19 |
+
from utils import *
|
20 |
+
|
21 |
+
|
22 |
+
NetG = Generator()
|
23 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
print("Number of Parameters:",params)
|
26 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
27 |
+
|
28 |
+
|
29 |
+
NetG = Generator()
|
30 |
+
NetG.load_state_dict(torch.load('../ckpt/srgan-ImageNet-bc347d67.pth', map_location='cuda:0'))
|
31 |
+
NetG.to('cuda')
|
32 |
+
NetG.eval()
|
33 |
+
|
34 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
35 |
+
NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location='cuda:0'))
|
36 |
+
NetC.to('cuda')
|
37 |
+
NetC.eval()
|
38 |
+
|
39 |
+
def tensor01_to_pil(xt):
|
40 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
41 |
+
return r
|
42 |
+
|
43 |
+
|
44 |
+
def predict(img):
|
45 |
+
"""
|
46 |
+
img: image
|
47 |
+
"""
|
48 |
+
image_size = (256,256)
|
49 |
+
upscale_factor = 4
|
50 |
+
lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
51 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
52 |
+
|
53 |
+
img = Image.fromarray(np.array(img))
|
54 |
+
img = lr_transforms(img)
|
55 |
+
lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
|
56 |
+
|
57 |
+
device = 'cuda'
|
58 |
+
dtype = torch.cuda.FloatTensor
|
59 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
60 |
+
xLR = xLR.type(dtype)
|
61 |
+
# pass them through the network
|
62 |
+
with torch.no_grad():
|
63 |
+
xSR = NetG(xLR)
|
64 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
65 |
+
|
66 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
67 |
+
b_map = xSRC_beta[0].to('cpu').data
|
68 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
69 |
+
|
70 |
+
|
71 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
72 |
+
|
73 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
74 |
+
|
75 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
76 |
+
|
77 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
78 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
79 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
80 |
+
|
81 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
82 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
83 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
84 |
+
|
85 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
86 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
87 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
88 |
+
|
89 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
90 |
+
|
91 |
+
import gradio as gr
|
92 |
+
|
93 |
+
title = "BayesCap"
|
94 |
+
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
|
95 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
96 |
+
|
97 |
+
|
98 |
+
gr.Interface(
|
99 |
+
fn=predict,
|
100 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
101 |
+
outputs=[
|
102 |
+
gr.outputs.Image(type='pil', label="Low-res"),
|
103 |
+
gr.outputs.Image(type='pil', label="Super-res"),
|
104 |
+
gr.outputs.Image(type='pil', label="Alpha"),
|
105 |
+
gr.outputs.Image(type='pil', label="Beta"),
|
106 |
+
gr.outputs.Image(type='pil', label="Uncertainty")
|
107 |
+
],
|
108 |
+
title=title,
|
109 |
+
description=description,
|
110 |
+
article=article,
|
111 |
+
examples=[
|
112 |
+
["../demo_examples/baby.png"],
|
113 |
+
["../demo_examples/bird.png"]
|
114 |
+
]
|
115 |
+
).launch(share=True)
|
src/ds.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import skimage.transform
|
10 |
+
from collections import Counter
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.utils.data as data
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torchvision import transforms
|
18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
19 |
+
|
20 |
+
import utils
|
21 |
+
|
22 |
+
class ImgDset(Dataset):
|
23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataroot (str): Training data set address
|
27 |
+
image_size (int): High resolution image size
|
28 |
+
upscale_factor (int): Image magnification
|
29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
30 |
+
and the verification data set is not for data enhancement
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
35 |
+
super(ImgDset, self).__init__()
|
36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
37 |
+
|
38 |
+
if mode == "train":
|
39 |
+
self.hr_transforms = transforms.Compose([
|
40 |
+
transforms.RandomCrop(image_size),
|
41 |
+
transforms.RandomRotation(90),
|
42 |
+
transforms.RandomHorizontalFlip(0.5),
|
43 |
+
])
|
44 |
+
else:
|
45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
46 |
+
|
47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
48 |
+
|
49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
50 |
+
# Read a batch of image data
|
51 |
+
image = Image.open(self.filenames[batch_index])
|
52 |
+
|
53 |
+
# Transform image
|
54 |
+
hr_image = self.hr_transforms(image)
|
55 |
+
lr_image = self.lr_transforms(hr_image)
|
56 |
+
|
57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
58 |
+
# Note: The range of input and output is between [0, 1]
|
59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
61 |
+
|
62 |
+
return lr_tensor, hr_tensor
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
return len(self.filenames)
|
66 |
+
|
67 |
+
|
68 |
+
class PairedImages_w_nameList(Dataset):
|
69 |
+
'''
|
70 |
+
can act as supervised or un-supervised based on flists
|
71 |
+
'''
|
72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
73 |
+
self.flist1 = flist1
|
74 |
+
self.flist2 = flist2
|
75 |
+
self.transform1 = transform1
|
76 |
+
self.transform2 = transform2
|
77 |
+
self.do_aug = do_aug
|
78 |
+
def __getitem__(self, index):
|
79 |
+
impath1 = self.flist1[index]
|
80 |
+
img1 = Image.open(impath1).convert('RGB')
|
81 |
+
impath2 = self.flist2[index]
|
82 |
+
img2 = Image.open(impath2).convert('RGB')
|
83 |
+
|
84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
86 |
+
|
87 |
+
if self.transform1 is not None:
|
88 |
+
img1 = self.transform1(img1)
|
89 |
+
if self.transform2 is not None:
|
90 |
+
img2 = self.transform2(img2)
|
91 |
+
|
92 |
+
return img1, img2
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.flist1)
|
95 |
+
|
96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
97 |
+
'''
|
98 |
+
can act as supervised or un-supervised based on flists
|
99 |
+
'''
|
100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
101 |
+
self.flist1 = flist1
|
102 |
+
self.flist2 = flist2
|
103 |
+
self.transform1 = transform1
|
104 |
+
self.transform2 = transform2
|
105 |
+
self.do_aug = do_aug
|
106 |
+
def __getitem__(self, index):
|
107 |
+
impath1 = self.flist1[index]
|
108 |
+
img1 = np.load(impath1)
|
109 |
+
impath2 = self.flist2[index]
|
110 |
+
img2 = np.load(impath2)
|
111 |
+
|
112 |
+
if self.transform1 is not None:
|
113 |
+
img1 = self.transform1(img1)
|
114 |
+
if self.transform2 is not None:
|
115 |
+
img2 = self.transform2(img2)
|
116 |
+
|
117 |
+
return img1, img2
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.flist1)
|
120 |
+
|
121 |
+
# def call_paired():
|
122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
124 |
+
|
125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
127 |
+
|
128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
129 |
+
|
130 |
+
#### KITTI depth
|
131 |
+
|
132 |
+
def load_velodyne_points(filename):
|
133 |
+
"""Load 3D point cloud from KITTI file format
|
134 |
+
(adapted from https://github.com/hunse/kitti)
|
135 |
+
"""
|
136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
137 |
+
points[:, 3] = 1.0 # homogeneous
|
138 |
+
return points
|
139 |
+
|
140 |
+
|
141 |
+
def read_calib_file(path):
|
142 |
+
"""Read KITTI calibration file
|
143 |
+
(from https://github.com/hunse/kitti)
|
144 |
+
"""
|
145 |
+
float_chars = set("0123456789.e+- ")
|
146 |
+
data = {}
|
147 |
+
with open(path, 'r') as f:
|
148 |
+
for line in f.readlines():
|
149 |
+
key, value = line.split(':', 1)
|
150 |
+
value = value.strip()
|
151 |
+
data[key] = value
|
152 |
+
if float_chars.issuperset(value):
|
153 |
+
# try to cast to float array
|
154 |
+
try:
|
155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
156 |
+
except ValueError:
|
157 |
+
# casting error: data[key] already eq. value, so pass
|
158 |
+
pass
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
|
163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
164 |
+
"""Convert row, col matrix subscripts to linear indices
|
165 |
+
"""
|
166 |
+
m, n = matrixSize
|
167 |
+
return rowSub * (n-1) + colSub - 1
|
168 |
+
|
169 |
+
|
170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
171 |
+
"""Generate a depth map from velodyne data
|
172 |
+
"""
|
173 |
+
# load calibration files
|
174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
178 |
+
|
179 |
+
# get image shape
|
180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
181 |
+
|
182 |
+
# compute projection matrix velodyne->image plane
|
183 |
+
R_cam2rect = np.eye(4)
|
184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
187 |
+
|
188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
190 |
+
velo = load_velodyne_points(velo_filename)
|
191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
192 |
+
|
193 |
+
# project the points to the camera
|
194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
196 |
+
|
197 |
+
if vel_depth:
|
198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
199 |
+
|
200 |
+
# check if in bounds
|
201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
207 |
+
|
208 |
+
# project to image
|
209 |
+
depth = np.zeros((im_shape[:2]))
|
210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
211 |
+
|
212 |
+
# find the duplicate points and choose the closest depth
|
213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
215 |
+
for dd in dupe_inds:
|
216 |
+
pts = np.where(inds == dd)[0]
|
217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
220 |
+
depth[depth < 0] = 0
|
221 |
+
|
222 |
+
return depth
|
223 |
+
|
224 |
+
def pil_loader(path):
|
225 |
+
# open path as file to avoid ResourceWarning
|
226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
227 |
+
with open(path, 'rb') as f:
|
228 |
+
with Image.open(f) as img:
|
229 |
+
return img.convert('RGB')
|
230 |
+
|
231 |
+
|
232 |
+
class MonoDataset(data.Dataset):
|
233 |
+
"""Superclass for monocular dataloaders
|
234 |
+
|
235 |
+
Args:
|
236 |
+
data_path
|
237 |
+
filenames
|
238 |
+
height
|
239 |
+
width
|
240 |
+
frame_idxs
|
241 |
+
num_scales
|
242 |
+
is_train
|
243 |
+
img_ext
|
244 |
+
"""
|
245 |
+
def __init__(self,
|
246 |
+
data_path,
|
247 |
+
filenames,
|
248 |
+
height,
|
249 |
+
width,
|
250 |
+
frame_idxs,
|
251 |
+
num_scales,
|
252 |
+
is_train=False,
|
253 |
+
img_ext='.jpg'):
|
254 |
+
super(MonoDataset, self).__init__()
|
255 |
+
|
256 |
+
self.data_path = data_path
|
257 |
+
self.filenames = filenames
|
258 |
+
self.height = height
|
259 |
+
self.width = width
|
260 |
+
self.num_scales = num_scales
|
261 |
+
self.interp = Image.ANTIALIAS
|
262 |
+
|
263 |
+
self.frame_idxs = frame_idxs
|
264 |
+
|
265 |
+
self.is_train = is_train
|
266 |
+
self.img_ext = img_ext
|
267 |
+
|
268 |
+
self.loader = pil_loader
|
269 |
+
self.to_tensor = transforms.ToTensor()
|
270 |
+
|
271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
273 |
+
try:
|
274 |
+
self.brightness = (0.8, 1.2)
|
275 |
+
self.contrast = (0.8, 1.2)
|
276 |
+
self.saturation = (0.8, 1.2)
|
277 |
+
self.hue = (-0.1, 0.1)
|
278 |
+
transforms.ColorJitter.get_params(
|
279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
280 |
+
except TypeError:
|
281 |
+
self.brightness = 0.2
|
282 |
+
self.contrast = 0.2
|
283 |
+
self.saturation = 0.2
|
284 |
+
self.hue = 0.1
|
285 |
+
|
286 |
+
self.resize = {}
|
287 |
+
for i in range(self.num_scales):
|
288 |
+
s = 2 ** i
|
289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
290 |
+
interpolation=self.interp)
|
291 |
+
|
292 |
+
self.load_depth = self.check_depth()
|
293 |
+
|
294 |
+
def preprocess(self, inputs, color_aug):
|
295 |
+
"""Resize colour images to the required scales and augment if required
|
296 |
+
|
297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
299 |
+
same augmentation.
|
300 |
+
"""
|
301 |
+
for k in list(inputs):
|
302 |
+
frame = inputs[k]
|
303 |
+
if "color" in k:
|
304 |
+
n, im, i = k
|
305 |
+
for i in range(self.num_scales):
|
306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
307 |
+
|
308 |
+
for k in list(inputs):
|
309 |
+
f = inputs[k]
|
310 |
+
if "color" in k:
|
311 |
+
n, im, i = k
|
312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
314 |
+
|
315 |
+
def __len__(self):
|
316 |
+
return len(self.filenames)
|
317 |
+
|
318 |
+
def __getitem__(self, index):
|
319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
320 |
+
|
321 |
+
Values correspond to torch tensors.
|
322 |
+
Keys in the dictionary are either strings or tuples:
|
323 |
+
|
324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
327 |
+
"stereo_T" for camera extrinsics, and
|
328 |
+
"depth_gt" for ground truth depth maps.
|
329 |
+
|
330 |
+
<frame_id> is either:
|
331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
332 |
+
or
|
333 |
+
"s" for the opposite image in the stereo pair.
|
334 |
+
|
335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
336 |
+
-1 images at native resolution as loaded from disk
|
337 |
+
0 images resized to (self.width, self.height )
|
338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
341 |
+
"""
|
342 |
+
inputs = {}
|
343 |
+
|
344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
345 |
+
do_flip = self.is_train and random.random() > 0.5
|
346 |
+
|
347 |
+
line = self.filenames[index].split()
|
348 |
+
folder = line[0]
|
349 |
+
|
350 |
+
if len(line) == 3:
|
351 |
+
frame_index = int(line[1])
|
352 |
+
else:
|
353 |
+
frame_index = 0
|
354 |
+
|
355 |
+
if len(line) == 3:
|
356 |
+
side = line[2]
|
357 |
+
else:
|
358 |
+
side = None
|
359 |
+
|
360 |
+
for i in self.frame_idxs:
|
361 |
+
if i == "s":
|
362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
364 |
+
else:
|
365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
366 |
+
|
367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
368 |
+
for scale in range(self.num_scales):
|
369 |
+
K = self.K.copy()
|
370 |
+
|
371 |
+
K[0, :] *= self.width // (2 ** scale)
|
372 |
+
K[1, :] *= self.height // (2 ** scale)
|
373 |
+
|
374 |
+
inv_K = np.linalg.pinv(K)
|
375 |
+
|
376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
378 |
+
|
379 |
+
if do_color_aug:
|
380 |
+
color_aug = transforms.ColorJitter.get_params(
|
381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
382 |
+
else:
|
383 |
+
color_aug = (lambda x: x)
|
384 |
+
|
385 |
+
self.preprocess(inputs, color_aug)
|
386 |
+
|
387 |
+
for i in self.frame_idxs:
|
388 |
+
del inputs[("color", i, -1)]
|
389 |
+
del inputs[("color_aug", i, -1)]
|
390 |
+
|
391 |
+
if self.load_depth:
|
392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
395 |
+
|
396 |
+
if "s" in self.frame_idxs:
|
397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
398 |
+
baseline_sign = -1 if do_flip else 1
|
399 |
+
side_sign = -1 if side == "l" else 1
|
400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
401 |
+
|
402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
403 |
+
|
404 |
+
return inputs
|
405 |
+
|
406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
407 |
+
raise NotImplementedError
|
408 |
+
|
409 |
+
def check_depth(self):
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
class KITTIDataset(MonoDataset):
|
416 |
+
"""Superclass for different types of KITTI dataset loaders
|
417 |
+
"""
|
418 |
+
def __init__(self, *args, **kwargs):
|
419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
420 |
+
|
421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
425 |
+
# flip augmentation.
|
426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
427 |
+
[0, 1.92, 0.5, 0],
|
428 |
+
[0, 0, 1, 0],
|
429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
430 |
+
|
431 |
+
self.full_res_shape = (1242, 375)
|
432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
433 |
+
|
434 |
+
def check_depth(self):
|
435 |
+
line = self.filenames[0].split()
|
436 |
+
scene_name = line[0]
|
437 |
+
frame_index = int(line[1])
|
438 |
+
|
439 |
+
velo_filename = os.path.join(
|
440 |
+
self.data_path,
|
441 |
+
scene_name,
|
442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
443 |
+
|
444 |
+
return os.path.isfile(velo_filename)
|
445 |
+
|
446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
448 |
+
|
449 |
+
if do_flip:
|
450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
451 |
+
|
452 |
+
return color
|
453 |
+
|
454 |
+
|
455 |
+
class KITTIDepthDataset(KITTIDataset):
|
456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
457 |
+
"""
|
458 |
+
def __init__(self, *args, **kwargs):
|
459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
460 |
+
|
461 |
+
def get_image_path(self, folder, frame_index, side):
|
462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
463 |
+
image_path = os.path.join(
|
464 |
+
self.data_path,
|
465 |
+
folder,
|
466 |
+
"image_0{}/data".format(self.side_map[side]),
|
467 |
+
f_str)
|
468 |
+
return image_path
|
469 |
+
|
470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
471 |
+
f_str = "{:010d}.png".format(frame_index)
|
472 |
+
depth_path = os.path.join(
|
473 |
+
self.data_path,
|
474 |
+
folder,
|
475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
476 |
+
f_str)
|
477 |
+
|
478 |
+
depth_gt = Image.open(depth_path)
|
479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
481 |
+
|
482 |
+
if do_flip:
|
483 |
+
depth_gt = np.fliplr(depth_gt)
|
484 |
+
|
485 |
+
return depth_gt
|
src/flagged/Alpha/0.png
ADDED
![]() |
src/flagged/Beta/0.png
ADDED
![]() |
src/flagged/Low-res/0.png
ADDED
![]() |
src/flagged/Orignal/0.png
ADDED
![]() |
src/flagged/Super-res/0.png
ADDED
![]() |
src/flagged/Uncertainty/0.png
ADDED
![]() |
src/flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
'Orignal','Low-res','Super-res','Alpha','Beta','Uncertainty','flag','username','timestamp'
|
2 |
+
'Orignal/0.png','Low-res/0.png','Super-res/0.png','Alpha/0.png','Beta/0.png','Uncertainty/0.png','','','2022-07-09 14:01:12.964411'
|
src/losses.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
class ContentLoss(nn.Module):
|
8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
10 |
+
|
11 |
+
Paper reference list:
|
12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
super(ContentLoss, self).__init__()
|
20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
24 |
+
# Freeze model parameters.
|
25 |
+
for parameters in self.feature_extractor.parameters():
|
26 |
+
parameters.requires_grad = False
|
27 |
+
|
28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
31 |
+
|
32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
33 |
+
# Standardized operations
|
34 |
+
sr = sr.sub(self.mean).div(self.std)
|
35 |
+
hr = hr.sub(self.mean).div(self.std)
|
36 |
+
|
37 |
+
# Find the feature map difference between the two images
|
38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
39 |
+
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
class GenGaussLoss(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, reduction='mean',
|
46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
47 |
+
resi_min = 1e-4, resi_max=1e3
|
48 |
+
) -> None:
|
49 |
+
super(GenGaussLoss, self).__init__()
|
50 |
+
self.reduction = reduction
|
51 |
+
self.alpha_eps = alpha_eps
|
52 |
+
self.beta_eps = beta_eps
|
53 |
+
self.resi_min = resi_min
|
54 |
+
self.resi_max = resi_max
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
59 |
+
):
|
60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
61 |
+
beta1 = beta + self.beta_eps
|
62 |
+
|
63 |
+
resi = torch.abs(mean - target)
|
64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
66 |
+
## check if resi has nans
|
67 |
+
if torch.sum(resi != resi) > 0:
|
68 |
+
print('resi has nans!!')
|
69 |
+
return None
|
70 |
+
|
71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
72 |
+
log_beta = torch.log(beta1)
|
73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
74 |
+
|
75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
76 |
+
print('log_one_over_alpha has nan')
|
77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
78 |
+
print('lgamma_beta has nan')
|
79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
80 |
+
print('log_beta has nan')
|
81 |
+
|
82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
83 |
+
|
84 |
+
if self.reduction == 'mean':
|
85 |
+
return l.mean()
|
86 |
+
elif self.reduction == 'sum':
|
87 |
+
return l.sum()
|
88 |
+
else:
|
89 |
+
print('Reduction not supported')
|
90 |
+
return None
|
91 |
+
|
92 |
+
class TempCombLoss(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self, reduction='mean',
|
95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
96 |
+
resi_min = 1e-4, resi_max=1e3
|
97 |
+
) -> None:
|
98 |
+
super(TempCombLoss, self).__init__()
|
99 |
+
self.reduction = reduction
|
100 |
+
self.alpha_eps = alpha_eps
|
101 |
+
self.beta_eps = beta_eps
|
102 |
+
self.resi_min = resi_min
|
103 |
+
self.resi_max = resi_max
|
104 |
+
|
105 |
+
self.L_GenGauss = GenGaussLoss(
|
106 |
+
reduction=self.reduction,
|
107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
109 |
+
)
|
110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
115 |
+
T1: float, T2: float
|
116 |
+
):
|
117 |
+
l1 = self.L_l1(mean, target)
|
118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
119 |
+
l = T1*l1 + T2*l2
|
120 |
+
|
121 |
+
return l
|
122 |
+
|
123 |
+
|
124 |
+
# x1 = torch.randn(4,3,32,32)
|
125 |
+
# x2 = torch.rand(4,3,32,32)
|
126 |
+
# x3 = torch.rand(4,3,32,32)
|
127 |
+
# x4 = torch.randn(4,3,32,32)
|
128 |
+
|
129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
src/networks_SRGAN.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
# __all__ = [
|
8 |
+
# "ResidualConvBlock",
|
9 |
+
# "Discriminator", "Generator",
|
10 |
+
# ]
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualConvBlock(nn.Module):
|
14 |
+
"""Implements residual conv function.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
channels (int): Number of channels in the input image.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels: int) -> None:
|
21 |
+
super(ResidualConvBlock, self).__init__()
|
22 |
+
self.rcb = nn.Sequential(
|
23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
24 |
+
nn.BatchNorm2d(channels),
|
25 |
+
nn.PReLU(),
|
26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
27 |
+
nn.BatchNorm2d(channels),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
identity = x
|
32 |
+
|
33 |
+
out = self.rcb(x)
|
34 |
+
out = torch.add(out, identity)
|
35 |
+
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
class Discriminator(nn.Module):
|
40 |
+
def __init__(self) -> None:
|
41 |
+
super(Discriminator, self).__init__()
|
42 |
+
self.features = nn.Sequential(
|
43 |
+
# input size. (3) x 96 x 96
|
44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
45 |
+
nn.LeakyReLU(0.2, True),
|
46 |
+
# state size. (64) x 48 x 48
|
47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
48 |
+
nn.BatchNorm2d(64),
|
49 |
+
nn.LeakyReLU(0.2, True),
|
50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
51 |
+
nn.BatchNorm2d(128),
|
52 |
+
nn.LeakyReLU(0.2, True),
|
53 |
+
# state size. (128) x 24 x 24
|
54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
55 |
+
nn.BatchNorm2d(128),
|
56 |
+
nn.LeakyReLU(0.2, True),
|
57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
58 |
+
nn.BatchNorm2d(256),
|
59 |
+
nn.LeakyReLU(0.2, True),
|
60 |
+
# state size. (256) x 12 x 12
|
61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
62 |
+
nn.BatchNorm2d(256),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
65 |
+
nn.BatchNorm2d(512),
|
66 |
+
nn.LeakyReLU(0.2, True),
|
67 |
+
# state size. (512) x 6 x 6
|
68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
69 |
+
nn.BatchNorm2d(512),
|
70 |
+
nn.LeakyReLU(0.2, True),
|
71 |
+
)
|
72 |
+
|
73 |
+
self.classifier = nn.Sequential(
|
74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
75 |
+
nn.LeakyReLU(0.2, True),
|
76 |
+
nn.Linear(1024, 1),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor:
|
80 |
+
out = self.features(x)
|
81 |
+
out = torch.flatten(out, 1)
|
82 |
+
out = self.classifier(out)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
def __init__(self) -> None:
|
89 |
+
super(Generator, self).__init__()
|
90 |
+
# First conv layer.
|
91 |
+
self.conv_block1 = nn.Sequential(
|
92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
93 |
+
nn.PReLU(),
|
94 |
+
)
|
95 |
+
|
96 |
+
# Features trunk blocks.
|
97 |
+
trunk = []
|
98 |
+
for _ in range(16):
|
99 |
+
trunk.append(ResidualConvBlock(64))
|
100 |
+
self.trunk = nn.Sequential(*trunk)
|
101 |
+
|
102 |
+
# Second conv layer.
|
103 |
+
self.conv_block2 = nn.Sequential(
|
104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
105 |
+
nn.BatchNorm2d(64),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Upscale conv block.
|
109 |
+
self.upsampling = nn.Sequential(
|
110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
111 |
+
nn.PixelShuffle(2),
|
112 |
+
nn.PReLU(),
|
113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
114 |
+
nn.PixelShuffle(2),
|
115 |
+
nn.PReLU(),
|
116 |
+
)
|
117 |
+
|
118 |
+
# Output layer.
|
119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
120 |
+
|
121 |
+
# Initialize neural network weights.
|
122 |
+
self._initialize_weights()
|
123 |
+
|
124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
125 |
+
if not dop:
|
126 |
+
return self._forward_impl(x)
|
127 |
+
else:
|
128 |
+
return self._forward_w_dop_impl(x, dop)
|
129 |
+
|
130 |
+
# Support torch.script function.
|
131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
132 |
+
out1 = self.conv_block1(x)
|
133 |
+
out = self.trunk(out1)
|
134 |
+
out2 = self.conv_block2(out)
|
135 |
+
out = torch.add(out1, out2)
|
136 |
+
out = self.upsampling(out)
|
137 |
+
out = self.conv_block3(out)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
142 |
+
out1 = self.conv_block1(x)
|
143 |
+
out = self.trunk(out1)
|
144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
145 |
+
out = torch.add(out1, out2)
|
146 |
+
out = self.upsampling(out)
|
147 |
+
out = self.conv_block3(out)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
def _initialize_weights(self) -> None:
|
152 |
+
for module in self.modules():
|
153 |
+
if isinstance(module, nn.Conv2d):
|
154 |
+
nn.init.kaiming_normal_(module.weight)
|
155 |
+
if module.bias is not None:
|
156 |
+
nn.init.constant_(module.bias, 0)
|
157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
158 |
+
nn.init.constant_(module.weight, 1)
|
159 |
+
|
160 |
+
|
161 |
+
#### BayesCap
|
162 |
+
class BayesCap(nn.Module):
|
163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
164 |
+
super(BayesCap, self).__init__()
|
165 |
+
# First conv layer.
|
166 |
+
self.conv_block1 = nn.Sequential(
|
167 |
+
nn.Conv2d(
|
168 |
+
in_channels, 64,
|
169 |
+
kernel_size=9, stride=1, padding=4
|
170 |
+
),
|
171 |
+
nn.PReLU(),
|
172 |
+
)
|
173 |
+
|
174 |
+
# Features trunk blocks.
|
175 |
+
trunk = []
|
176 |
+
for _ in range(16):
|
177 |
+
trunk.append(ResidualConvBlock(64))
|
178 |
+
self.trunk = nn.Sequential(*trunk)
|
179 |
+
|
180 |
+
# Second conv layer.
|
181 |
+
self.conv_block2 = nn.Sequential(
|
182 |
+
nn.Conv2d(
|
183 |
+
64, 64,
|
184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
185 |
+
),
|
186 |
+
nn.BatchNorm2d(64),
|
187 |
+
)
|
188 |
+
|
189 |
+
# Output layer.
|
190 |
+
self.conv_block3_mu = nn.Conv2d(
|
191 |
+
64, out_channels=out_channels,
|
192 |
+
kernel_size=9, stride=1, padding=4
|
193 |
+
)
|
194 |
+
self.conv_block3_alpha = nn.Sequential(
|
195 |
+
nn.Conv2d(
|
196 |
+
64, 64,
|
197 |
+
kernel_size=9, stride=1, padding=4
|
198 |
+
),
|
199 |
+
nn.PReLU(),
|
200 |
+
nn.Conv2d(
|
201 |
+
64, 64,
|
202 |
+
kernel_size=9, stride=1, padding=4
|
203 |
+
),
|
204 |
+
nn.PReLU(),
|
205 |
+
nn.Conv2d(
|
206 |
+
64, 1,
|
207 |
+
kernel_size=9, stride=1, padding=4
|
208 |
+
),
|
209 |
+
nn.ReLU(),
|
210 |
+
)
|
211 |
+
self.conv_block3_beta = nn.Sequential(
|
212 |
+
nn.Conv2d(
|
213 |
+
64, 64,
|
214 |
+
kernel_size=9, stride=1, padding=4
|
215 |
+
),
|
216 |
+
nn.PReLU(),
|
217 |
+
nn.Conv2d(
|
218 |
+
64, 64,
|
219 |
+
kernel_size=9, stride=1, padding=4
|
220 |
+
),
|
221 |
+
nn.PReLU(),
|
222 |
+
nn.Conv2d(
|
223 |
+
64, 1,
|
224 |
+
kernel_size=9, stride=1, padding=4
|
225 |
+
),
|
226 |
+
nn.ReLU(),
|
227 |
+
)
|
228 |
+
|
229 |
+
# Initialize neural network weights.
|
230 |
+
self._initialize_weights()
|
231 |
+
|
232 |
+
def forward(self, x: Tensor) -> Tensor:
|
233 |
+
return self._forward_impl(x)
|
234 |
+
|
235 |
+
# Support torch.script function.
|
236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
237 |
+
out1 = self.conv_block1(x)
|
238 |
+
out = self.trunk(out1)
|
239 |
+
out2 = self.conv_block2(out)
|
240 |
+
out = out1 + out2
|
241 |
+
out_mu = self.conv_block3_mu(out)
|
242 |
+
out_alpha = self.conv_block3_alpha(out)
|
243 |
+
out_beta = self.conv_block3_beta(out)
|
244 |
+
return out_mu, out_alpha, out_beta
|
245 |
+
|
246 |
+
def _initialize_weights(self) -> None:
|
247 |
+
for module in self.modules():
|
248 |
+
if isinstance(module, nn.Conv2d):
|
249 |
+
nn.init.kaiming_normal_(module.weight)
|
250 |
+
if module.bias is not None:
|
251 |
+
nn.init.constant_(module.bias, 0)
|
252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
253 |
+
nn.init.constant_(module.weight, 1)
|
254 |
+
|
255 |
+
|
256 |
+
class BayesCap_noID(nn.Module):
|
257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
258 |
+
super(BayesCap_noID, self).__init__()
|
259 |
+
# First conv layer.
|
260 |
+
self.conv_block1 = nn.Sequential(
|
261 |
+
nn.Conv2d(
|
262 |
+
in_channels, 64,
|
263 |
+
kernel_size=9, stride=1, padding=4
|
264 |
+
),
|
265 |
+
nn.PReLU(),
|
266 |
+
)
|
267 |
+
|
268 |
+
# Features trunk blocks.
|
269 |
+
trunk = []
|
270 |
+
for _ in range(16):
|
271 |
+
trunk.append(ResidualConvBlock(64))
|
272 |
+
self.trunk = nn.Sequential(*trunk)
|
273 |
+
|
274 |
+
# Second conv layer.
|
275 |
+
self.conv_block2 = nn.Sequential(
|
276 |
+
nn.Conv2d(
|
277 |
+
64, 64,
|
278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
279 |
+
),
|
280 |
+
nn.BatchNorm2d(64),
|
281 |
+
)
|
282 |
+
|
283 |
+
# Output layer.
|
284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
285 |
+
# 64, out_channels=out_channels,
|
286 |
+
# kernel_size=9, stride=1, padding=4
|
287 |
+
# )
|
288 |
+
self.conv_block3_alpha = nn.Sequential(
|
289 |
+
nn.Conv2d(
|
290 |
+
64, 64,
|
291 |
+
kernel_size=9, stride=1, padding=4
|
292 |
+
),
|
293 |
+
nn.PReLU(),
|
294 |
+
nn.Conv2d(
|
295 |
+
64, 64,
|
296 |
+
kernel_size=9, stride=1, padding=4
|
297 |
+
),
|
298 |
+
nn.PReLU(),
|
299 |
+
nn.Conv2d(
|
300 |
+
64, 1,
|
301 |
+
kernel_size=9, stride=1, padding=4
|
302 |
+
),
|
303 |
+
nn.ReLU(),
|
304 |
+
)
|
305 |
+
self.conv_block3_beta = nn.Sequential(
|
306 |
+
nn.Conv2d(
|
307 |
+
64, 64,
|
308 |
+
kernel_size=9, stride=1, padding=4
|
309 |
+
),
|
310 |
+
nn.PReLU(),
|
311 |
+
nn.Conv2d(
|
312 |
+
64, 64,
|
313 |
+
kernel_size=9, stride=1, padding=4
|
314 |
+
),
|
315 |
+
nn.PReLU(),
|
316 |
+
nn.Conv2d(
|
317 |
+
64, 1,
|
318 |
+
kernel_size=9, stride=1, padding=4
|
319 |
+
),
|
320 |
+
nn.ReLU(),
|
321 |
+
)
|
322 |
+
|
323 |
+
# Initialize neural network weights.
|
324 |
+
self._initialize_weights()
|
325 |
+
|
326 |
+
def forward(self, x: Tensor) -> Tensor:
|
327 |
+
return self._forward_impl(x)
|
328 |
+
|
329 |
+
# Support torch.script function.
|
330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
331 |
+
out1 = self.conv_block1(x)
|
332 |
+
out = self.trunk(out1)
|
333 |
+
out2 = self.conv_block2(out)
|
334 |
+
out = out1 + out2
|
335 |
+
# out_mu = self.conv_block3_mu(out)
|
336 |
+
out_alpha = self.conv_block3_alpha(out)
|
337 |
+
out_beta = self.conv_block3_beta(out)
|
338 |
+
return out_alpha, out_beta
|
339 |
+
|
340 |
+
def _initialize_weights(self) -> None:
|
341 |
+
for module in self.modules():
|
342 |
+
if isinstance(module, nn.Conv2d):
|
343 |
+
nn.init.kaiming_normal_(module.weight)
|
344 |
+
if module.bias is not None:
|
345 |
+
nn.init.constant_(module.bias, 0)
|
346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
347 |
+
nn.init.constant_(module.weight, 1)
|
src/networks_T1toT2.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
|
6 |
+
### components
|
7 |
+
class ResConv(nn.Module):
|
8 |
+
"""
|
9 |
+
Residual convolutional block, where
|
10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
11 |
+
residual connection adds the input to the output
|
12 |
+
"""
|
13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
14 |
+
super().__init__()
|
15 |
+
if not mid_channels:
|
16 |
+
mid_channels = out_channels
|
17 |
+
self.double_conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
19 |
+
nn.BatchNorm2d(mid_channels),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
22 |
+
nn.BatchNorm2d(mid_channels),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
25 |
+
nn.BatchNorm2d(out_channels),
|
26 |
+
nn.ReLU(inplace=True)
|
27 |
+
)
|
28 |
+
self.double_conv1 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
30 |
+
nn.BatchNorm2d(out_channels),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
x_in = self.double_conv1(x)
|
35 |
+
x1 = self.double_conv(x)
|
36 |
+
return self.double_conv(x) + x_in
|
37 |
+
|
38 |
+
class Down(nn.Module):
|
39 |
+
"""Downscaling with maxpool then Resconv"""
|
40 |
+
def __init__(self, in_channels, out_channels):
|
41 |
+
super().__init__()
|
42 |
+
self.maxpool_conv = nn.Sequential(
|
43 |
+
nn.MaxPool2d(2),
|
44 |
+
ResConv(in_channels, out_channels)
|
45 |
+
)
|
46 |
+
def forward(self, x):
|
47 |
+
return self.maxpool_conv(x)
|
48 |
+
|
49 |
+
class Up(nn.Module):
|
50 |
+
"""Upscaling then double conv"""
|
51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
52 |
+
super().__init__()
|
53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
54 |
+
if bilinear:
|
55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
57 |
+
else:
|
58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
59 |
+
self.conv = ResConv(in_channels, out_channels)
|
60 |
+
def forward(self, x1, x2):
|
61 |
+
x1 = self.up(x1)
|
62 |
+
# input is CHW
|
63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
65 |
+
x1 = F.pad(
|
66 |
+
x1,
|
67 |
+
[
|
68 |
+
diffX // 2, diffX - diffX // 2,
|
69 |
+
diffY // 2, diffY - diffY // 2
|
70 |
+
]
|
71 |
+
)
|
72 |
+
# if you have padding issues, see
|
73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
75 |
+
x = torch.cat([x2, x1], dim=1)
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
class OutConv(nn.Module):
|
79 |
+
def __init__(self, in_channels, out_channels):
|
80 |
+
super(OutConv, self).__init__()
|
81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
82 |
+
def forward(self, x):
|
83 |
+
# return F.relu(self.conv(x))
|
84 |
+
return self.conv(x)
|
85 |
+
|
86 |
+
##### The composite networks
|
87 |
+
class UNet(nn.Module):
|
88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
89 |
+
super(UNet, self).__init__()
|
90 |
+
self.n_channels = n_channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
self.bilinear = bilinear
|
93 |
+
####
|
94 |
+
self.inc = ResConv(n_channels, 64)
|
95 |
+
self.down1 = Down(64, 128)
|
96 |
+
self.down2 = Down(128, 256)
|
97 |
+
self.down3 = Down(256, 512)
|
98 |
+
factor = 2 if bilinear else 1
|
99 |
+
self.down4 = Down(512, 1024 // factor)
|
100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
103 |
+
self.up4 = Up(128, 64, bilinear)
|
104 |
+
self.outc = OutConv(64, out_channels)
|
105 |
+
def forward(self, x):
|
106 |
+
x1 = self.inc(x)
|
107 |
+
x2 = self.down1(x1)
|
108 |
+
x3 = self.down2(x2)
|
109 |
+
x4 = self.down3(x3)
|
110 |
+
x5 = self.down4(x4)
|
111 |
+
x = self.up1(x5, x4)
|
112 |
+
x = self.up2(x, x3)
|
113 |
+
x = self.up3(x, x2)
|
114 |
+
x = self.up4(x, x1)
|
115 |
+
y = self.outc(x)
|
116 |
+
return y
|
117 |
+
|
118 |
+
class CasUNet(nn.Module):
|
119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
120 |
+
super(CasUNet, self).__init__()
|
121 |
+
self.n_unet = n_unet
|
122 |
+
self.io_channels = io_channels
|
123 |
+
self.bilinear = bilinear
|
124 |
+
####
|
125 |
+
self.unet_list = nn.ModuleList()
|
126 |
+
for i in range(self.n_unet):
|
127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
128 |
+
def forward(self, x, dop=None):
|
129 |
+
y = x
|
130 |
+
for i in range(self.n_unet):
|
131 |
+
if i==0:
|
132 |
+
if dop is not None:
|
133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
134 |
+
else:
|
135 |
+
y = self.unet_list[i](y)
|
136 |
+
else:
|
137 |
+
y = self.unet_list[i](y+x)
|
138 |
+
return y
|
139 |
+
|
140 |
+
class CasUNet_2head(nn.Module):
|
141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
142 |
+
super(CasUNet_2head, self).__init__()
|
143 |
+
self.n_unet = n_unet
|
144 |
+
self.io_channels = io_channels
|
145 |
+
self.bilinear = bilinear
|
146 |
+
####
|
147 |
+
self.unet_list = nn.ModuleList()
|
148 |
+
for i in range(self.n_unet):
|
149 |
+
if i != self.n_unet-1:
|
150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
151 |
+
else:
|
152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
153 |
+
def forward(self, x):
|
154 |
+
y = x
|
155 |
+
for i in range(self.n_unet):
|
156 |
+
if i==0:
|
157 |
+
y = self.unet_list[i](y)
|
158 |
+
else:
|
159 |
+
y = self.unet_list[i](y+x)
|
160 |
+
y_mean, y_sigma = y[0], y[1]
|
161 |
+
return y_mean, y_sigma
|
162 |
+
|
163 |
+
class CasUNet_3head(nn.Module):
|
164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
165 |
+
super(CasUNet_3head, self).__init__()
|
166 |
+
self.n_unet = n_unet
|
167 |
+
self.io_channels = io_channels
|
168 |
+
self.bilinear = bilinear
|
169 |
+
####
|
170 |
+
self.unet_list = nn.ModuleList()
|
171 |
+
for i in range(self.n_unet):
|
172 |
+
if i != self.n_unet-1:
|
173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
174 |
+
else:
|
175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
176 |
+
def forward(self, x):
|
177 |
+
y = x
|
178 |
+
for i in range(self.n_unet):
|
179 |
+
if i==0:
|
180 |
+
y = self.unet_list[i](y)
|
181 |
+
else:
|
182 |
+
y = self.unet_list[i](y+x)
|
183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
184 |
+
return y_mean, y_alpha, y_beta
|
185 |
+
|
186 |
+
class UNet_2head(nn.Module):
|
187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
188 |
+
super(UNet_2head, self).__init__()
|
189 |
+
self.n_channels = n_channels
|
190 |
+
self.out_channels = out_channels
|
191 |
+
self.bilinear = bilinear
|
192 |
+
####
|
193 |
+
self.inc = ResConv(n_channels, 64)
|
194 |
+
self.down1 = Down(64, 128)
|
195 |
+
self.down2 = Down(128, 256)
|
196 |
+
self.down3 = Down(256, 512)
|
197 |
+
factor = 2 if bilinear else 1
|
198 |
+
self.down4 = Down(512, 1024 // factor)
|
199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
202 |
+
self.up4 = Up(128, 64, bilinear)
|
203 |
+
#per pixel multiple channels may exist
|
204 |
+
self.out_mean = OutConv(64, out_channels)
|
205 |
+
#variance will always be a single number for a pixel
|
206 |
+
self.out_var = nn.Sequential(
|
207 |
+
OutConv(64, 128),
|
208 |
+
OutConv(128, 1),
|
209 |
+
)
|
210 |
+
def forward(self, x):
|
211 |
+
x1 = self.inc(x)
|
212 |
+
x2 = self.down1(x1)
|
213 |
+
x3 = self.down2(x2)
|
214 |
+
x4 = self.down3(x3)
|
215 |
+
x5 = self.down4(x4)
|
216 |
+
x = self.up1(x5, x4)
|
217 |
+
x = self.up2(x, x3)
|
218 |
+
x = self.up3(x, x2)
|
219 |
+
x = self.up4(x, x1)
|
220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
221 |
+
return y_mean, y_var
|
222 |
+
|
223 |
+
class UNet_3head(nn.Module):
|
224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
225 |
+
super(UNet_3head, self).__init__()
|
226 |
+
self.n_channels = n_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.bilinear = bilinear
|
229 |
+
####
|
230 |
+
self.inc = ResConv(n_channels, 64)
|
231 |
+
self.down1 = Down(64, 128)
|
232 |
+
self.down2 = Down(128, 256)
|
233 |
+
self.down3 = Down(256, 512)
|
234 |
+
factor = 2 if bilinear else 1
|
235 |
+
self.down4 = Down(512, 1024 // factor)
|
236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
239 |
+
self.up4 = Up(128, 64, bilinear)
|
240 |
+
#per pixel multiple channels may exist
|
241 |
+
self.out_mean = OutConv(64, out_channels)
|
242 |
+
#variance will always be a single number for a pixel
|
243 |
+
self.out_alpha = nn.Sequential(
|
244 |
+
OutConv(64, 128),
|
245 |
+
OutConv(128, 1),
|
246 |
+
nn.ReLU()
|
247 |
+
)
|
248 |
+
self.out_beta = nn.Sequential(
|
249 |
+
OutConv(64, 128),
|
250 |
+
OutConv(128, 1),
|
251 |
+
nn.ReLU()
|
252 |
+
)
|
253 |
+
def forward(self, x):
|
254 |
+
x1 = self.inc(x)
|
255 |
+
x2 = self.down1(x1)
|
256 |
+
x3 = self.down2(x2)
|
257 |
+
x4 = self.down3(x3)
|
258 |
+
x5 = self.down4(x4)
|
259 |
+
x = self.up1(x5, x4)
|
260 |
+
x = self.up2(x, x3)
|
261 |
+
x = self.up3(x, x2)
|
262 |
+
x = self.up4(x, x1)
|
263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
264 |
+
self.out_alpha(x), self.out_beta(x)
|
265 |
+
return y_mean, y_alpha, y_beta
|
266 |
+
|
267 |
+
class ResidualBlock(nn.Module):
|
268 |
+
def __init__(self, in_features):
|
269 |
+
super(ResidualBlock, self).__init__()
|
270 |
+
conv_block = [
|
271 |
+
nn.ReflectionPad2d(1),
|
272 |
+
nn.Conv2d(in_features, in_features, 3),
|
273 |
+
nn.InstanceNorm2d(in_features),
|
274 |
+
nn.ReLU(inplace=True),
|
275 |
+
nn.ReflectionPad2d(1),
|
276 |
+
nn.Conv2d(in_features, in_features, 3),
|
277 |
+
nn.InstanceNorm2d(in_features)
|
278 |
+
]
|
279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
280 |
+
def forward(self, x):
|
281 |
+
return x + self.conv_block(x)
|
282 |
+
|
283 |
+
class Generator(nn.Module):
|
284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
285 |
+
super(Generator, self).__init__()
|
286 |
+
# Initial convolution block
|
287 |
+
model = [
|
288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
290 |
+
]
|
291 |
+
# Downsampling
|
292 |
+
in_features = 64
|
293 |
+
out_features = in_features*2
|
294 |
+
for _ in range(2):
|
295 |
+
model += [
|
296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
297 |
+
nn.InstanceNorm2d(out_features),
|
298 |
+
nn.ReLU(inplace=True)
|
299 |
+
]
|
300 |
+
in_features = out_features
|
301 |
+
out_features = in_features*2
|
302 |
+
# Residual blocks
|
303 |
+
for _ in range(n_residual_blocks):
|
304 |
+
model += [ResidualBlock(in_features)]
|
305 |
+
# Upsampling
|
306 |
+
out_features = in_features//2
|
307 |
+
for _ in range(2):
|
308 |
+
model += [
|
309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
310 |
+
nn.InstanceNorm2d(out_features),
|
311 |
+
nn.ReLU(inplace=True)
|
312 |
+
]
|
313 |
+
in_features = out_features
|
314 |
+
out_features = in_features//2
|
315 |
+
# Output layer
|
316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
317 |
+
self.model = nn.Sequential(*model)
|
318 |
+
def forward(self, x):
|
319 |
+
return self.model(x)
|
320 |
+
|
321 |
+
|
322 |
+
class ResnetGenerator(nn.Module):
|
323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
328 |
+
"""Construct a Resnet-based generator
|
329 |
+
Parameters:
|
330 |
+
input_nc (int) -- the number of channels in input images
|
331 |
+
output_nc (int) -- the number of channels in output images
|
332 |
+
ngf (int) -- the number of filters in the last conv layer
|
333 |
+
norm_layer -- normalization layer
|
334 |
+
use_dropout (bool) -- if use dropout layers
|
335 |
+
n_blocks (int) -- the number of ResNet blocks
|
336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
337 |
+
"""
|
338 |
+
assert(n_blocks >= 0)
|
339 |
+
super(ResnetGenerator, self).__init__()
|
340 |
+
if type(norm_layer) == functools.partial:
|
341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
342 |
+
else:
|
343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
344 |
+
|
345 |
+
model = [nn.ReflectionPad2d(3),
|
346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
347 |
+
norm_layer(ngf),
|
348 |
+
nn.ReLU(True)]
|
349 |
+
|
350 |
+
n_downsampling = 2
|
351 |
+
for i in range(n_downsampling): # add downsampling layers
|
352 |
+
mult = 2 ** i
|
353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
354 |
+
norm_layer(ngf * mult * 2),
|
355 |
+
nn.ReLU(True)]
|
356 |
+
|
357 |
+
mult = 2 ** n_downsampling
|
358 |
+
for i in range(n_blocks): # add ResNet blocks
|
359 |
+
|
360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
361 |
+
|
362 |
+
for i in range(n_downsampling): # add upsampling layers
|
363 |
+
mult = 2 ** (n_downsampling - i)
|
364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
365 |
+
kernel_size=3, stride=2,
|
366 |
+
padding=1, output_padding=1,
|
367 |
+
bias=use_bias),
|
368 |
+
norm_layer(int(ngf * mult / 2)),
|
369 |
+
nn.ReLU(True)]
|
370 |
+
model += [nn.ReflectionPad2d(3)]
|
371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
372 |
+
model += [nn.Tanh()]
|
373 |
+
|
374 |
+
self.model = nn.Sequential(*model)
|
375 |
+
|
376 |
+
def forward(self, input):
|
377 |
+
"""Standard forward"""
|
378 |
+
return self.model(input)
|
379 |
+
|
380 |
+
|
381 |
+
class ResnetBlock(nn.Module):
|
382 |
+
"""Define a Resnet block"""
|
383 |
+
|
384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
385 |
+
"""Initialize the Resnet block
|
386 |
+
A resnet block is a conv block with skip connections
|
387 |
+
We construct a conv block with build_conv_block function,
|
388 |
+
and implement skip connections in <forward> function.
|
389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
390 |
+
"""
|
391 |
+
super(ResnetBlock, self).__init__()
|
392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
393 |
+
|
394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
395 |
+
"""Construct a convolutional block.
|
396 |
+
Parameters:
|
397 |
+
dim (int) -- the number of channels in the conv layer.
|
398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
399 |
+
norm_layer -- normalization layer
|
400 |
+
use_dropout (bool) -- if use dropout layers.
|
401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
403 |
+
"""
|
404 |
+
conv_block = []
|
405 |
+
p = 0
|
406 |
+
if padding_type == 'reflect':
|
407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
408 |
+
elif padding_type == 'replicate':
|
409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
410 |
+
elif padding_type == 'zero':
|
411 |
+
p = 1
|
412 |
+
else:
|
413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
414 |
+
|
415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
416 |
+
if use_dropout:
|
417 |
+
conv_block += [nn.Dropout(0.5)]
|
418 |
+
|
419 |
+
p = 0
|
420 |
+
if padding_type == 'reflect':
|
421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
422 |
+
elif padding_type == 'replicate':
|
423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
424 |
+
elif padding_type == 'zero':
|
425 |
+
p = 1
|
426 |
+
else:
|
427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
429 |
+
|
430 |
+
return nn.Sequential(*conv_block)
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
"""Forward function (with skip connections)"""
|
434 |
+
out = x + self.conv_block(x) # add skip connections
|
435 |
+
return out
|
436 |
+
|
437 |
+
### discriminator
|
438 |
+
class NLayerDiscriminator(nn.Module):
|
439 |
+
"""Defines a PatchGAN discriminator"""
|
440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
441 |
+
"""Construct a PatchGAN discriminator
|
442 |
+
Parameters:
|
443 |
+
input_nc (int) -- the number of channels in input images
|
444 |
+
ndf (int) -- the number of filters in the last conv layer
|
445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
446 |
+
norm_layer -- normalization layer
|
447 |
+
"""
|
448 |
+
super(NLayerDiscriminator, self).__init__()
|
449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
451 |
+
else:
|
452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
453 |
+
kw = 4
|
454 |
+
padw = 1
|
455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
456 |
+
nf_mult = 1
|
457 |
+
nf_mult_prev = 1
|
458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
459 |
+
nf_mult_prev = nf_mult
|
460 |
+
nf_mult = min(2 ** n, 8)
|
461 |
+
sequence += [
|
462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
463 |
+
norm_layer(ndf * nf_mult),
|
464 |
+
nn.LeakyReLU(0.2, True)
|
465 |
+
]
|
466 |
+
nf_mult_prev = nf_mult
|
467 |
+
nf_mult = min(2 ** n_layers, 8)
|
468 |
+
sequence += [
|
469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
470 |
+
norm_layer(ndf * nf_mult),
|
471 |
+
nn.LeakyReLU(0.2, True)
|
472 |
+
]
|
473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
474 |
+
self.model = nn.Sequential(*sequence)
|
475 |
+
def forward(self, input):
|
476 |
+
"""Standard forward."""
|
477 |
+
return self.model(input)
|
src/utils.py
ADDED
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Optional
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
from tqdm import tqdm
|
9 |
+
import kornia
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
import albumentations as albu
|
13 |
+
import functools
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch import Tensor
|
19 |
+
import torchvision as tv
|
20 |
+
import torchvision.models as models
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision.transforms import functional as F
|
23 |
+
from losses import TempCombLoss
|
24 |
+
|
25 |
+
########### DeblurGAN function
|
26 |
+
def get_norm_layer(norm_type='instance'):
|
27 |
+
if norm_type == 'batch':
|
28 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
29 |
+
elif norm_type == 'instance':
|
30 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
33 |
+
return norm_layer
|
34 |
+
|
35 |
+
def _array_to_batch(x):
|
36 |
+
x = np.transpose(x, (2, 0, 1))
|
37 |
+
x = np.expand_dims(x, 0)
|
38 |
+
return torch.from_numpy(x)
|
39 |
+
|
40 |
+
def get_normalize():
|
41 |
+
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
42 |
+
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
43 |
+
|
44 |
+
def process(a, b):
|
45 |
+
r = normalize(image=a, target=b)
|
46 |
+
return r['image'], r['target']
|
47 |
+
|
48 |
+
return process
|
49 |
+
|
50 |
+
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
|
51 |
+
x, _ = get_normalize()(x, x)
|
52 |
+
if mask is None:
|
53 |
+
mask = np.ones_like(x, dtype=np.float32)
|
54 |
+
else:
|
55 |
+
mask = np.round(mask.astype('float32') / 255)
|
56 |
+
|
57 |
+
h, w, _ = x.shape
|
58 |
+
block_size = 32
|
59 |
+
min_height = (h // block_size + 1) * block_size
|
60 |
+
min_width = (w // block_size + 1) * block_size
|
61 |
+
|
62 |
+
pad_params = {'mode': 'constant',
|
63 |
+
'constant_values': 0,
|
64 |
+
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
65 |
+
}
|
66 |
+
x = np.pad(x, **pad_params)
|
67 |
+
mask = np.pad(mask, **pad_params)
|
68 |
+
|
69 |
+
return map(_array_to_batch, (x, mask)), h, w
|
70 |
+
|
71 |
+
def postprocess(x: torch.Tensor) -> np.ndarray:
|
72 |
+
x, = x
|
73 |
+
x = x.detach().cpu().float().numpy()
|
74 |
+
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
75 |
+
return x.astype('uint8')
|
76 |
+
|
77 |
+
def sorted_glob(pattern):
|
78 |
+
return sorted(glob(pattern))
|
79 |
+
###########
|
80 |
+
|
81 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
82 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
83 |
+
Args:
|
84 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
85 |
+
Returns:
|
86 |
+
Normalized image data. Data range [0, 1].
|
87 |
+
"""
|
88 |
+
return image.astype(np.float64) / 255.0
|
89 |
+
|
90 |
+
|
91 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
92 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
93 |
+
Args:
|
94 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
95 |
+
Returns:
|
96 |
+
Denormalized image data. Data range [0, 255].
|
97 |
+
"""
|
98 |
+
return image.astype(np.float64) * 255.0
|
99 |
+
|
100 |
+
|
101 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
102 |
+
"""Convert ``PIL.Image`` to Tensor.
|
103 |
+
Args:
|
104 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
105 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
106 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
107 |
+
Returns:
|
108 |
+
Normalized image data
|
109 |
+
Examples:
|
110 |
+
>>> image = Image.open("image.bmp")
|
111 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
112 |
+
"""
|
113 |
+
tensor = F.to_tensor(image)
|
114 |
+
|
115 |
+
if range_norm:
|
116 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
117 |
+
if half:
|
118 |
+
tensor = tensor.half()
|
119 |
+
|
120 |
+
return tensor
|
121 |
+
|
122 |
+
|
123 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
124 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
125 |
+
Args:
|
126 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
127 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
128 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
129 |
+
Returns:
|
130 |
+
Convert image data to support PIL library
|
131 |
+
Examples:
|
132 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
133 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
134 |
+
"""
|
135 |
+
if range_norm:
|
136 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
137 |
+
if half:
|
138 |
+
tensor = tensor.half()
|
139 |
+
|
140 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
141 |
+
|
142 |
+
return image
|
143 |
+
|
144 |
+
|
145 |
+
def convert_rgb_to_y(image: Any) -> Any:
|
146 |
+
"""Convert RGB image or tensor image data to YCbCr(Y) format.
|
147 |
+
Args:
|
148 |
+
image: RGB image data read by ``PIL.Image''.
|
149 |
+
Returns:
|
150 |
+
Y image array data.
|
151 |
+
"""
|
152 |
+
if type(image) == np.ndarray:
|
153 |
+
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
154 |
+
elif type(image) == torch.Tensor:
|
155 |
+
if len(image.shape) == 4:
|
156 |
+
image = image.squeeze_(0)
|
157 |
+
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
158 |
+
else:
|
159 |
+
raise Exception("Unknown Type", type(image))
|
160 |
+
|
161 |
+
|
162 |
+
def convert_rgb_to_ycbcr(image: Any) -> Any:
|
163 |
+
"""Convert RGB image or tensor image data to YCbCr format.
|
164 |
+
Args:
|
165 |
+
image: RGB image data read by ``PIL.Image''.
|
166 |
+
Returns:
|
167 |
+
YCbCr image array data.
|
168 |
+
"""
|
169 |
+
if type(image) == np.ndarray:
|
170 |
+
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
171 |
+
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
|
172 |
+
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
|
173 |
+
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
174 |
+
elif type(image) == torch.Tensor:
|
175 |
+
if len(image.shape) == 4:
|
176 |
+
image = image.squeeze(0)
|
177 |
+
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
178 |
+
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
|
179 |
+
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
|
180 |
+
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
181 |
+
else:
|
182 |
+
raise Exception("Unknown Type", type(image))
|
183 |
+
|
184 |
+
|
185 |
+
def convert_ycbcr_to_rgb(image: Any) -> Any:
|
186 |
+
"""Convert YCbCr format image to RGB format.
|
187 |
+
Args:
|
188 |
+
image: YCbCr image data read by ``PIL.Image''.
|
189 |
+
Returns:
|
190 |
+
RGB image array data.
|
191 |
+
"""
|
192 |
+
if type(image) == np.ndarray:
|
193 |
+
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
|
194 |
+
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
|
195 |
+
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
|
196 |
+
return np.array([r, g, b]).transpose([1, 2, 0])
|
197 |
+
elif type(image) == torch.Tensor:
|
198 |
+
if len(image.shape) == 4:
|
199 |
+
image = image.squeeze(0)
|
200 |
+
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
|
201 |
+
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
|
202 |
+
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
|
203 |
+
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
204 |
+
else:
|
205 |
+
raise Exception("Unknown Type", type(image))
|
206 |
+
|
207 |
+
|
208 |
+
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
209 |
+
"""Cut ``PIL.Image`` in the center area of the image.
|
210 |
+
Args:
|
211 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
212 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
213 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
214 |
+
upscale_factor (int): magnification factor.
|
215 |
+
Returns:
|
216 |
+
Randomly cropped low-resolution images and high-resolution images.
|
217 |
+
"""
|
218 |
+
w, h = hr.size
|
219 |
+
|
220 |
+
left = (w - image_size) // 2
|
221 |
+
top = (h - image_size) // 2
|
222 |
+
right = left + image_size
|
223 |
+
bottom = top + image_size
|
224 |
+
|
225 |
+
lr = lr.crop((left // upscale_factor,
|
226 |
+
top // upscale_factor,
|
227 |
+
right // upscale_factor,
|
228 |
+
bottom // upscale_factor))
|
229 |
+
hr = hr.crop((left, top, right, bottom))
|
230 |
+
|
231 |
+
return lr, hr
|
232 |
+
|
233 |
+
|
234 |
+
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
235 |
+
"""Will ``PIL.Image`` randomly capture the specified area of the image.
|
236 |
+
Args:
|
237 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
238 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
239 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
240 |
+
upscale_factor (int): magnification factor.
|
241 |
+
Returns:
|
242 |
+
Randomly cropped low-resolution images and high-resolution images.
|
243 |
+
"""
|
244 |
+
w, h = hr.size
|
245 |
+
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
|
246 |
+
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
|
247 |
+
right = left + image_size
|
248 |
+
bottom = top + image_size
|
249 |
+
|
250 |
+
lr = lr.crop((left // upscale_factor,
|
251 |
+
top // upscale_factor,
|
252 |
+
right // upscale_factor,
|
253 |
+
bottom // upscale_factor))
|
254 |
+
hr = hr.crop((left, top, right, bottom))
|
255 |
+
|
256 |
+
return lr, hr
|
257 |
+
|
258 |
+
|
259 |
+
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
|
260 |
+
"""Will ``PIL.Image`` randomly rotate the image.
|
261 |
+
Args:
|
262 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
263 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
264 |
+
angle (int): rotation angle, clockwise and counterclockwise rotation.
|
265 |
+
Returns:
|
266 |
+
Randomly rotated low-resolution images and high-resolution images.
|
267 |
+
"""
|
268 |
+
angle = random.choice((+angle, -angle))
|
269 |
+
lr = F.rotate(lr, angle)
|
270 |
+
hr = F.rotate(hr, angle)
|
271 |
+
|
272 |
+
return lr, hr
|
273 |
+
|
274 |
+
|
275 |
+
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
276 |
+
"""Flip the ``PIL.Image`` image horizontally randomly.
|
277 |
+
Args:
|
278 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
279 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
280 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
281 |
+
Returns:
|
282 |
+
Low-resolution image and high-resolution image after random horizontal flip.
|
283 |
+
"""
|
284 |
+
if torch.rand(1).item() > p:
|
285 |
+
lr = F.hflip(lr)
|
286 |
+
hr = F.hflip(hr)
|
287 |
+
|
288 |
+
return lr, hr
|
289 |
+
|
290 |
+
|
291 |
+
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
292 |
+
"""Turn the ``PIL.Image`` image upside down randomly.
|
293 |
+
Args:
|
294 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
295 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
296 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
297 |
+
Returns:
|
298 |
+
Randomly rotated up and down low-resolution images and high-resolution images.
|
299 |
+
"""
|
300 |
+
if torch.rand(1).item() > p:
|
301 |
+
lr = F.vflip(lr)
|
302 |
+
hr = F.vflip(hr)
|
303 |
+
|
304 |
+
return lr, hr
|
305 |
+
|
306 |
+
|
307 |
+
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
|
308 |
+
"""Set ``PIL.Image`` to randomly adjust the image brightness.
|
309 |
+
Args:
|
310 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
311 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
312 |
+
Returns:
|
313 |
+
Low-resolution image and high-resolution image with randomly adjusted brightness.
|
314 |
+
"""
|
315 |
+
# Randomly adjust the brightness gain range.
|
316 |
+
factor = random.uniform(0.5, 2)
|
317 |
+
lr = F.adjust_brightness(lr, factor)
|
318 |
+
hr = F.adjust_brightness(hr, factor)
|
319 |
+
|
320 |
+
return lr, hr
|
321 |
+
|
322 |
+
|
323 |
+
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
|
324 |
+
"""Set ``PIL.Image`` to randomly adjust the image contrast.
|
325 |
+
Args:
|
326 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
327 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
328 |
+
Returns:
|
329 |
+
Low-resolution image and high-resolution image with randomly adjusted contrast.
|
330 |
+
"""
|
331 |
+
# Randomly adjust the contrast gain range.
|
332 |
+
factor = random.uniform(0.5, 2)
|
333 |
+
lr = F.adjust_contrast(lr, factor)
|
334 |
+
hr = F.adjust_contrast(hr, factor)
|
335 |
+
|
336 |
+
return lr, hr
|
337 |
+
|
338 |
+
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
|
339 |
+
def img_mae(x1, x2):
|
340 |
+
m = torch.abs(x1-x2).mean()
|
341 |
+
return m
|
342 |
+
|
343 |
+
def img_mse(x1, x2):
|
344 |
+
m = torch.pow(torch.abs(x1-x2),2).mean()
|
345 |
+
return m
|
346 |
+
|
347 |
+
def img_psnr(x1, x2):
|
348 |
+
m = kornia.metrics.psnr(x1, x2, 1)
|
349 |
+
return m
|
350 |
+
|
351 |
+
def img_ssim(x1, x2):
|
352 |
+
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
|
353 |
+
m = m.mean()
|
354 |
+
return m
|
355 |
+
|
356 |
+
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
|
357 |
+
'''
|
358 |
+
xLR/SR/HR: 3xHxW
|
359 |
+
xSRvar: 1xHxW
|
360 |
+
'''
|
361 |
+
plt.figure(figsize=(30,10))
|
362 |
+
|
363 |
+
plt.subplot(1,5,1)
|
364 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
365 |
+
plt.axis('off')
|
366 |
+
|
367 |
+
plt.subplot(1,5,2)
|
368 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
369 |
+
plt.axis('off')
|
370 |
+
|
371 |
+
plt.subplot(1,5,3)
|
372 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
373 |
+
plt.axis('off')
|
374 |
+
|
375 |
+
plt.subplot(1,5,4)
|
376 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
377 |
+
print('error', error_map.min(), error_map.max())
|
378 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
379 |
+
plt.clim(elim[0], elim[1])
|
380 |
+
plt.axis('off')
|
381 |
+
|
382 |
+
plt.subplot(1,5,5)
|
383 |
+
print('uncer', xSRvar.min(), xSRvar.max())
|
384 |
+
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
385 |
+
plt.clim(ulim[0], ulim[1])
|
386 |
+
plt.axis('off')
|
387 |
+
|
388 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
389 |
+
plt.show()
|
390 |
+
|
391 |
+
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
|
392 |
+
'''
|
393 |
+
xLR/SR/HR: 3xHxW
|
394 |
+
'''
|
395 |
+
plt.figure(figsize=(30,10))
|
396 |
+
|
397 |
+
if task != 'm':
|
398 |
+
plt.subplot(1,4,1)
|
399 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
400 |
+
plt.axis('off')
|
401 |
+
|
402 |
+
plt.subplot(1,4,2)
|
403 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
404 |
+
plt.axis('off')
|
405 |
+
|
406 |
+
plt.subplot(1,4,3)
|
407 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
408 |
+
plt.axis('off')
|
409 |
+
else:
|
410 |
+
plt.subplot(1,4,1)
|
411 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
412 |
+
plt.clim(0,0.9)
|
413 |
+
plt.axis('off')
|
414 |
+
|
415 |
+
plt.subplot(1,4,2)
|
416 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
417 |
+
plt.clim(0,0.9)
|
418 |
+
plt.axis('off')
|
419 |
+
|
420 |
+
plt.subplot(1,4,3)
|
421 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
422 |
+
plt.clim(0,0.9)
|
423 |
+
plt.axis('off')
|
424 |
+
|
425 |
+
plt.subplot(1,4,4)
|
426 |
+
if task == 'inpainting':
|
427 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
|
428 |
+
else:
|
429 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
430 |
+
print('error', error_map.min(), error_map.max())
|
431 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
432 |
+
plt.clim(elim[0], elim[1])
|
433 |
+
plt.axis('off')
|
434 |
+
|
435 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
436 |
+
plt.show()
|
437 |
+
|
438 |
+
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
|
439 |
+
'''
|
440 |
+
xSRvar: 1xHxW
|
441 |
+
'''
|
442 |
+
plt.figure(figsize=(30,10))
|
443 |
+
|
444 |
+
plt.subplot(1,4,1)
|
445 |
+
print('uncer', xSRvar1.min(), xSRvar1.max())
|
446 |
+
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
447 |
+
plt.clim(ulim[0], ulim[1])
|
448 |
+
plt.axis('off')
|
449 |
+
|
450 |
+
plt.subplot(1,4,2)
|
451 |
+
print('uncer', xSRvar2.min(), xSRvar2.max())
|
452 |
+
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
453 |
+
plt.clim(ulim[0], ulim[1])
|
454 |
+
plt.axis('off')
|
455 |
+
|
456 |
+
plt.subplot(1,4,3)
|
457 |
+
print('uncer', xSRvar3.min(), xSRvar3.max())
|
458 |
+
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
459 |
+
plt.clim(ulim[0], ulim[1])
|
460 |
+
plt.axis('off')
|
461 |
+
|
462 |
+
plt.subplot(1,4,4)
|
463 |
+
print('uncer', xSRvar4.min(), xSRvar4.max())
|
464 |
+
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
465 |
+
plt.clim(ulim[0], ulim[1])
|
466 |
+
plt.axis('off')
|
467 |
+
|
468 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
469 |
+
plt.show()
|
470 |
+
|
471 |
+
def get_UCE(list_err, list_yout_var, num_bins=100):
|
472 |
+
err_min = np.min(list_err)
|
473 |
+
err_max = np.max(list_err)
|
474 |
+
err_len = (err_max-err_min)/num_bins
|
475 |
+
num_points = len(list_err)
|
476 |
+
|
477 |
+
bin_stats = {}
|
478 |
+
for i in range(num_bins):
|
479 |
+
bin_stats[i] = {
|
480 |
+
'start_idx': err_min + i*err_len,
|
481 |
+
'end_idx': err_min + (i+1)*err_len,
|
482 |
+
'num_points': 0,
|
483 |
+
'mean_err': 0,
|
484 |
+
'mean_var': 0,
|
485 |
+
}
|
486 |
+
|
487 |
+
for e,v in zip(list_err, list_yout_var):
|
488 |
+
for i in range(num_bins):
|
489 |
+
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
|
490 |
+
bin_stats[i]['num_points'] += 1
|
491 |
+
bin_stats[i]['mean_err'] += e
|
492 |
+
bin_stats[i]['mean_var'] += v
|
493 |
+
|
494 |
+
uce = 0
|
495 |
+
eps = 1e-8
|
496 |
+
for i in range(num_bins):
|
497 |
+
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
|
498 |
+
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
|
499 |
+
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
|
500 |
+
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
|
501 |
+
uce += bin_stats[i]['uce_bin']
|
502 |
+
|
503 |
+
list_x, list_y = [], []
|
504 |
+
for i in range(num_bins):
|
505 |
+
if bin_stats[i]['num_points']>0:
|
506 |
+
list_x.append(bin_stats[i]['mean_err'])
|
507 |
+
list_y.append(bin_stats[i]['mean_var'])
|
508 |
+
|
509 |
+
# sns.set_style('darkgrid')
|
510 |
+
# sns.scatterplot(x=list_x, y=list_y)
|
511 |
+
# sns.regplot(x=list_x, y=list_y, order=1)
|
512 |
+
# plt.xlabel('MSE', fontsize=34)
|
513 |
+
# plt.ylabel('Uncertainty', fontsize=34)
|
514 |
+
# plt.plot(list_x, list_x, color='r')
|
515 |
+
# plt.xlim(np.min(list_x), np.max(list_x))
|
516 |
+
# plt.ylim(np.min(list_err), np.max(list_x))
|
517 |
+
# plt.show()
|
518 |
+
|
519 |
+
return bin_stats, uce
|
520 |
+
|
521 |
+
##################### training BayesCap
|
522 |
+
def train_BayesCap(
|
523 |
+
NetC,
|
524 |
+
NetG,
|
525 |
+
train_loader,
|
526 |
+
eval_loader,
|
527 |
+
Cri = TempCombLoss(),
|
528 |
+
device='cuda',
|
529 |
+
dtype=torch.cuda.FloatTensor(),
|
530 |
+
init_lr=1e-4,
|
531 |
+
num_epochs=100,
|
532 |
+
eval_every=1,
|
533 |
+
ckpt_path='../ckpt/BayesCap',
|
534 |
+
T1=1e0,
|
535 |
+
T2=5e-2,
|
536 |
+
task=None,
|
537 |
+
):
|
538 |
+
NetC.to(device)
|
539 |
+
NetC.train()
|
540 |
+
NetG.to(device)
|
541 |
+
NetG.eval()
|
542 |
+
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
|
543 |
+
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
|
544 |
+
|
545 |
+
score = -1e8
|
546 |
+
all_loss = []
|
547 |
+
for eph in range(num_epochs):
|
548 |
+
eph_loss = 0
|
549 |
+
with tqdm(train_loader, unit='batch') as tepoch:
|
550 |
+
for (idx, batch) in enumerate(tepoch):
|
551 |
+
if idx>2000:
|
552 |
+
break
|
553 |
+
tepoch.set_description('Epoch {}'.format(eph))
|
554 |
+
##
|
555 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
556 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
557 |
+
if task == 'inpainting':
|
558 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
559 |
+
xMask = xMask.to(device).type(dtype)
|
560 |
+
# pass them through the network
|
561 |
+
with torch.no_grad():
|
562 |
+
if task == 'inpainting':
|
563 |
+
_, xSR1 = NetG(xLR, xMask)
|
564 |
+
elif task == 'depth':
|
565 |
+
xSR1 = NetG(xLR)[("disp", 0)]
|
566 |
+
else:
|
567 |
+
xSR1 = NetG(xLR)
|
568 |
+
# with torch.autograd.set_detect_anomaly(True):
|
569 |
+
xSR = xSR1.clone()
|
570 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
571 |
+
# print(xSRC_alpha)
|
572 |
+
optimizer.zero_grad()
|
573 |
+
if task == 'depth':
|
574 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
|
575 |
+
else:
|
576 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
|
577 |
+
# print(loss)
|
578 |
+
loss.backward()
|
579 |
+
optimizer.step()
|
580 |
+
##
|
581 |
+
eph_loss += loss.item()
|
582 |
+
tepoch.set_postfix(loss=loss.item())
|
583 |
+
eph_loss /= len(train_loader)
|
584 |
+
all_loss.append(eph_loss)
|
585 |
+
print('Avg. loss: {}'.format(eph_loss))
|
586 |
+
# evaluate and save the models
|
587 |
+
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
|
588 |
+
if eph%eval_every == 0:
|
589 |
+
curr_score = eval_BayesCap(
|
590 |
+
NetC,
|
591 |
+
NetG,
|
592 |
+
eval_loader,
|
593 |
+
device=device,
|
594 |
+
dtype=dtype,
|
595 |
+
task=task,
|
596 |
+
)
|
597 |
+
print('current score: {} | Last best score: {}'.format(curr_score, score))
|
598 |
+
if curr_score >= score:
|
599 |
+
score = curr_score
|
600 |
+
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
|
601 |
+
optim_scheduler.step()
|
602 |
+
|
603 |
+
#### get different uncertainty maps
|
604 |
+
def get_uncer_BayesCap(
|
605 |
+
NetC,
|
606 |
+
NetG,
|
607 |
+
xin,
|
608 |
+
task=None,
|
609 |
+
xMask=None,
|
610 |
+
):
|
611 |
+
with torch.no_grad():
|
612 |
+
if task == 'inpainting':
|
613 |
+
_, xSR = NetG(xin, xMask)
|
614 |
+
else:
|
615 |
+
xSR = NetG(xin)
|
616 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
617 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
618 |
+
b_map = xSRC_beta.to('cpu').data
|
619 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
620 |
+
|
621 |
+
return xSRvar
|
622 |
+
|
623 |
+
def get_uncer_TTDAp(
|
624 |
+
NetG,
|
625 |
+
xin,
|
626 |
+
p_mag=0.05,
|
627 |
+
num_runs=50,
|
628 |
+
task=None,
|
629 |
+
xMask=None,
|
630 |
+
):
|
631 |
+
list_xSR = []
|
632 |
+
with torch.no_grad():
|
633 |
+
for z in range(num_runs):
|
634 |
+
if task == 'inpainting':
|
635 |
+
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
|
636 |
+
else:
|
637 |
+
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
|
638 |
+
list_xSR.append(xSRz)
|
639 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
640 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
641 |
+
return xSRvar
|
642 |
+
|
643 |
+
def get_uncer_DO(
|
644 |
+
NetG,
|
645 |
+
xin,
|
646 |
+
dop=0.2,
|
647 |
+
num_runs=50,
|
648 |
+
task=None,
|
649 |
+
xMask=None,
|
650 |
+
):
|
651 |
+
list_xSR = []
|
652 |
+
with torch.no_grad():
|
653 |
+
for z in range(num_runs):
|
654 |
+
if task == 'inpainting':
|
655 |
+
_, xSRz = NetG(xin, xMask, dop=dop)
|
656 |
+
else:
|
657 |
+
xSRz = NetG(xin, dop=dop)
|
658 |
+
list_xSR.append(xSRz)
|
659 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
660 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
661 |
+
return xSRvar
|
662 |
+
|
663 |
+
################### Different eval functions
|
664 |
+
|
665 |
+
def eval_BayesCap(
|
666 |
+
NetC,
|
667 |
+
NetG,
|
668 |
+
eval_loader,
|
669 |
+
device='cuda',
|
670 |
+
dtype=torch.cuda.FloatTensor,
|
671 |
+
task=None,
|
672 |
+
xMask=None,
|
673 |
+
):
|
674 |
+
NetC.to(device)
|
675 |
+
NetC.eval()
|
676 |
+
NetG.to(device)
|
677 |
+
NetG.eval()
|
678 |
+
|
679 |
+
mean_ssim = 0
|
680 |
+
mean_psnr = 0
|
681 |
+
mean_mse = 0
|
682 |
+
mean_mae = 0
|
683 |
+
num_imgs = 0
|
684 |
+
list_error = []
|
685 |
+
list_var = []
|
686 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
687 |
+
for (idx, batch) in enumerate(tepoch):
|
688 |
+
tepoch.set_description('Validating ...')
|
689 |
+
##
|
690 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
691 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
692 |
+
if task == 'inpainting':
|
693 |
+
if xMask==None:
|
694 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
695 |
+
xMask = xMask.to(device).type(dtype)
|
696 |
+
else:
|
697 |
+
xMask = xMask.to(device).type(dtype)
|
698 |
+
# pass them through the network
|
699 |
+
with torch.no_grad():
|
700 |
+
if task == 'inpainting':
|
701 |
+
_, xSR = NetG(xLR, xMask)
|
702 |
+
elif task == 'depth':
|
703 |
+
xSR = NetG(xLR)[("disp", 0)]
|
704 |
+
else:
|
705 |
+
xSR = NetG(xLR)
|
706 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
707 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
708 |
+
b_map = xSRC_beta.to('cpu').data
|
709 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
710 |
+
n_batch = xSRC_mu.shape[0]
|
711 |
+
if task == 'depth':
|
712 |
+
xHR = xSR
|
713 |
+
for j in range(n_batch):
|
714 |
+
num_imgs += 1
|
715 |
+
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
|
716 |
+
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
|
717 |
+
mean_mse += img_mse(xSRC_mu[j], xHR[j])
|
718 |
+
mean_mae += img_mae(xSRC_mu[j], xHR[j])
|
719 |
+
|
720 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
721 |
+
|
722 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
723 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
724 |
+
list_error.extend(list(error_map.numpy()))
|
725 |
+
list_var.extend(list(var_map.numpy()))
|
726 |
+
##
|
727 |
+
mean_ssim /= num_imgs
|
728 |
+
mean_psnr /= num_imgs
|
729 |
+
mean_mse /= num_imgs
|
730 |
+
mean_mae /= num_imgs
|
731 |
+
print(
|
732 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
733 |
+
(
|
734 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
735 |
+
)
|
736 |
+
)
|
737 |
+
# print(len(list_error), len(list_var))
|
738 |
+
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
|
739 |
+
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
|
740 |
+
return mean_ssim
|
741 |
+
|
742 |
+
def eval_TTDA_p(
|
743 |
+
NetG,
|
744 |
+
eval_loader,
|
745 |
+
device='cuda',
|
746 |
+
dtype=torch.cuda.FloatTensor,
|
747 |
+
p_mag=0.05,
|
748 |
+
num_runs=50,
|
749 |
+
task = None,
|
750 |
+
xMask = None,
|
751 |
+
):
|
752 |
+
NetG.to(device)
|
753 |
+
NetG.eval()
|
754 |
+
|
755 |
+
mean_ssim = 0
|
756 |
+
mean_psnr = 0
|
757 |
+
mean_mse = 0
|
758 |
+
mean_mae = 0
|
759 |
+
num_imgs = 0
|
760 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
761 |
+
for (idx, batch) in enumerate(tepoch):
|
762 |
+
tepoch.set_description('Validating ...')
|
763 |
+
##
|
764 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
765 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
766 |
+
# pass them through the network
|
767 |
+
list_xSR = []
|
768 |
+
with torch.no_grad():
|
769 |
+
if task=='inpainting':
|
770 |
+
_, xSR = NetG(xLR, xMask)
|
771 |
+
else:
|
772 |
+
xSR = NetG(xLR)
|
773 |
+
for z in range(num_runs):
|
774 |
+
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
|
775 |
+
list_xSR.append(xSRz)
|
776 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
777 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
778 |
+
n_batch = xSR.shape[0]
|
779 |
+
for j in range(n_batch):
|
780 |
+
num_imgs += 1
|
781 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
782 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
783 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
784 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
785 |
+
|
786 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
787 |
+
|
788 |
+
mean_ssim /= num_imgs
|
789 |
+
mean_psnr /= num_imgs
|
790 |
+
mean_mse /= num_imgs
|
791 |
+
mean_mae /= num_imgs
|
792 |
+
print(
|
793 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
794 |
+
(
|
795 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
796 |
+
)
|
797 |
+
)
|
798 |
+
|
799 |
+
return mean_ssim
|
800 |
+
|
801 |
+
def eval_DO(
|
802 |
+
NetG,
|
803 |
+
eval_loader,
|
804 |
+
device='cuda',
|
805 |
+
dtype=torch.cuda.FloatTensor,
|
806 |
+
dop=0.2,
|
807 |
+
num_runs=50,
|
808 |
+
task=None,
|
809 |
+
xMask=None,
|
810 |
+
):
|
811 |
+
NetG.to(device)
|
812 |
+
NetG.eval()
|
813 |
+
|
814 |
+
mean_ssim = 0
|
815 |
+
mean_psnr = 0
|
816 |
+
mean_mse = 0
|
817 |
+
mean_mae = 0
|
818 |
+
num_imgs = 0
|
819 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
820 |
+
for (idx, batch) in enumerate(tepoch):
|
821 |
+
tepoch.set_description('Validating ...')
|
822 |
+
##
|
823 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
824 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
825 |
+
# pass them through the network
|
826 |
+
list_xSR = []
|
827 |
+
with torch.no_grad():
|
828 |
+
if task == 'inpainting':
|
829 |
+
_, xSR = NetG(xLR, xMask)
|
830 |
+
else:
|
831 |
+
xSR = NetG(xLR)
|
832 |
+
for z in range(num_runs):
|
833 |
+
xSRz = NetG(xLR, dop=dop)
|
834 |
+
list_xSR.append(xSRz)
|
835 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
836 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
837 |
+
n_batch = xSR.shape[0]
|
838 |
+
for j in range(n_batch):
|
839 |
+
num_imgs += 1
|
840 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
841 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
842 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
843 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
844 |
+
|
845 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
846 |
+
##
|
847 |
+
mean_ssim /= num_imgs
|
848 |
+
mean_psnr /= num_imgs
|
849 |
+
mean_mse /= num_imgs
|
850 |
+
mean_mae /= num_imgs
|
851 |
+
print(
|
852 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
853 |
+
(
|
854 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
855 |
+
)
|
856 |
+
)
|
857 |
+
|
858 |
+
return mean_ssim
|
859 |
+
|
860 |
+
|
861 |
+
############### compare all function
|
862 |
+
def compare_all(
|
863 |
+
NetC,
|
864 |
+
NetG,
|
865 |
+
eval_loader,
|
866 |
+
p_mag = 0.05,
|
867 |
+
dop = 0.2,
|
868 |
+
num_runs = 100,
|
869 |
+
device='cuda',
|
870 |
+
dtype=torch.cuda.FloatTensor,
|
871 |
+
task=None,
|
872 |
+
):
|
873 |
+
NetC.to(device)
|
874 |
+
NetC.eval()
|
875 |
+
NetG.to(device)
|
876 |
+
NetG.eval()
|
877 |
+
|
878 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
879 |
+
for (idx, batch) in enumerate(tepoch):
|
880 |
+
tepoch.set_description('Comparing ...')
|
881 |
+
##
|
882 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
883 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
884 |
+
if task == 'inpainting':
|
885 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
886 |
+
xMask = xMask.to(device).type(dtype)
|
887 |
+
# pass them through the network
|
888 |
+
with torch.no_grad():
|
889 |
+
if task == 'inpainting':
|
890 |
+
_, xSR = NetG(xLR, xMask)
|
891 |
+
else:
|
892 |
+
xSR = NetG(xLR)
|
893 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
894 |
+
|
895 |
+
if task == 'inpainting':
|
896 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
|
897 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
|
898 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
|
899 |
+
else:
|
900 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
|
901 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
|
902 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
|
903 |
+
|
904 |
+
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
|
905 |
+
|
906 |
+
n_batch = xSR.shape[0]
|
907 |
+
for j in range(n_batch):
|
908 |
+
if task=='s':
|
909 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j])
|
910 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
911 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
|
912 |
+
if task=='d':
|
913 |
+
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
|
914 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
915 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
916 |
+
if task=='inpainting':
|
917 |
+
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
|
918 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
|
919 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
920 |
+
if task=='m':
|
921 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
|
922 |
+
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
|
923 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
|
924 |
+
|
925 |
+
|
926 |
+
################# Degrading Identity
|
927 |
+
def degrage_BayesCap_p(
|
928 |
+
NetC,
|
929 |
+
NetG,
|
930 |
+
eval_loader,
|
931 |
+
device='cuda',
|
932 |
+
dtype=torch.cuda.FloatTensor,
|
933 |
+
num_runs=50,
|
934 |
+
):
|
935 |
+
NetC.to(device)
|
936 |
+
NetC.eval()
|
937 |
+
NetG.to(device)
|
938 |
+
NetG.eval()
|
939 |
+
|
940 |
+
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
|
941 |
+
list_s = []
|
942 |
+
list_p = []
|
943 |
+
list_u1 = []
|
944 |
+
list_u2 = []
|
945 |
+
list_c = []
|
946 |
+
for p_mag in p_mag_list:
|
947 |
+
mean_ssim = 0
|
948 |
+
mean_psnr = 0
|
949 |
+
mean_mse = 0
|
950 |
+
mean_mae = 0
|
951 |
+
num_imgs = 0
|
952 |
+
list_error = []
|
953 |
+
list_error2 = []
|
954 |
+
list_var = []
|
955 |
+
|
956 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
957 |
+
for (idx, batch) in enumerate(tepoch):
|
958 |
+
tepoch.set_description('Validating ...')
|
959 |
+
##
|
960 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
961 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
962 |
+
# pass them through the network
|
963 |
+
with torch.no_grad():
|
964 |
+
xSR = NetG(xLR)
|
965 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
|
966 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
967 |
+
b_map = xSRC_beta.to('cpu').data
|
968 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
969 |
+
n_batch = xSRC_mu.shape[0]
|
970 |
+
for j in range(n_batch):
|
971 |
+
num_imgs += 1
|
972 |
+
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
|
973 |
+
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
|
974 |
+
mean_mse += img_mse(xSRC_mu[j], xSR[j])
|
975 |
+
mean_mae += img_mae(xSRC_mu[j], xSR[j])
|
976 |
+
|
977 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
978 |
+
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
979 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
980 |
+
list_error.extend(list(error_map.numpy()))
|
981 |
+
list_error2.extend(list(error_map2.numpy()))
|
982 |
+
list_var.extend(list(var_map.numpy()))
|
983 |
+
##
|
984 |
+
mean_ssim /= num_imgs
|
985 |
+
mean_psnr /= num_imgs
|
986 |
+
mean_mse /= num_imgs
|
987 |
+
mean_mae /= num_imgs
|
988 |
+
print(
|
989 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
990 |
+
(
|
991 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
992 |
+
)
|
993 |
+
)
|
994 |
+
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
|
995 |
+
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
|
996 |
+
print('UCE1: ', uce1)
|
997 |
+
print('UCE2: ', uce2)
|
998 |
+
list_s.append(mean_ssim.item())
|
999 |
+
list_p.append(mean_psnr.item())
|
1000 |
+
list_u1.append(uce1)
|
1001 |
+
list_u2.append(uce2)
|
1002 |
+
|
1003 |
+
plt.plot(list_s)
|
1004 |
+
plt.show()
|
1005 |
+
plt.plot(list_p)
|
1006 |
+
plt.show()
|
1007 |
+
|
1008 |
+
plt.plot(list_u1, label='wrt SR output')
|
1009 |
+
plt.plot(list_u2, label='wrt BayesCap output')
|
1010 |
+
plt.legend()
|
1011 |
+
plt.show()
|
1012 |
+
|
1013 |
+
sns.set_style('darkgrid')
|
1014 |
+
fig,ax = plt.subplots()
|
1015 |
+
# make a plot
|
1016 |
+
ax.plot(p_mag_list, list_s, color="red", marker="o")
|
1017 |
+
# set x-axis label
|
1018 |
+
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
|
1019 |
+
# set y-axis label
|
1020 |
+
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
|
1021 |
+
|
1022 |
+
# twin object for two different y-axis on the sample plot
|
1023 |
+
ax2=ax.twinx()
|
1024 |
+
# make a plot with different y-axis using second axis object
|
1025 |
+
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
|
1026 |
+
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
|
1027 |
+
ax2.set_ylabel("UCE", color="green", fontsize=10)
|
1028 |
+
plt.legend(fontsize=10)
|
1029 |
+
plt.tight_layout()
|
1030 |
+
plt.show()
|
1031 |
+
|
1032 |
+
################# DeepFill_v2
|
1033 |
+
|
1034 |
+
# ----------------------------------------
|
1035 |
+
# PATH processing
|
1036 |
+
# ----------------------------------------
|
1037 |
+
def text_readlines(filename):
|
1038 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
1039 |
+
try:
|
1040 |
+
file = open(filename, 'r')
|
1041 |
+
except IOError:
|
1042 |
+
error = []
|
1043 |
+
return error
|
1044 |
+
content = file.readlines()
|
1045 |
+
# This for loop deletes the EOF (like \n)
|
1046 |
+
for i in range(len(content)):
|
1047 |
+
content[i] = content[i][:len(content[i])-1]
|
1048 |
+
file.close()
|
1049 |
+
return content
|
1050 |
+
|
1051 |
+
def savetxt(name, loss_log):
|
1052 |
+
np_loss_log = np.array(loss_log)
|
1053 |
+
np.savetxt(name, np_loss_log)
|
1054 |
+
|
1055 |
+
def get_files(path):
|
1056 |
+
# read a folder, return the complete path
|
1057 |
+
ret = []
|
1058 |
+
for root, dirs, files in os.walk(path):
|
1059 |
+
for filespath in files:
|
1060 |
+
ret.append(os.path.join(root, filespath))
|
1061 |
+
return ret
|
1062 |
+
|
1063 |
+
def get_names(path):
|
1064 |
+
# read a folder, return the image name
|
1065 |
+
ret = []
|
1066 |
+
for root, dirs, files in os.walk(path):
|
1067 |
+
for filespath in files:
|
1068 |
+
ret.append(filespath)
|
1069 |
+
return ret
|
1070 |
+
|
1071 |
+
def text_save(content, filename, mode = 'a'):
|
1072 |
+
# save a list to a txt
|
1073 |
+
# Try to save a list variable in txt file.
|
1074 |
+
file = open(filename, mode)
|
1075 |
+
for i in range(len(content)):
|
1076 |
+
file.write(str(content[i]) + '\n')
|
1077 |
+
file.close()
|
1078 |
+
|
1079 |
+
def check_path(path):
|
1080 |
+
if not os.path.exists(path):
|
1081 |
+
os.makedirs(path)
|
1082 |
+
|
1083 |
+
# ----------------------------------------
|
1084 |
+
# Validation and Sample at training
|
1085 |
+
# ----------------------------------------
|
1086 |
+
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
|
1087 |
+
# Save image one-by-one
|
1088 |
+
for i in range(len(img_list)):
|
1089 |
+
img = img_list[i]
|
1090 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
1091 |
+
img = img * 255
|
1092 |
+
# Process img_copy and do not destroy the data of img
|
1093 |
+
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
|
1094 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
1095 |
+
img_copy = img_copy.astype(np.uint8)
|
1096 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
1097 |
+
# Save to certain path
|
1098 |
+
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
|
1099 |
+
save_img_path = os.path.join(sample_folder, save_img_name)
|
1100 |
+
cv2.imwrite(save_img_path, img_copy)
|
1101 |
+
|
1102 |
+
def psnr(pred, target, pixel_max_cnt = 255):
|
1103 |
+
mse = torch.mul(target - pred, target - pred)
|
1104 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1105 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
1106 |
+
return p
|
1107 |
+
|
1108 |
+
def grey_psnr(pred, target, pixel_max_cnt = 255):
|
1109 |
+
pred = torch.sum(pred, dim = 0)
|
1110 |
+
target = torch.sum(target, dim = 0)
|
1111 |
+
mse = torch.mul(target - pred, target - pred)
|
1112 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1113 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
1114 |
+
return p
|
1115 |
+
|
1116 |
+
def ssim(pred, target):
|
1117 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1118 |
+
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1119 |
+
target = target[0]
|
1120 |
+
pred = pred[0]
|
1121 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
|
1122 |
+
return ssim
|
1123 |
+
|
1124 |
+
## for contextual attention
|
1125 |
+
|
1126 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
1127 |
+
"""
|
1128 |
+
Extract patches from images and put them in the C output dimension.
|
1129 |
+
:param padding:
|
1130 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
1131 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
1132 |
+
each dimension of images
|
1133 |
+
:param strides: [stride_rows, stride_cols]
|
1134 |
+
:param rates: [dilation_rows, dilation_cols]
|
1135 |
+
:return: A Tensor
|
1136 |
+
"""
|
1137 |
+
assert len(images.size()) == 4
|
1138 |
+
assert padding in ['same', 'valid']
|
1139 |
+
batch_size, channel, height, width = images.size()
|
1140 |
+
|
1141 |
+
if padding == 'same':
|
1142 |
+
images = same_padding(images, ksizes, strides, rates)
|
1143 |
+
elif padding == 'valid':
|
1144 |
+
pass
|
1145 |
+
else:
|
1146 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
1147 |
+
Only "same" or "valid" are supported.'.format(padding))
|
1148 |
+
|
1149 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
1150 |
+
dilation=rates,
|
1151 |
+
padding=0,
|
1152 |
+
stride=strides)
|
1153 |
+
patches = unfold(images)
|
1154 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
1155 |
+
|
1156 |
+
def same_padding(images, ksizes, strides, rates):
|
1157 |
+
assert len(images.size()) == 4
|
1158 |
+
batch_size, channel, rows, cols = images.size()
|
1159 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
1160 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
1161 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
1162 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
1163 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
1164 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
1165 |
+
# Pad the input
|
1166 |
+
padding_top = int(padding_rows / 2.)
|
1167 |
+
padding_left = int(padding_cols / 2.)
|
1168 |
+
padding_bottom = padding_rows - padding_top
|
1169 |
+
padding_right = padding_cols - padding_left
|
1170 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
1171 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
1172 |
+
return images
|
1173 |
+
|
1174 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
1175 |
+
if not axis:
|
1176 |
+
axis = range(len(x.shape))
|
1177 |
+
for i in sorted(axis, reverse=True):
|
1178 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
1179 |
+
return x
|
1180 |
+
|
1181 |
+
|
1182 |
+
def reduce_std(x, axis=None, keepdim=False):
|
1183 |
+
if not axis:
|
1184 |
+
axis = range(len(x.shape))
|
1185 |
+
for i in sorted(axis, reverse=True):
|
1186 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
1187 |
+
return x
|
1188 |
+
|
1189 |
+
|
1190 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
1191 |
+
if not axis:
|
1192 |
+
axis = range(len(x.shape))
|
1193 |
+
for i in sorted(axis, reverse=True):
|
1194 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
1195 |
+
return x
|
1196 |
+
|
1197 |
+
def random_mask(num_batch=1, mask_shape=(256,256)):
|
1198 |
+
list_mask = []
|
1199 |
+
for _ in range(num_batch):
|
1200 |
+
# rectangle mask
|
1201 |
+
image_height = mask_shape[0]
|
1202 |
+
image_width = mask_shape[1]
|
1203 |
+
max_delta_height = image_height//8
|
1204 |
+
max_delta_width = image_width//8
|
1205 |
+
height = image_height//4
|
1206 |
+
width = image_width//4
|
1207 |
+
max_t = image_height - height
|
1208 |
+
max_l = image_width - width
|
1209 |
+
t = random.randint(0, max_t)
|
1210 |
+
l = random.randint(0, max_l)
|
1211 |
+
# bbox = (t, l, height, width)
|
1212 |
+
h = random.randint(0, max_delta_height//2)
|
1213 |
+
w = random.randint(0, max_delta_width//2)
|
1214 |
+
mask = torch.zeros((1, 1, image_height, image_width))
|
1215 |
+
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
|
1216 |
+
rect_mask = mask
|
1217 |
+
|
1218 |
+
# brush mask
|
1219 |
+
min_num_vertex = 4
|
1220 |
+
max_num_vertex = 12
|
1221 |
+
mean_angle = 2 * math.pi / 5
|
1222 |
+
angle_range = 2 * math.pi / 15
|
1223 |
+
min_width = 12
|
1224 |
+
max_width = 40
|
1225 |
+
H, W = image_height, image_width
|
1226 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
1227 |
+
mask = Image.new('L', (W, H), 0)
|
1228 |
+
|
1229 |
+
for _ in range(np.random.randint(1, 4)):
|
1230 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
1231 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
1232 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
1233 |
+
angles = []
|
1234 |
+
vertex = []
|
1235 |
+
for i in range(num_vertex):
|
1236 |
+
if i % 2 == 0:
|
1237 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
1238 |
+
else:
|
1239 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
1240 |
+
|
1241 |
+
h, w = mask.size
|
1242 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
1243 |
+
for i in range(num_vertex):
|
1244 |
+
r = np.clip(
|
1245 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
1246 |
+
0, 2*average_radius)
|
1247 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
1248 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
1249 |
+
vertex.append((int(new_x), int(new_y)))
|
1250 |
+
|
1251 |
+
draw = ImageDraw.Draw(mask)
|
1252 |
+
width = int(np.random.uniform(min_width, max_width))
|
1253 |
+
draw.line(vertex, fill=255, width=width)
|
1254 |
+
for v in vertex:
|
1255 |
+
draw.ellipse((v[0] - width//2,
|
1256 |
+
v[1] - width//2,
|
1257 |
+
v[0] + width//2,
|
1258 |
+
v[1] + width//2),
|
1259 |
+
fill=255)
|
1260 |
+
|
1261 |
+
if np.random.normal() > 0:
|
1262 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
1263 |
+
if np.random.normal() > 0:
|
1264 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
1265 |
+
|
1266 |
+
mask = transforms.ToTensor()(mask)
|
1267 |
+
mask = mask.reshape((1, 1, H, W))
|
1268 |
+
brush_mask = mask
|
1269 |
+
|
1270 |
+
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
|
1271 |
+
list_mask.append(mask)
|
1272 |
+
mask = torch.cat(list_mask, dim=0)
|
1273 |
+
return mask
|
utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Optional
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
from tqdm import tqdm
|
9 |
+
import kornia
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
import albumentations as albu
|
13 |
+
import functools
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch import Tensor
|
19 |
+
import torchvision as tv
|
20 |
+
import torchvision.models as models
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision.transforms import functional as F
|
23 |
+
from losses import TempCombLoss
|
24 |
+
|
25 |
+
|
26 |
+
######## for loading checkpoint from googledrive
|
27 |
+
google_drive_paths = {
|
28 |
+
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
|
29 |
+
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
|
30 |
+
}
|
31 |
+
|
32 |
+
def ensure_checkpoint_exists(model_weights_filename):
|
33 |
+
if not os.path.isfile(model_weights_filename) and (
|
34 |
+
model_weights_filename in google_drive_paths
|
35 |
+
):
|
36 |
+
gdrive_url = google_drive_paths[model_weights_filename]
|
37 |
+
try:
|
38 |
+
from gdown import download as drive_download
|
39 |
+
|
40 |
+
drive_download(gdrive_url, model_weights_filename, quiet=False)
|
41 |
+
except ModuleNotFoundError:
|
42 |
+
print(
|
43 |
+
"gdown module not found.",
|
44 |
+
"pip3 install gdown or, manually download the checkpoint file:",
|
45 |
+
gdrive_url
|
46 |
+
)
|
47 |
+
|
48 |
+
if not os.path.isfile(model_weights_filename) and (
|
49 |
+
model_weights_filename not in google_drive_paths
|
50 |
+
):
|
51 |
+
print(
|
52 |
+
model_weights_filename,
|
53 |
+
" not found, you may need to manually download the model weights."
|
54 |
+
)
|
55 |
+
|
56 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
57 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
58 |
+
Args:
|
59 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
60 |
+
Returns:
|
61 |
+
Normalized image data. Data range [0, 1].
|
62 |
+
"""
|
63 |
+
return image.astype(np.float64) / 255.0
|
64 |
+
|
65 |
+
|
66 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
67 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
68 |
+
Args:
|
69 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
70 |
+
Returns:
|
71 |
+
Denormalized image data. Data range [0, 255].
|
72 |
+
"""
|
73 |
+
return image.astype(np.float64) * 255.0
|
74 |
+
|
75 |
+
|
76 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
77 |
+
"""Convert ``PIL.Image`` to Tensor.
|
78 |
+
Args:
|
79 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
80 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
81 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
82 |
+
Returns:
|
83 |
+
Normalized image data
|
84 |
+
Examples:
|
85 |
+
>>> image = Image.open("image.bmp")
|
86 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
87 |
+
"""
|
88 |
+
tensor = F.to_tensor(image)
|
89 |
+
|
90 |
+
if range_norm:
|
91 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
92 |
+
if half:
|
93 |
+
tensor = tensor.half()
|
94 |
+
|
95 |
+
return tensor
|
96 |
+
|
97 |
+
|
98 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
99 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
100 |
+
Args:
|
101 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
102 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
103 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
104 |
+
Returns:
|
105 |
+
Convert image data to support PIL library
|
106 |
+
Examples:
|
107 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
108 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
109 |
+
"""
|
110 |
+
if range_norm:
|
111 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
112 |
+
if half:
|
113 |
+
tensor = tensor.half()
|
114 |
+
|
115 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
116 |
+
|
117 |
+
return image
|