Spaces:
Running
on
Zero
Running
on
Zero
Init project
Browse files- README.md +59 -0
- data/MBD/MBD.py +110 -0
- data/MBD/MBD_utils.py +291 -0
- data/MBD/infer.py +151 -0
- data/MBD/model/__init__.py +50 -0
- data/MBD/model/cbam.py +95 -0
- data/MBD/model/deep_lab_model/__init__.py +0 -0
- data/MBD/model/deep_lab_model/aspp.py +95 -0
- data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- data/MBD/model/deep_lab_model/decoder.py +59 -0
- data/MBD/model/deep_lab_model/deeplab.py +81 -0
- data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- data/MBD/model/densenetccnl.py +382 -0
- data/MBD/model/gienet.py +742 -0
- data/MBD/model/unetnc.py +86 -0
- data/MBD/modify_stn_model/stn_head.py +123 -0
- data/MBD/modify_stn_model/tps_spatial_transformer.py +194 -0
- data/MBD/stn_model/stn_head.py +123 -0
- data/MBD/stn_model/tps_spatial_transformer.py +155 -0
- data/MBD/tps_grid_gen.py +70 -0
- data/MBD/utils.py +234 -0
- data/README.md +135 -0
- data/preprocess/crop_merge_image.py +142 -0
- data/preprocess/sauvola_binarize.py +91 -0
- data/preprocess/shadow_extraction.py +68 -0
- eval.py +369 -0
- inference.py +341 -0
- loaders/docres_loader.py +558 -0
- models/restormer_arch.py +308 -0
- requirements.txt +10 -0
- start_train.sh +1 -0
- train.py +221 -0
- utils.py +464 -0
README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<div align=center>
|
3 |
+
|
4 |
+
# DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks
|
5 |
+
|
6 |
+
</div>
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
<img src="images/motivation.jpg" width="400">
|
10 |
+
</p>
|
11 |
+
|
12 |
+
This is the official implementation of our paper [DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks](https://arxiv.org/abs/2405.04408).
|
13 |
+
|
14 |
+
## News
|
15 |
+
🔥 A comprehensive [Recommendation for Document Image Processing](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing) is available.
|
16 |
+
|
17 |
+
|
18 |
+
## Inference
|
19 |
+
1. Put MBD model weights [mbd.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./data/MBD/checkpoint/`
|
20 |
+
2. Put DocRes model weights [docres.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./checkpoints/`
|
21 |
+
3. Run the following script and the results will be saved in `./restorted/`. We have provided some distorted examples in `./input/`.
|
22 |
+
```bash
|
23 |
+
python inference.py --im_path ./input/for_dewarping.png --task dewarping --save_dtsprompt 1
|
24 |
+
```
|
25 |
+
|
26 |
+
- `--im_path`: the path of input document image
|
27 |
+
- `--task`: task that need to be executed, it must be one of _dewarping_, _deshadowing_, _appearance_, _deblurring_, _binarization_, or _end2end_
|
28 |
+
- `--save_dtsprompt`: whether to save the DTSPrompt
|
29 |
+
|
30 |
+
## Evaluation
|
31 |
+
|
32 |
+
1. Dataset preparation, see [dataset instruction](./data/README.md)
|
33 |
+
2. Put MBD model weights [mbd.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `data/MBD/checkpoint/`
|
34 |
+
3. Put DocRes model weights [docres.pkl](https://1drv.ms/f/s!Ak15mSdV3Wy4iahoKckhDPVP5e2Czw?e=iClwdK) to `./checkpoints/`
|
35 |
+
2. Run the following script
|
36 |
+
```bash
|
37 |
+
python eval.py --dataset realdae
|
38 |
+
```
|
39 |
+
- `--dataset`: dataset that need to be evaluated, it can be set as _dir300_, _kligler_, _jung_, _osr_, _docunet\_docaligner_, _realdae_, _tdd_, and _dibco18_.
|
40 |
+
|
41 |
+
## Training
|
42 |
+
1. Dataset preparation, see [dataset instruction](./data/README.md)
|
43 |
+
2. Specify the datasets_setting within `train.py` based on your dataset path and experimental setting.
|
44 |
+
3. Run the following script
|
45 |
+
```bash
|
46 |
+
bash start_train.sh
|
47 |
+
```
|
48 |
+
|
49 |
+
|
50 |
+
## Citation:
|
51 |
+
```
|
52 |
+
@inproceedings{zhangdocres2024,
|
53 |
+
Author = {Jiaxin Zhang, Dezhi Peng, Chongyu Liu , Peirong Zhang and Lianwen Jin},
|
54 |
+
Booktitle = {In Proceedings of the IEEE/CV Conference on Computer Vision and Pattern Recognition},
|
55 |
+
Title = {DocRes: A Generalist Model Toward Unifying Document Image Restoration Tasks},
|
56 |
+
Year = {2024}}
|
57 |
+
```
|
58 |
+
## ⭐ Star Rising
|
59 |
+
[![Star Rising](https://api.star-history.com/svg?repos=ZZZHANG-jx/DocRes&type=Timeline)](https://star-history.com/#ZZZHANG-jx/DocRes&Timeline)
|
data/MBD/MBD.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import MBD_utils
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def mask_base_dewarper(image,mask):
|
9 |
+
'''
|
10 |
+
input:
|
11 |
+
image -> ndarray HxWx3 uint8
|
12 |
+
mask -> ndarray HxW uint8
|
13 |
+
return
|
14 |
+
dewarped -> ndarray HxWx3 uint8
|
15 |
+
grid (optional) -> ndarray HxWx2 -1~1
|
16 |
+
'''
|
17 |
+
|
18 |
+
## get contours
|
19 |
+
# _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
|
20 |
+
contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
|
21 |
+
|
22 |
+
## get biggest contours and four corners based on Douglas-Peucker algorithm
|
23 |
+
four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
|
24 |
+
four_corners = MBD_utils.reorder(four_corners)
|
25 |
+
|
26 |
+
## reserve biggest contours and remove other noisy contours
|
27 |
+
new_mask = np.zeros_like(mask)
|
28 |
+
new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
|
29 |
+
|
30 |
+
## obtain middle points
|
31 |
+
# ratios = [0.25,0.5,0.75] # ratios = [0.125,0.25,0.375,0.5,0.625,0.75,0.875]
|
32 |
+
ratios = [0.25,0.5,0.75]
|
33 |
+
# ratios = [0.0625,0.125,0.1875,0.25,0.3125,0.375,0.4475,0.5,0.5625,0.625,0.06875,0.75,0.8125,0.875,0.9375]
|
34 |
+
middle = MBD_utils.findMiddle(corners=four_corners,mask=new_mask,points=ratios)
|
35 |
+
|
36 |
+
## all points
|
37 |
+
source_points = np.concatenate((four_corners,middle),axis=0) ## all_point = four_corners(topleft,topright,bottom)+top+bottom+left+right
|
38 |
+
|
39 |
+
## target points
|
40 |
+
h,w = image.shape[:2]
|
41 |
+
padding = 0
|
42 |
+
target_points = [[padding, padding],[w-padding, padding], [padding, h-padding],[w-padding, h-padding]]
|
43 |
+
for ratio in ratios:
|
44 |
+
target_points.append([int((w-2*padding)*ratio)+padding,padding])
|
45 |
+
for ratio in ratios:
|
46 |
+
target_points.append([int((w-2*padding)*ratio)+padding,h-padding])
|
47 |
+
for ratio in ratios:
|
48 |
+
target_points.append([padding,int((h-2*padding)*ratio)+padding])
|
49 |
+
for ratio in ratios:
|
50 |
+
target_points.append([w-padding,int((h-2*padding)*ratio)+padding])
|
51 |
+
|
52 |
+
## dewarp base on cv2
|
53 |
+
# pts1 = np.float32(source_points)
|
54 |
+
# pts2 = np.float32(target_points)
|
55 |
+
# tps = cv2.createThinPlateSplineShapeTransformer()
|
56 |
+
# matches = []
|
57 |
+
# N = pts1.shape[0]
|
58 |
+
# for i in range(0,N):
|
59 |
+
# matches.append(cv2.DMatch(i,i,0))
|
60 |
+
# pts1 = pts1.reshape(1,-1,2)
|
61 |
+
# pts2 = pts2.reshape(1,-1,2)
|
62 |
+
# tps.estimateTransformation(pts2,pts1,matches)
|
63 |
+
# dewarped = tps.warpImage(image)
|
64 |
+
|
65 |
+
## dewarp base on generated grid
|
66 |
+
source_points = source_points.reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
|
67 |
+
source_points = torch.from_numpy(source_points).float().cuda()
|
68 |
+
source_points = source_points.unsqueeze(0)
|
69 |
+
source_points = (source_points-0.5)*2
|
70 |
+
target_points = np.asarray(target_points).reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
|
71 |
+
target_points = torch.from_numpy(target_points).float()
|
72 |
+
target_points = (target_points-0.5)*2
|
73 |
+
|
74 |
+
model = MBD_utils.TPSGridGen(target_height=256,target_width=256,target_control_points=target_points)
|
75 |
+
model = model.cuda()
|
76 |
+
grid = model(source_points).view(-1,256,256,2).permute(0,3,1,2)
|
77 |
+
grid = F.interpolate(grid,(h,w),mode='bilinear').permute(0,2,3,1)
|
78 |
+
dewarped = MBD_utils.torch2cvimg(F.grid_sample(MBD_utils.cvimg2torch(image).cuda(),grid))[0]
|
79 |
+
return dewarped,grid[0].cpu().numpy()
|
80 |
+
|
81 |
+
def mask_base_cropper(image,mask):
|
82 |
+
'''
|
83 |
+
input:
|
84 |
+
image -> ndarray HxWx3 uint8
|
85 |
+
mask -> ndarray HxW uint8
|
86 |
+
return
|
87 |
+
dewarped -> ndarray HxWx3 uint8
|
88 |
+
grid (optional) -> ndarray HxWx2 -1~1
|
89 |
+
'''
|
90 |
+
|
91 |
+
## get contours
|
92 |
+
_, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
|
93 |
+
# contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
|
94 |
+
|
95 |
+
## get biggest contours and four corners based on Douglas-Peucker algorithm
|
96 |
+
four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
|
97 |
+
four_corners = MBD_utils.reorder(four_corners)
|
98 |
+
|
99 |
+
## reserve biggest contours and remove other noisy contours
|
100 |
+
new_mask = np.zeros_like(mask)
|
101 |
+
new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
|
102 |
+
|
103 |
+
## 最小外接矩形
|
104 |
+
rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
|
105 |
+
box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
|
106 |
+
box = np.int0(box)
|
107 |
+
box = box.reshape((4,1,2))
|
108 |
+
|
109 |
+
|
110 |
+
|
data/MBD/MBD_utils.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch
|
6 |
+
import itertools
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.autograd import Function, Variable
|
9 |
+
|
10 |
+
def reorder(myPoints):
|
11 |
+
myPoints = myPoints.reshape((4, 2))
|
12 |
+
myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
|
13 |
+
add = myPoints.sum(1)
|
14 |
+
myPointsNew[0] = myPoints[np.argmin(add)]
|
15 |
+
myPointsNew[3] =myPoints[np.argmax(add)]
|
16 |
+
diff = np.diff(myPoints, axis=1)
|
17 |
+
myPointsNew[1] =myPoints[np.argmin(diff)]
|
18 |
+
myPointsNew[2] = myPoints[np.argmax(diff)]
|
19 |
+
return myPointsNew
|
20 |
+
|
21 |
+
|
22 |
+
def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
|
23 |
+
num_middle_points = len(points)
|
24 |
+
top = [np.array([])]*num_middle_points
|
25 |
+
bottom = [np.array([])]*num_middle_points
|
26 |
+
left = [np.array([])]*num_middle_points
|
27 |
+
right = [np.array([])]*num_middle_points
|
28 |
+
|
29 |
+
center_top = []
|
30 |
+
center_bottom = []
|
31 |
+
center_left = []
|
32 |
+
center_right = []
|
33 |
+
|
34 |
+
center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
|
35 |
+
for ratio in points:
|
36 |
+
|
37 |
+
center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
|
38 |
+
|
39 |
+
center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
|
40 |
+
|
41 |
+
center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
|
42 |
+
|
43 |
+
center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
|
44 |
+
|
45 |
+
for i in range(0,center[0],1):
|
46 |
+
for j in range(num_middle_points):
|
47 |
+
if top[j].size==0:
|
48 |
+
if mask[i,center_top[j][1]]==255:
|
49 |
+
top[j] = np.asarray([center_top[j][1],i])
|
50 |
+
top[j] = top[j].reshape(1,2)
|
51 |
+
|
52 |
+
for i in range(mask.shape[0]-1,center[0],-1):
|
53 |
+
for j in range(num_middle_points):
|
54 |
+
if bottom[j].size==0:
|
55 |
+
if mask[i,center_bottom[j][1]]==255:
|
56 |
+
bottom[j] = np.asarray([center_bottom[j][1],i])
|
57 |
+
bottom[j] = bottom[j].reshape(1,2)
|
58 |
+
|
59 |
+
for i in range(mask.shape[1]-1,center[1],-1):
|
60 |
+
for j in range(num_middle_points):
|
61 |
+
if right[j].size==0:
|
62 |
+
if mask[center_right[j][0],i]==255:
|
63 |
+
right[j] = np.asarray([i,center_right[j][0]])
|
64 |
+
right[j] = right[j].reshape(1,2)
|
65 |
+
|
66 |
+
for i in range(0,center[1]):
|
67 |
+
for j in range(num_middle_points):
|
68 |
+
if left[j].size==0:
|
69 |
+
if mask[center_left[j][0],i]==255:
|
70 |
+
left[j] = np.asarray([i,center_left[j][0]])
|
71 |
+
left[j] = left[j].reshape(1,2)
|
72 |
+
|
73 |
+
return np.asarray(top+bottom+left+right)
|
74 |
+
|
75 |
+
def DP_algorithmv1(contours):
|
76 |
+
biggest = np.array([])
|
77 |
+
max_area = 0
|
78 |
+
step = 0.001
|
79 |
+
count = 0
|
80 |
+
# while biggest.size==0:
|
81 |
+
while True:
|
82 |
+
for i in contours:
|
83 |
+
# print(i.shape)
|
84 |
+
area = cv2.contourArea(i)
|
85 |
+
# print(area,cv2.arcLength(i, True))
|
86 |
+
if area > cv2.arcLength(i, True)*10:
|
87 |
+
peri = cv2.arcLength(i, True)
|
88 |
+
approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
|
89 |
+
if area > max_area and len(approx) == 4:
|
90 |
+
max_area = area
|
91 |
+
biggest_contours = i
|
92 |
+
biggest = approx
|
93 |
+
break
|
94 |
+
if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
|
95 |
+
biggest = np.array([])
|
96 |
+
count += 1
|
97 |
+
if count > 200:
|
98 |
+
break
|
99 |
+
temp = biggest[0]
|
100 |
+
return biggest,max_area, biggest_contours
|
101 |
+
|
102 |
+
def DP_algorithm(contours):
|
103 |
+
biggest = np.array([])
|
104 |
+
max_area = 0
|
105 |
+
step = 0.001
|
106 |
+
count = 0
|
107 |
+
|
108 |
+
### largest contours
|
109 |
+
for i in contours:
|
110 |
+
area = cv2.contourArea(i)
|
111 |
+
if area > max_area:
|
112 |
+
max_area = area
|
113 |
+
biggest_contours = i
|
114 |
+
peri = cv2.arcLength(biggest_contours, True)
|
115 |
+
|
116 |
+
### find four corners
|
117 |
+
while True:
|
118 |
+
approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
|
119 |
+
if len(approx) == 4:
|
120 |
+
biggest = approx
|
121 |
+
break
|
122 |
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
|
123 |
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
|
124 |
+
# biggest = np.array([])
|
125 |
+
count += 1
|
126 |
+
if count > 200:
|
127 |
+
break
|
128 |
+
return biggest,max_area, biggest_contours
|
129 |
+
|
130 |
+
def drawRectangle(img,biggest,color,thickness):
|
131 |
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
132 |
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
133 |
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
134 |
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
135 |
+
return img
|
136 |
+
|
137 |
+
def minAreaRect(contours,img):
|
138 |
+
# biggest = np.array([])
|
139 |
+
max_area = 0
|
140 |
+
for i in contours:
|
141 |
+
area = cv2.contourArea(i)
|
142 |
+
if area > max_area:
|
143 |
+
peri = cv2.arcLength(i, True)
|
144 |
+
rect = cv2.minAreaRect(i)
|
145 |
+
points = cv2.boxPoints(rect)
|
146 |
+
max_area = area
|
147 |
+
return points
|
148 |
+
|
149 |
+
def cropRectangle(img,biggest):
|
150 |
+
# print(biggest)
|
151 |
+
w = np.abs(biggest[0][0][0] - biggest[1][0][0])
|
152 |
+
h = np.abs(biggest[0][0][1] - biggest[2][0][1])
|
153 |
+
new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
|
154 |
+
new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
|
155 |
+
return new_img
|
156 |
+
|
157 |
+
def cvimg2torch(img,min=0,max=1):
|
158 |
+
'''
|
159 |
+
input:
|
160 |
+
im -> ndarray uint8 HxWxC
|
161 |
+
return
|
162 |
+
tensor -> torch.tensor BxCxHxW
|
163 |
+
'''
|
164 |
+
if len(img.shape)==2:
|
165 |
+
img = np.expand_dims(img,axis=-1)
|
166 |
+
img = img.astype(float) / 255.0
|
167 |
+
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
168 |
+
img = np.expand_dims(img, 0)
|
169 |
+
img = torch.from_numpy(img).float()
|
170 |
+
return img
|
171 |
+
|
172 |
+
def torch2cvimg(tensor,min=0,max=1):
|
173 |
+
'''
|
174 |
+
input:
|
175 |
+
tensor -> torch.tensor BxCxHxW C can be 1,3
|
176 |
+
return
|
177 |
+
im -> ndarray uint8 HxWxC
|
178 |
+
'''
|
179 |
+
im_list = []
|
180 |
+
for i in range(tensor.shape[0]):
|
181 |
+
im = tensor.detach().cpu().data.numpy()[i]
|
182 |
+
im = im.transpose(1,2,0)
|
183 |
+
im = np.clip(im,min,max)
|
184 |
+
im = ((im-min)/(max-min)*255).astype(np.uint8)
|
185 |
+
im_list.append(im)
|
186 |
+
return im_list
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
class TPSGridGen(nn.Module):
|
191 |
+
def __init__(self, target_height, target_width, target_control_points):
|
192 |
+
'''
|
193 |
+
target_control_points -> torch.tensor num_pointx2 -1~1
|
194 |
+
source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
|
195 |
+
return:
|
196 |
+
grid -> batch_size x hw x 2 -1~1
|
197 |
+
'''
|
198 |
+
super(TPSGridGen, self).__init__()
|
199 |
+
assert target_control_points.ndimension() == 2
|
200 |
+
assert target_control_points.size(1) == 2
|
201 |
+
N = target_control_points.size(0)
|
202 |
+
self.num_points = N
|
203 |
+
target_control_points = target_control_points.float()
|
204 |
+
|
205 |
+
# create padded kernel matrix
|
206 |
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
207 |
+
target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
|
208 |
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
209 |
+
forward_kernel[:N, -3].fill_(1)
|
210 |
+
forward_kernel[-3, :N].fill_(1)
|
211 |
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
212 |
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
213 |
+
# compute inverse matrix
|
214 |
+
inverse_kernel = torch.inverse(forward_kernel)
|
215 |
+
|
216 |
+
# create target cordinate matrix
|
217 |
+
HW = target_height * target_width
|
218 |
+
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
|
219 |
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
220 |
+
Y, X = target_coordinate.split(1, dim = 1)
|
221 |
+
Y = Y * 2 / (target_height - 1) - 1
|
222 |
+
X = X * 2 / (target_width - 1) - 1
|
223 |
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
224 |
+
target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
|
225 |
+
target_coordinate_repr = torch.cat([
|
226 |
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
227 |
+
], dim = 1)
|
228 |
+
|
229 |
+
# register precomputed matrices
|
230 |
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
231 |
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
232 |
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
233 |
+
|
234 |
+
def forward(self, source_control_points):
|
235 |
+
assert source_control_points.ndimension() == 3
|
236 |
+
assert source_control_points.size(1) == self.num_points
|
237 |
+
assert source_control_points.size(2) == 2
|
238 |
+
batch_size = source_control_points.size(0)
|
239 |
+
|
240 |
+
Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
|
241 |
+
mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
|
242 |
+
source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
|
243 |
+
return source_coordinate
|
244 |
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
245 |
+
def compute_partial_repr(self, input_points, control_points):
|
246 |
+
N = input_points.size(0)
|
247 |
+
M = control_points.size(0)
|
248 |
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
249 |
+
# original implementation, very slow
|
250 |
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
251 |
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
252 |
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
253 |
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
254 |
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
255 |
+
mask = repr_matrix != repr_matrix
|
256 |
+
repr_matrix.masked_fill_(mask, 0)
|
257 |
+
return repr_matrix
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
### deside wheather further process
|
264 |
+
# point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
|
265 |
+
#### 最小外接矩形
|
266 |
+
# rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
|
267 |
+
# box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
|
268 |
+
# box = np.int0(box)
|
269 |
+
# box = box.reshape((4,1,2))
|
270 |
+
# minrect_area = cv2.contourArea(box)
|
271 |
+
# print(abs(minrect_area-point_area)/point_area)
|
272 |
+
#### 四个角点 IOU
|
273 |
+
# biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
|
274 |
+
# biggest_mask = np.zeros_like(mask)
|
275 |
+
# # corner_area = cv2.contourArea(biggest_box)
|
276 |
+
# cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
|
277 |
+
|
278 |
+
# smooth = 1e-5
|
279 |
+
# biggest_mask_ = biggest_mask > 50
|
280 |
+
# mask_ = mask > 50
|
281 |
+
# intersection = (biggest_mask_ & mask_).sum()
|
282 |
+
# union = (biggest_mask_ | mask_).sum()
|
283 |
+
# iou = (intersection + smooth) / (union + smooth)
|
284 |
+
# if iou > 0.975:
|
285 |
+
# skip = True
|
286 |
+
# else:
|
287 |
+
# skip = False
|
288 |
+
# print(iou)
|
289 |
+
# cv2.imshow('mask',cv2.resize(mask,(512,512)))
|
290 |
+
# cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
|
291 |
+
# cv2.waitKey(0)
|
data/MBD/infer.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import glob
|
6 |
+
import cv2
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
import time
|
10 |
+
import os
|
11 |
+
from model.deep_lab_model.deeplab import *
|
12 |
+
from MBD import mask_base_dewarper
|
13 |
+
import time
|
14 |
+
|
15 |
+
from utils import cvimg2torch,torch2cvimg
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def net1_net2_infer(model,img_paths,args):
|
20 |
+
|
21 |
+
### validate on the real datasets
|
22 |
+
seg_model=model
|
23 |
+
seg_model.eval()
|
24 |
+
for img_path in tqdm(img_paths):
|
25 |
+
if os.path.exists(img_path.replace('_origin','_capture')):
|
26 |
+
continue
|
27 |
+
t1 = time.time()
|
28 |
+
### segmentation mask predict
|
29 |
+
img_org = cv2.imread(img_path)
|
30 |
+
h_org,w_org = img_org.shape[:2]
|
31 |
+
img = cv2.resize(img_org,(448, 448))
|
32 |
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
33 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
34 |
+
img = cvimg2torch(img)
|
35 |
+
|
36 |
+
with torch.no_grad():
|
37 |
+
pred = seg_model(img.cuda())
|
38 |
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
39 |
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
40 |
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
41 |
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
42 |
+
kernel = np.ones((3,3))
|
43 |
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
44 |
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
45 |
+
mask_pred[mask_pred>100] = 255
|
46 |
+
mask_pred[mask_pred<100] = 0
|
47 |
+
### tps transform base on the mask
|
48 |
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
49 |
+
try:
|
50 |
+
dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
51 |
+
except:
|
52 |
+
print('fail')
|
53 |
+
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
54 |
+
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
55 |
+
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
56 |
+
grid = grid[0].numpy()
|
57 |
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
58 |
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
59 |
+
# cv2.waitKey(0)
|
60 |
+
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
61 |
+
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
62 |
+
|
63 |
+
grid0 = cv2.resize(grid[:,:,0],(128,128))
|
64 |
+
grid1 = cv2.resize(grid[:,:,1],(128,128))
|
65 |
+
grid = np.stack((grid0,grid1),axis=-1)
|
66 |
+
np.save(img_path.replace('_origin','_grid1'),grid)
|
67 |
+
|
68 |
+
|
69 |
+
def net1_net2_infer_single_im(img,model_path):
|
70 |
+
seg_model = DeepLab(num_classes=1,
|
71 |
+
backbone='resnet',
|
72 |
+
output_stride=16,
|
73 |
+
sync_bn=None,
|
74 |
+
freeze_bn=False)
|
75 |
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
76 |
+
seg_model.cuda()
|
77 |
+
checkpoint = torch.load(model_path)
|
78 |
+
seg_model.load_state_dict(checkpoint['model_state'])
|
79 |
+
### validate on the real datasets
|
80 |
+
seg_model.eval()
|
81 |
+
### segmentation mask predict
|
82 |
+
img_org = img
|
83 |
+
h_org,w_org = img_org.shape[:2]
|
84 |
+
img = cv2.resize(img_org,(448, 448))
|
85 |
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
86 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
87 |
+
img = cvimg2torch(img)
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
# from torchtoolbox.tools import summary
|
91 |
+
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
|
92 |
+
|
93 |
+
pred = seg_model(img.cuda())
|
94 |
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
95 |
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
96 |
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
97 |
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
98 |
+
kernel = np.ones((3,3))
|
99 |
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
100 |
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
101 |
+
mask_pred[mask_pred>100] = 255
|
102 |
+
mask_pred[mask_pred<100] = 0
|
103 |
+
### tps transform base on the mask
|
104 |
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
105 |
+
# try:
|
106 |
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
107 |
+
# except:
|
108 |
+
# print('fail')
|
109 |
+
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
110 |
+
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
111 |
+
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
112 |
+
# grid = grid[0].numpy()
|
113 |
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
114 |
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
115 |
+
# cv2.waitKey(0)
|
116 |
+
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
117 |
+
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
118 |
+
|
119 |
+
# grid0 = cv2.resize(grid[:,:,0],(128,128))
|
120 |
+
# grid1 = cv2.resize(grid[:,:,1],(128,128))
|
121 |
+
# grid = np.stack((grid0,grid1),axis=-1)
|
122 |
+
# np.save(img_path.replace('_origin','_grid1'),grid)
|
123 |
+
return mask_pred
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
parser = argparse.ArgumentParser(description='Hyperparams')
|
129 |
+
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
|
130 |
+
parser.add_argument('--img_rows', nargs='?', type=int, default=448,
|
131 |
+
help='Height of the input image')
|
132 |
+
parser.add_argument('--img_cols', nargs='?', type=int, default=448,
|
133 |
+
help='Width of the input image')
|
134 |
+
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
|
135 |
+
help='Path to previous saved model to restart from')
|
136 |
+
args = parser.parse_args()
|
137 |
+
|
138 |
+
seg_model = DeepLab(num_classes=1,
|
139 |
+
backbone='resnet',
|
140 |
+
output_stride=16,
|
141 |
+
sync_bn=None,
|
142 |
+
freeze_bn=False)
|
143 |
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
144 |
+
seg_model.cuda()
|
145 |
+
checkpoint = torch.load(args.seg_model_path)
|
146 |
+
seg_model.load_state_dict(checkpoint['model_state'])
|
147 |
+
|
148 |
+
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
|
149 |
+
|
150 |
+
net1_net2_infer(seg_model,im_paths,args)
|
151 |
+
|
data/MBD/model/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.models as models
|
2 |
+
from model.densenetccnl import *
|
3 |
+
from model.unetnc import *
|
4 |
+
from model.gienet import *
|
5 |
+
|
6 |
+
|
7 |
+
def get_model(name, n_classes=1, filters=64,version=None,in_channels=3, is_batchnorm=True, norm='batch', model_path=None, use_sigmoid=True, layers=3,img_size=512):
|
8 |
+
model = _get_model_instance(name)
|
9 |
+
|
10 |
+
|
11 |
+
if name == 'dnetccnl':
|
12 |
+
model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
|
13 |
+
elif name == 'dnetccnl512':
|
14 |
+
model = model(img_size=img_size, in_channels=in_channels, out_channels=n_classes, filters=32)
|
15 |
+
elif name == 'unetnc':
|
16 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
17 |
+
elif name == 'gie':
|
18 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
19 |
+
elif name == 'giecbam':
|
20 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
21 |
+
elif name == 'gie2head':
|
22 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
23 |
+
elif name == 'giemask':
|
24 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
25 |
+
elif name == 'giemask2':
|
26 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
27 |
+
elif name == 'giedilated':
|
28 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
29 |
+
elif name == 'bmp':
|
30 |
+
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
31 |
+
elif name == 'displacement':
|
32 |
+
model = model(n_classes=2, num_filter=32, BatchNorm='GN', in_channels=5)
|
33 |
+
return model
|
34 |
+
|
35 |
+
def _get_model_instance(name):
|
36 |
+
try:
|
37 |
+
return {
|
38 |
+
'dnetccnl': dnetccnl,
|
39 |
+
'dnetccnl512': dnetccnl512,
|
40 |
+
'unetnc': UnetGenerator,
|
41 |
+
'gie':GieGenerator,
|
42 |
+
'giecbam':GiecbamGenerator,
|
43 |
+
'giedilated':DilatedSingleUnet,
|
44 |
+
'gie2head':Gie2headGenerator,
|
45 |
+
'giemask':GiemaskGenerator,
|
46 |
+
'giemask2':Giemask2Generator,
|
47 |
+
'bmp':BmpGenerator,
|
48 |
+
}[name]
|
49 |
+
except:
|
50 |
+
print('Model {} not available'.format(name))
|
data/MBD/model/cbam.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class BasicConv(nn.Module):
|
7 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
|
8 |
+
super(BasicConv, self).__init__()
|
9 |
+
self.out_channels = out_planes
|
10 |
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
11 |
+
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
|
12 |
+
self.relu = nn.ReLU() if relu else None
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x = self.conv(x)
|
16 |
+
if self.bn is not None:
|
17 |
+
x = self.bn(x)
|
18 |
+
if self.relu is not None:
|
19 |
+
x = self.relu(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
class Flatten(nn.Module):
|
23 |
+
def forward(self, x):
|
24 |
+
return x.view(x.size(0), -1)
|
25 |
+
|
26 |
+
class ChannelGate(nn.Module):
|
27 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
28 |
+
super(ChannelGate, self).__init__()
|
29 |
+
self.gate_channels = gate_channels
|
30 |
+
self.mlp = nn.Sequential(
|
31 |
+
Flatten(),
|
32 |
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
33 |
+
nn.ReLU(),
|
34 |
+
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
35 |
+
)
|
36 |
+
self.pool_types = pool_types
|
37 |
+
def forward(self, x):
|
38 |
+
channel_att_sum = None
|
39 |
+
for pool_type in self.pool_types:
|
40 |
+
if pool_type=='avg':
|
41 |
+
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
42 |
+
channel_att_raw = self.mlp( avg_pool )
|
43 |
+
elif pool_type=='max':
|
44 |
+
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
45 |
+
channel_att_raw = self.mlp( max_pool )
|
46 |
+
elif pool_type=='lp':
|
47 |
+
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
48 |
+
channel_att_raw = self.mlp( lp_pool )
|
49 |
+
elif pool_type=='lse':
|
50 |
+
# LSE pool only
|
51 |
+
lse_pool = logsumexp_2d(x)
|
52 |
+
channel_att_raw = self.mlp( lse_pool )
|
53 |
+
|
54 |
+
if channel_att_sum is None:
|
55 |
+
channel_att_sum = channel_att_raw
|
56 |
+
else:
|
57 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
58 |
+
|
59 |
+
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
|
60 |
+
return x * scale
|
61 |
+
|
62 |
+
def logsumexp_2d(tensor):
|
63 |
+
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
64 |
+
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
65 |
+
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
66 |
+
return outputs
|
67 |
+
|
68 |
+
class ChannelPool(nn.Module):
|
69 |
+
def forward(self, x):
|
70 |
+
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
71 |
+
|
72 |
+
class SpatialGate(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super(SpatialGate, self).__init__()
|
75 |
+
kernel_size = 7
|
76 |
+
self.compress = ChannelPool()
|
77 |
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
|
78 |
+
def forward(self, x):
|
79 |
+
x_compress = self.compress(x)
|
80 |
+
x_out = self.spatial(x_compress)
|
81 |
+
scale = F.sigmoid(x_out) # broadcasting
|
82 |
+
return x * scale
|
83 |
+
|
84 |
+
class CBAM(nn.Module):
|
85 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
86 |
+
super(CBAM, self).__init__()
|
87 |
+
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
88 |
+
self.no_spatial=no_spatial
|
89 |
+
if not no_spatial:
|
90 |
+
self.SpatialGate = SpatialGate()
|
91 |
+
def forward(self, x):
|
92 |
+
x_out = self.ChannelGate(x)
|
93 |
+
if not self.no_spatial:
|
94 |
+
x_out = self.SpatialGate(x_out)
|
95 |
+
return x_out
|
data/MBD/model/deep_lab_model/__init__.py
ADDED
File without changes
|
data/MBD/model/deep_lab_model/aspp.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6 |
+
|
7 |
+
class _ASPPModule(nn.Module):
|
8 |
+
def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
|
9 |
+
super(_ASPPModule, self).__init__()
|
10 |
+
self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
|
11 |
+
stride=1, padding=padding, dilation=dilation, bias=False)
|
12 |
+
self.bn = BatchNorm(planes)
|
13 |
+
self.relu = nn.ReLU()
|
14 |
+
|
15 |
+
self._init_weight()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.atrous_conv(x)
|
19 |
+
x = self.bn(x)
|
20 |
+
|
21 |
+
return self.relu(x)
|
22 |
+
|
23 |
+
def _init_weight(self):
|
24 |
+
for m in self.modules():
|
25 |
+
if isinstance(m, nn.Conv2d):
|
26 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
27 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
28 |
+
m.weight.data.fill_(1)
|
29 |
+
m.bias.data.zero_()
|
30 |
+
elif isinstance(m, nn.BatchNorm2d):
|
31 |
+
m.weight.data.fill_(1)
|
32 |
+
m.bias.data.zero_()
|
33 |
+
|
34 |
+
class ASPP(nn.Module):
|
35 |
+
def __init__(self, backbone, output_stride, BatchNorm):
|
36 |
+
super(ASPP, self).__init__()
|
37 |
+
if backbone == 'drn':
|
38 |
+
inplanes = 512
|
39 |
+
elif backbone == 'mobilenet':
|
40 |
+
inplanes = 320
|
41 |
+
else:
|
42 |
+
inplanes = 2048
|
43 |
+
if output_stride == 16:
|
44 |
+
dilations = [1, 6, 12, 18]
|
45 |
+
elif output_stride == 8:
|
46 |
+
dilations = [1, 12, 24, 36]
|
47 |
+
else:
|
48 |
+
raise NotImplementedError
|
49 |
+
|
50 |
+
self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
|
51 |
+
self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
52 |
+
self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
53 |
+
self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
54 |
+
|
55 |
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
56 |
+
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
|
57 |
+
BatchNorm(256),
|
58 |
+
nn.ReLU())
|
59 |
+
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
60 |
+
self.bn1 = BatchNorm(256)
|
61 |
+
self.relu = nn.ReLU()
|
62 |
+
self.dropout = nn.Dropout(0.5)
|
63 |
+
self._init_weight()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x1 = self.aspp1(x)
|
67 |
+
x2 = self.aspp2(x)
|
68 |
+
x3 = self.aspp3(x)
|
69 |
+
x4 = self.aspp4(x)
|
70 |
+
x5 = self.global_avg_pool(x)
|
71 |
+
x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
72 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
73 |
+
|
74 |
+
x = self.conv1(x)
|
75 |
+
x = self.bn1(x)
|
76 |
+
x = self.relu(x)
|
77 |
+
|
78 |
+
return self.dropout(x)
|
79 |
+
|
80 |
+
def _init_weight(self):
|
81 |
+
for m in self.modules():
|
82 |
+
if isinstance(m, nn.Conv2d):
|
83 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
84 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
85 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
86 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
87 |
+
m.weight.data.fill_(1)
|
88 |
+
m.bias.data.zero_()
|
89 |
+
elif isinstance(m, nn.BatchNorm2d):
|
90 |
+
m.weight.data.fill_(1)
|
91 |
+
m.bias.data.zero_()
|
92 |
+
|
93 |
+
|
94 |
+
def build_aspp(backbone, output_stride, BatchNorm):
|
95 |
+
return ASPP(backbone, output_stride, BatchNorm)
|
data/MBD/model/deep_lab_model/backbone/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
|
2 |
+
|
3 |
+
def build_backbone(backbone, output_stride, BatchNorm):
|
4 |
+
if backbone == 'resnet':
|
5 |
+
return resnet.ResNet101(output_stride, BatchNorm)
|
6 |
+
elif backbone == 'xception':
|
7 |
+
return xception.AlignedXception(output_stride, BatchNorm)
|
8 |
+
elif backbone == 'drn':
|
9 |
+
return drn.drn_d_54(BatchNorm)
|
10 |
+
elif backbone == 'mobilenet':
|
11 |
+
return mobilenet.MobileNetV2(output_stride, BatchNorm)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError
|
data/MBD/model/deep_lab_model/backbone/drn.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import math
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5 |
+
|
6 |
+
webroot = 'http://dl.yf.io/drn/'
|
7 |
+
|
8 |
+
model_urls = {
|
9 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
10 |
+
'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
|
11 |
+
'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
|
12 |
+
'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
|
13 |
+
'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
|
14 |
+
'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
|
15 |
+
'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
|
16 |
+
'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
|
21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22 |
+
padding=padding, bias=False, dilation=dilation)
|
23 |
+
|
24 |
+
|
25 |
+
class BasicBlock(nn.Module):
|
26 |
+
expansion = 1
|
27 |
+
|
28 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
29 |
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
30 |
+
super(BasicBlock, self).__init__()
|
31 |
+
self.conv1 = conv3x3(inplanes, planes, stride,
|
32 |
+
padding=dilation[0], dilation=dilation[0])
|
33 |
+
self.bn1 = BatchNorm(planes)
|
34 |
+
self.relu = nn.ReLU(inplace=True)
|
35 |
+
self.conv2 = conv3x3(planes, planes,
|
36 |
+
padding=dilation[1], dilation=dilation[1])
|
37 |
+
self.bn2 = BatchNorm(planes)
|
38 |
+
self.downsample = downsample
|
39 |
+
self.stride = stride
|
40 |
+
self.residual = residual
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
residual = x
|
44 |
+
|
45 |
+
out = self.conv1(x)
|
46 |
+
out = self.bn1(out)
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
out = self.conv2(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
residual = self.downsample(x)
|
54 |
+
if self.residual:
|
55 |
+
out += residual
|
56 |
+
out = self.relu(out)
|
57 |
+
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
class Bottleneck(nn.Module):
|
62 |
+
expansion = 4
|
63 |
+
|
64 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
65 |
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
66 |
+
super(Bottleneck, self).__init__()
|
67 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
68 |
+
self.bn1 = BatchNorm(planes)
|
69 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
70 |
+
padding=dilation[1], bias=False,
|
71 |
+
dilation=dilation[1])
|
72 |
+
self.bn2 = BatchNorm(planes)
|
73 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
74 |
+
self.bn3 = BatchNorm(planes * 4)
|
75 |
+
self.relu = nn.ReLU(inplace=True)
|
76 |
+
self.downsample = downsample
|
77 |
+
self.stride = stride
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
residual = x
|
81 |
+
|
82 |
+
out = self.conv1(x)
|
83 |
+
out = self.bn1(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
out = self.conv2(out)
|
87 |
+
out = self.bn2(out)
|
88 |
+
out = self.relu(out)
|
89 |
+
|
90 |
+
out = self.conv3(out)
|
91 |
+
out = self.bn3(out)
|
92 |
+
|
93 |
+
if self.downsample is not None:
|
94 |
+
residual = self.downsample(x)
|
95 |
+
|
96 |
+
out += residual
|
97 |
+
out = self.relu(out)
|
98 |
+
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
class DRN(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self, block, layers, arch='D',
|
105 |
+
channels=(16, 32, 64, 128, 256, 512, 512, 512),
|
106 |
+
BatchNorm=None):
|
107 |
+
super(DRN, self).__init__()
|
108 |
+
self.inplanes = channels[0]
|
109 |
+
self.out_dim = channels[-1]
|
110 |
+
self.arch = arch
|
111 |
+
|
112 |
+
if arch == 'C':
|
113 |
+
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
|
114 |
+
padding=3, bias=False)
|
115 |
+
self.bn1 = BatchNorm(channels[0])
|
116 |
+
self.relu = nn.ReLU(inplace=True)
|
117 |
+
|
118 |
+
self.layer1 = self._make_layer(
|
119 |
+
BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
120 |
+
self.layer2 = self._make_layer(
|
121 |
+
BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
122 |
+
|
123 |
+
elif arch == 'D':
|
124 |
+
self.layer0 = nn.Sequential(
|
125 |
+
nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
|
126 |
+
bias=False),
|
127 |
+
BatchNorm(channels[0]),
|
128 |
+
nn.ReLU(inplace=True)
|
129 |
+
)
|
130 |
+
|
131 |
+
self.layer1 = self._make_conv_layers(
|
132 |
+
channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
133 |
+
self.layer2 = self._make_conv_layers(
|
134 |
+
channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
135 |
+
|
136 |
+
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
|
137 |
+
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
|
138 |
+
self.layer5 = self._make_layer(block, channels[4], layers[4],
|
139 |
+
dilation=2, new_level=False, BatchNorm=BatchNorm)
|
140 |
+
self.layer6 = None if layers[5] == 0 else \
|
141 |
+
self._make_layer(block, channels[5], layers[5], dilation=4,
|
142 |
+
new_level=False, BatchNorm=BatchNorm)
|
143 |
+
|
144 |
+
if arch == 'C':
|
145 |
+
self.layer7 = None if layers[6] == 0 else \
|
146 |
+
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
|
147 |
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
148 |
+
self.layer8 = None if layers[7] == 0 else \
|
149 |
+
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
|
150 |
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
151 |
+
elif arch == 'D':
|
152 |
+
self.layer7 = None if layers[6] == 0 else \
|
153 |
+
self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
|
154 |
+
self.layer8 = None if layers[7] == 0 else \
|
155 |
+
self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
|
156 |
+
|
157 |
+
self._init_weight()
|
158 |
+
|
159 |
+
def _init_weight(self):
|
160 |
+
for m in self.modules():
|
161 |
+
if isinstance(m, nn.Conv2d):
|
162 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
163 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
164 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
165 |
+
m.weight.data.fill_(1)
|
166 |
+
m.bias.data.zero_()
|
167 |
+
elif isinstance(m, nn.BatchNorm2d):
|
168 |
+
m.weight.data.fill_(1)
|
169 |
+
m.bias.data.zero_()
|
170 |
+
|
171 |
+
|
172 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
173 |
+
new_level=True, residual=True, BatchNorm=None):
|
174 |
+
assert dilation == 1 or dilation % 2 == 0
|
175 |
+
downsample = None
|
176 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
177 |
+
downsample = nn.Sequential(
|
178 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
179 |
+
kernel_size=1, stride=stride, bias=False),
|
180 |
+
BatchNorm(planes * block.expansion),
|
181 |
+
)
|
182 |
+
|
183 |
+
layers = list()
|
184 |
+
layers.append(block(
|
185 |
+
self.inplanes, planes, stride, downsample,
|
186 |
+
dilation=(1, 1) if dilation == 1 else (
|
187 |
+
dilation // 2 if new_level else dilation, dilation),
|
188 |
+
residual=residual, BatchNorm=BatchNorm))
|
189 |
+
self.inplanes = planes * block.expansion
|
190 |
+
for i in range(1, blocks):
|
191 |
+
layers.append(block(self.inplanes, planes, residual=residual,
|
192 |
+
dilation=(dilation, dilation), BatchNorm=BatchNorm))
|
193 |
+
|
194 |
+
return nn.Sequential(*layers)
|
195 |
+
|
196 |
+
def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
|
197 |
+
modules = []
|
198 |
+
for i in range(convs):
|
199 |
+
modules.extend([
|
200 |
+
nn.Conv2d(self.inplanes, channels, kernel_size=3,
|
201 |
+
stride=stride if i == 0 else 1,
|
202 |
+
padding=dilation, bias=False, dilation=dilation),
|
203 |
+
BatchNorm(channels),
|
204 |
+
nn.ReLU(inplace=True)])
|
205 |
+
self.inplanes = channels
|
206 |
+
return nn.Sequential(*modules)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
if self.arch == 'C':
|
210 |
+
x = self.conv1(x)
|
211 |
+
x = self.bn1(x)
|
212 |
+
x = self.relu(x)
|
213 |
+
elif self.arch == 'D':
|
214 |
+
x = self.layer0(x)
|
215 |
+
|
216 |
+
x = self.layer1(x)
|
217 |
+
x = self.layer2(x)
|
218 |
+
|
219 |
+
x = self.layer3(x)
|
220 |
+
low_level_feat = x
|
221 |
+
|
222 |
+
x = self.layer4(x)
|
223 |
+
x = self.layer5(x)
|
224 |
+
|
225 |
+
if self.layer6 is not None:
|
226 |
+
x = self.layer6(x)
|
227 |
+
|
228 |
+
if self.layer7 is not None:
|
229 |
+
x = self.layer7(x)
|
230 |
+
|
231 |
+
if self.layer8 is not None:
|
232 |
+
x = self.layer8(x)
|
233 |
+
|
234 |
+
return x, low_level_feat
|
235 |
+
|
236 |
+
|
237 |
+
class DRN_A(nn.Module):
|
238 |
+
|
239 |
+
def __init__(self, block, layers, BatchNorm=None):
|
240 |
+
self.inplanes = 64
|
241 |
+
super(DRN_A, self).__init__()
|
242 |
+
self.out_dim = 512 * block.expansion
|
243 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
244 |
+
bias=False)
|
245 |
+
self.bn1 = BatchNorm(64)
|
246 |
+
self.relu = nn.ReLU(inplace=True)
|
247 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
248 |
+
self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
|
249 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
|
250 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
251 |
+
dilation=2, BatchNorm=BatchNorm)
|
252 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
253 |
+
dilation=4, BatchNorm=BatchNorm)
|
254 |
+
|
255 |
+
self._init_weight()
|
256 |
+
|
257 |
+
def _init_weight(self):
|
258 |
+
for m in self.modules():
|
259 |
+
if isinstance(m, nn.Conv2d):
|
260 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
261 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
262 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
263 |
+
m.weight.data.fill_(1)
|
264 |
+
m.bias.data.zero_()
|
265 |
+
elif isinstance(m, nn.BatchNorm2d):
|
266 |
+
m.weight.data.fill_(1)
|
267 |
+
m.bias.data.zero_()
|
268 |
+
|
269 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
270 |
+
downsample = None
|
271 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
272 |
+
downsample = nn.Sequential(
|
273 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
274 |
+
kernel_size=1, stride=stride, bias=False),
|
275 |
+
BatchNorm(planes * block.expansion),
|
276 |
+
)
|
277 |
+
|
278 |
+
layers = []
|
279 |
+
layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
|
280 |
+
self.inplanes = planes * block.expansion
|
281 |
+
for i in range(1, blocks):
|
282 |
+
layers.append(block(self.inplanes, planes,
|
283 |
+
dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
|
284 |
+
|
285 |
+
return nn.Sequential(*layers)
|
286 |
+
|
287 |
+
def forward(self, x):
|
288 |
+
x = self.conv1(x)
|
289 |
+
x = self.bn1(x)
|
290 |
+
x = self.relu(x)
|
291 |
+
x = self.maxpool(x)
|
292 |
+
|
293 |
+
x = self.layer1(x)
|
294 |
+
x = self.layer2(x)
|
295 |
+
x = self.layer3(x)
|
296 |
+
x = self.layer4(x)
|
297 |
+
|
298 |
+
return x
|
299 |
+
|
300 |
+
def drn_a_50(BatchNorm, pretrained=True):
|
301 |
+
model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
|
302 |
+
if pretrained:
|
303 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
304 |
+
return model
|
305 |
+
|
306 |
+
|
307 |
+
def drn_c_26(BatchNorm, pretrained=True):
|
308 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
|
309 |
+
if pretrained:
|
310 |
+
pretrained = model_zoo.load_url(model_urls['drn-c-26'])
|
311 |
+
del pretrained['fc.weight']
|
312 |
+
del pretrained['fc.bias']
|
313 |
+
model.load_state_dict(pretrained)
|
314 |
+
return model
|
315 |
+
|
316 |
+
|
317 |
+
def drn_c_42(BatchNorm, pretrained=True):
|
318 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
319 |
+
if pretrained:
|
320 |
+
pretrained = model_zoo.load_url(model_urls['drn-c-42'])
|
321 |
+
del pretrained['fc.weight']
|
322 |
+
del pretrained['fc.bias']
|
323 |
+
model.load_state_dict(pretrained)
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
def drn_c_58(BatchNorm, pretrained=True):
|
328 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
329 |
+
if pretrained:
|
330 |
+
pretrained = model_zoo.load_url(model_urls['drn-c-58'])
|
331 |
+
del pretrained['fc.weight']
|
332 |
+
del pretrained['fc.bias']
|
333 |
+
model.load_state_dict(pretrained)
|
334 |
+
return model
|
335 |
+
|
336 |
+
|
337 |
+
def drn_d_22(BatchNorm, pretrained=True):
|
338 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
|
339 |
+
if pretrained:
|
340 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-22'])
|
341 |
+
del pretrained['fc.weight']
|
342 |
+
del pretrained['fc.bias']
|
343 |
+
model.load_state_dict(pretrained)
|
344 |
+
return model
|
345 |
+
|
346 |
+
|
347 |
+
def drn_d_24(BatchNorm, pretrained=True):
|
348 |
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
|
349 |
+
if pretrained:
|
350 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-24'])
|
351 |
+
del pretrained['fc.weight']
|
352 |
+
del pretrained['fc.bias']
|
353 |
+
model.load_state_dict(pretrained)
|
354 |
+
return model
|
355 |
+
|
356 |
+
|
357 |
+
def drn_d_38(BatchNorm, pretrained=True):
|
358 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
359 |
+
if pretrained:
|
360 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-38'])
|
361 |
+
del pretrained['fc.weight']
|
362 |
+
del pretrained['fc.bias']
|
363 |
+
model.load_state_dict(pretrained)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def drn_d_40(BatchNorm, pretrained=True):
|
368 |
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
|
369 |
+
if pretrained:
|
370 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-40'])
|
371 |
+
del pretrained['fc.weight']
|
372 |
+
del pretrained['fc.bias']
|
373 |
+
model.load_state_dict(pretrained)
|
374 |
+
return model
|
375 |
+
|
376 |
+
|
377 |
+
def drn_d_54(BatchNorm, pretrained=True):
|
378 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
379 |
+
if pretrained:
|
380 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-54'])
|
381 |
+
del pretrained['fc.weight']
|
382 |
+
del pretrained['fc.bias']
|
383 |
+
model.load_state_dict(pretrained)
|
384 |
+
return model
|
385 |
+
|
386 |
+
|
387 |
+
def drn_d_105(BatchNorm, pretrained=True):
|
388 |
+
model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
389 |
+
if pretrained:
|
390 |
+
pretrained = model_zoo.load_url(model_urls['drn-d-105'])
|
391 |
+
del pretrained['fc.weight']
|
392 |
+
del pretrained['fc.bias']
|
393 |
+
model.load_state_dict(pretrained)
|
394 |
+
return model
|
395 |
+
|
396 |
+
if __name__ == "__main__":
|
397 |
+
import torch
|
398 |
+
model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
|
399 |
+
input = torch.rand(1, 3, 512, 512)
|
400 |
+
output, low_level_feat = model(input)
|
401 |
+
print(output.size())
|
402 |
+
print(low_level_feat.size())
|
data/MBD/model/deep_lab_model/backbone/mobilenet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6 |
+
import torch.utils.model_zoo as model_zoo
|
7 |
+
|
8 |
+
def conv_bn(inp, oup, stride, BatchNorm):
|
9 |
+
return nn.Sequential(
|
10 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
11 |
+
BatchNorm(oup),
|
12 |
+
nn.ReLU6(inplace=True)
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def fixed_padding(inputs, kernel_size, dilation):
|
17 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
18 |
+
pad_total = kernel_size_effective - 1
|
19 |
+
pad_beg = pad_total // 2
|
20 |
+
pad_end = pad_total - pad_beg
|
21 |
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
22 |
+
return padded_inputs
|
23 |
+
|
24 |
+
|
25 |
+
class InvertedResidual(nn.Module):
|
26 |
+
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
|
27 |
+
super(InvertedResidual, self).__init__()
|
28 |
+
self.stride = stride
|
29 |
+
assert stride in [1, 2]
|
30 |
+
|
31 |
+
hidden_dim = round(inp * expand_ratio)
|
32 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
33 |
+
self.kernel_size = 3
|
34 |
+
self.dilation = dilation
|
35 |
+
|
36 |
+
if expand_ratio == 1:
|
37 |
+
self.conv = nn.Sequential(
|
38 |
+
# dw
|
39 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
|
40 |
+
BatchNorm(hidden_dim),
|
41 |
+
nn.ReLU6(inplace=True),
|
42 |
+
# pw-linear
|
43 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
|
44 |
+
BatchNorm(oup),
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.conv = nn.Sequential(
|
48 |
+
# pw
|
49 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
|
50 |
+
BatchNorm(hidden_dim),
|
51 |
+
nn.ReLU6(inplace=True),
|
52 |
+
# dw
|
53 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
|
54 |
+
BatchNorm(hidden_dim),
|
55 |
+
nn.ReLU6(inplace=True),
|
56 |
+
# pw-linear
|
57 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
|
58 |
+
BatchNorm(oup),
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
|
63 |
+
if self.use_res_connect:
|
64 |
+
x = x + self.conv(x_pad)
|
65 |
+
else:
|
66 |
+
x = self.conv(x_pad)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class MobileNetV2(nn.Module):
|
71 |
+
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
|
72 |
+
super(MobileNetV2, self).__init__()
|
73 |
+
block = InvertedResidual
|
74 |
+
input_channel = 32
|
75 |
+
current_stride = 1
|
76 |
+
rate = 1
|
77 |
+
interverted_residual_setting = [
|
78 |
+
# t, c, n, s
|
79 |
+
[1, 16, 1, 1],
|
80 |
+
[6, 24, 2, 2],
|
81 |
+
[6, 32, 3, 2],
|
82 |
+
[6, 64, 4, 2],
|
83 |
+
[6, 96, 3, 1],
|
84 |
+
[6, 160, 3, 2],
|
85 |
+
[6, 320, 1, 1],
|
86 |
+
]
|
87 |
+
|
88 |
+
# building first layer
|
89 |
+
input_channel = int(input_channel * width_mult)
|
90 |
+
self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
|
91 |
+
current_stride *= 2
|
92 |
+
# building inverted residual blocks
|
93 |
+
for t, c, n, s in interverted_residual_setting:
|
94 |
+
if current_stride == output_stride:
|
95 |
+
stride = 1
|
96 |
+
dilation = rate
|
97 |
+
rate *= s
|
98 |
+
else:
|
99 |
+
stride = s
|
100 |
+
dilation = 1
|
101 |
+
current_stride *= s
|
102 |
+
output_channel = int(c * width_mult)
|
103 |
+
for i in range(n):
|
104 |
+
if i == 0:
|
105 |
+
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
|
106 |
+
else:
|
107 |
+
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
|
108 |
+
input_channel = output_channel
|
109 |
+
self.features = nn.Sequential(*self.features)
|
110 |
+
self._initialize_weights()
|
111 |
+
|
112 |
+
if pretrained:
|
113 |
+
self._load_pretrained_model()
|
114 |
+
|
115 |
+
self.low_level_features = self.features[0:4]
|
116 |
+
self.high_level_features = self.features[4:]
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
low_level_feat = self.low_level_features(x)
|
120 |
+
x = self.high_level_features(low_level_feat)
|
121 |
+
return x, low_level_feat
|
122 |
+
|
123 |
+
def _load_pretrained_model(self):
|
124 |
+
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
|
125 |
+
model_dict = {}
|
126 |
+
state_dict = self.state_dict()
|
127 |
+
for k, v in pretrain_dict.items():
|
128 |
+
if k in state_dict:
|
129 |
+
model_dict[k] = v
|
130 |
+
state_dict.update(model_dict)
|
131 |
+
self.load_state_dict(state_dict)
|
132 |
+
|
133 |
+
def _initialize_weights(self):
|
134 |
+
for m in self.modules():
|
135 |
+
if isinstance(m, nn.Conv2d):
|
136 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
137 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
138 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
139 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
140 |
+
m.weight.data.fill_(1)
|
141 |
+
m.bias.data.zero_()
|
142 |
+
elif isinstance(m, nn.BatchNorm2d):
|
143 |
+
m.weight.data.fill_(1)
|
144 |
+
m.bias.data.zero_()
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
input = torch.rand(1, 3, 512, 512)
|
148 |
+
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
|
149 |
+
output, low_level_feat = model(input)
|
150 |
+
print(output.size())
|
151 |
+
print(low_level_feat.size())
|
data/MBD/model/deep_lab_model/backbone/resnet.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5 |
+
|
6 |
+
class Bottleneck(nn.Module):
|
7 |
+
expansion = 4
|
8 |
+
|
9 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
|
10 |
+
super(Bottleneck, self).__init__()
|
11 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
12 |
+
self.bn1 = BatchNorm(planes)
|
13 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
14 |
+
dilation=dilation, padding=dilation, bias=False)
|
15 |
+
self.bn2 = BatchNorm(planes)
|
16 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
17 |
+
self.bn3 = BatchNorm(planes * 4)
|
18 |
+
self.relu = nn.ReLU(inplace=True)
|
19 |
+
self.downsample = downsample
|
20 |
+
self.stride = stride
|
21 |
+
self.dilation = dilation
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
residual = x
|
25 |
+
|
26 |
+
out = self.conv1(x)
|
27 |
+
out = self.bn1(out)
|
28 |
+
out = self.relu(out)
|
29 |
+
|
30 |
+
out = self.conv2(out)
|
31 |
+
out = self.bn2(out)
|
32 |
+
out = self.relu(out)
|
33 |
+
|
34 |
+
out = self.conv3(out)
|
35 |
+
out = self.bn3(out)
|
36 |
+
|
37 |
+
if self.downsample is not None:
|
38 |
+
residual = self.downsample(x)
|
39 |
+
|
40 |
+
out += residual
|
41 |
+
out = self.relu(out)
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
45 |
+
class ResNet(nn.Module):
|
46 |
+
|
47 |
+
def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
|
48 |
+
self.inplanes = 64
|
49 |
+
super(ResNet, self).__init__()
|
50 |
+
blocks = [1, 2, 4]
|
51 |
+
if output_stride == 16:
|
52 |
+
strides = [1, 2, 2, 1]
|
53 |
+
dilations = [1, 1, 1, 2]
|
54 |
+
elif output_stride == 8:
|
55 |
+
strides = [1, 2, 1, 1]
|
56 |
+
dilations = [1, 1, 2, 4]
|
57 |
+
else:
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
# Modules
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = BatchNorm(64)
|
64 |
+
self.relu = nn.ReLU(inplace=True)
|
65 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
66 |
+
|
67 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
|
68 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
69 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
70 |
+
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
71 |
+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
72 |
+
self._init_weight()
|
73 |
+
|
74 |
+
# if pretrained:
|
75 |
+
# self._load_pretrained_model()
|
76 |
+
|
77 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
78 |
+
downsample = None
|
79 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
80 |
+
downsample = nn.Sequential(
|
81 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
82 |
+
kernel_size=1, stride=stride, bias=False),
|
83 |
+
BatchNorm(planes * block.expansion),
|
84 |
+
)
|
85 |
+
|
86 |
+
layers = []
|
87 |
+
layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
|
88 |
+
self.inplanes = planes * block.expansion
|
89 |
+
for i in range(1, blocks):
|
90 |
+
layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
|
91 |
+
|
92 |
+
return nn.Sequential(*layers)
|
93 |
+
|
94 |
+
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
95 |
+
downsample = None
|
96 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
97 |
+
downsample = nn.Sequential(
|
98 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
99 |
+
kernel_size=1, stride=stride, bias=False),
|
100 |
+
BatchNorm(planes * block.expansion),
|
101 |
+
)
|
102 |
+
|
103 |
+
layers = []
|
104 |
+
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
|
105 |
+
downsample=downsample, BatchNorm=BatchNorm))
|
106 |
+
self.inplanes = planes * block.expansion
|
107 |
+
for i in range(1, len(blocks)):
|
108 |
+
layers.append(block(self.inplanes, planes, stride=1,
|
109 |
+
dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
|
110 |
+
|
111 |
+
return nn.Sequential(*layers)
|
112 |
+
|
113 |
+
def forward(self, input):
|
114 |
+
x = self.conv1(input)
|
115 |
+
x = self.bn1(x)
|
116 |
+
x = self.relu(x)
|
117 |
+
x = self.maxpool(x)
|
118 |
+
|
119 |
+
x = self.layer1(x)
|
120 |
+
low_level_feat = x
|
121 |
+
x = self.layer2(x)
|
122 |
+
x = self.layer3(x)
|
123 |
+
x = self.layer4(x)
|
124 |
+
return x, low_level_feat
|
125 |
+
|
126 |
+
def _init_weight(self):
|
127 |
+
for m in self.modules():
|
128 |
+
if isinstance(m, nn.Conv2d):
|
129 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
130 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
131 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
132 |
+
m.weight.data.fill_(1)
|
133 |
+
m.bias.data.zero_()
|
134 |
+
elif isinstance(m, nn.BatchNorm2d):
|
135 |
+
m.weight.data.fill_(1)
|
136 |
+
m.bias.data.zero_()
|
137 |
+
|
138 |
+
def _load_pretrained_model(self):
|
139 |
+
|
140 |
+
import urllib.request
|
141 |
+
import ssl
|
142 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
143 |
+
response = urllib.request.urlopen('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
|
144 |
+
|
145 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
|
146 |
+
model_dict = {}
|
147 |
+
state_dict = self.state_dict()
|
148 |
+
for k, v in pretrain_dict.items():
|
149 |
+
if k in state_dict:
|
150 |
+
# if 'conv1' in k:
|
151 |
+
# continue
|
152 |
+
model_dict[k] = v
|
153 |
+
state_dict.update(model_dict)
|
154 |
+
self.load_state_dict(state_dict)
|
155 |
+
|
156 |
+
def ResNet101(output_stride, BatchNorm, pretrained=True):
|
157 |
+
"""Constructs a ResNet-101 model.
|
158 |
+
Args:
|
159 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
160 |
+
"""
|
161 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
|
162 |
+
return model
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
import torch
|
166 |
+
model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
|
167 |
+
input = torch.rand(1, 3, 512, 512)
|
168 |
+
output, low_level_feat = model(input)
|
169 |
+
print(output.size())
|
170 |
+
print(low_level_feat.size())
|
data/MBD/model/deep_lab_model/backbone/xception.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
7 |
+
|
8 |
+
def fixed_padding(inputs, kernel_size, dilation):
|
9 |
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
10 |
+
pad_total = kernel_size_effective - 1
|
11 |
+
pad_beg = pad_total // 2
|
12 |
+
pad_end = pad_total - pad_beg
|
13 |
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
14 |
+
return padded_inputs
|
15 |
+
|
16 |
+
|
17 |
+
class SeparableConv2d(nn.Module):
|
18 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
|
19 |
+
super(SeparableConv2d, self).__init__()
|
20 |
+
|
21 |
+
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
|
22 |
+
groups=inplanes, bias=bias)
|
23 |
+
self.bn = BatchNorm(inplanes)
|
24 |
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
|
28 |
+
x = self.conv1(x)
|
29 |
+
x = self.bn(x)
|
30 |
+
x = self.pointwise(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class Block(nn.Module):
|
35 |
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
|
36 |
+
start_with_relu=True, grow_first=True, is_last=False):
|
37 |
+
super(Block, self).__init__()
|
38 |
+
|
39 |
+
if planes != inplanes or stride != 1:
|
40 |
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
|
41 |
+
self.skipbn = BatchNorm(planes)
|
42 |
+
else:
|
43 |
+
self.skip = None
|
44 |
+
|
45 |
+
self.relu = nn.ReLU(inplace=True)
|
46 |
+
rep = []
|
47 |
+
|
48 |
+
filters = inplanes
|
49 |
+
if grow_first:
|
50 |
+
rep.append(self.relu)
|
51 |
+
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
|
52 |
+
rep.append(BatchNorm(planes))
|
53 |
+
filters = planes
|
54 |
+
|
55 |
+
for i in range(reps - 1):
|
56 |
+
rep.append(self.relu)
|
57 |
+
rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
|
58 |
+
rep.append(BatchNorm(filters))
|
59 |
+
|
60 |
+
if not grow_first:
|
61 |
+
rep.append(self.relu)
|
62 |
+
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
|
63 |
+
rep.append(BatchNorm(planes))
|
64 |
+
|
65 |
+
if stride != 1:
|
66 |
+
rep.append(self.relu)
|
67 |
+
rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
|
68 |
+
rep.append(BatchNorm(planes))
|
69 |
+
|
70 |
+
if stride == 1 and is_last:
|
71 |
+
rep.append(self.relu)
|
72 |
+
rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
|
73 |
+
rep.append(BatchNorm(planes))
|
74 |
+
|
75 |
+
if not start_with_relu:
|
76 |
+
rep = rep[1:]
|
77 |
+
|
78 |
+
self.rep = nn.Sequential(*rep)
|
79 |
+
|
80 |
+
def forward(self, inp):
|
81 |
+
x = self.rep(inp)
|
82 |
+
|
83 |
+
if self.skip is not None:
|
84 |
+
skip = self.skip(inp)
|
85 |
+
skip = self.skipbn(skip)
|
86 |
+
else:
|
87 |
+
skip = inp
|
88 |
+
|
89 |
+
x = x + skip
|
90 |
+
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class AlignedXception(nn.Module):
|
95 |
+
"""
|
96 |
+
Modified Alighed Xception
|
97 |
+
"""
|
98 |
+
def __init__(self, output_stride, BatchNorm,
|
99 |
+
pretrained=True):
|
100 |
+
super(AlignedXception, self).__init__()
|
101 |
+
|
102 |
+
if output_stride == 16:
|
103 |
+
entry_block3_stride = 2
|
104 |
+
middle_block_dilation = 1
|
105 |
+
exit_block_dilations = (1, 2)
|
106 |
+
elif output_stride == 8:
|
107 |
+
entry_block3_stride = 1
|
108 |
+
middle_block_dilation = 2
|
109 |
+
exit_block_dilations = (2, 4)
|
110 |
+
else:
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
|
114 |
+
# Entry flow
|
115 |
+
self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
|
116 |
+
self.bn1 = BatchNorm(32)
|
117 |
+
self.relu = nn.ReLU(inplace=True)
|
118 |
+
|
119 |
+
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
|
120 |
+
self.bn2 = BatchNorm(64)
|
121 |
+
|
122 |
+
self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
|
123 |
+
self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
|
124 |
+
grow_first=True)
|
125 |
+
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
|
126 |
+
start_with_relu=True, grow_first=True, is_last=True)
|
127 |
+
|
128 |
+
# Middle flow
|
129 |
+
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
130 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
131 |
+
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
132 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
133 |
+
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
134 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
135 |
+
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
136 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
137 |
+
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
138 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
139 |
+
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
140 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
141 |
+
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
142 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
143 |
+
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
144 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
145 |
+
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
146 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
147 |
+
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
148 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
149 |
+
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
150 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
151 |
+
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
152 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
153 |
+
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
154 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
155 |
+
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
156 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
157 |
+
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
158 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
159 |
+
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
160 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
161 |
+
|
162 |
+
# Exit flow
|
163 |
+
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
|
164 |
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
|
165 |
+
|
166 |
+
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
167 |
+
self.bn3 = BatchNorm(1536)
|
168 |
+
|
169 |
+
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
170 |
+
self.bn4 = BatchNorm(1536)
|
171 |
+
|
172 |
+
self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
173 |
+
self.bn5 = BatchNorm(2048)
|
174 |
+
|
175 |
+
# Init weights
|
176 |
+
self._init_weight()
|
177 |
+
|
178 |
+
# Load pretrained model
|
179 |
+
if pretrained:
|
180 |
+
self._load_pretrained_model()
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
# Entry flow
|
184 |
+
x = self.conv1(x)
|
185 |
+
x = self.bn1(x)
|
186 |
+
x = self.relu(x)
|
187 |
+
|
188 |
+
x = self.conv2(x)
|
189 |
+
x = self.bn2(x)
|
190 |
+
x = self.relu(x)
|
191 |
+
|
192 |
+
x = self.block1(x)
|
193 |
+
# add relu here
|
194 |
+
x = self.relu(x)
|
195 |
+
low_level_feat = x
|
196 |
+
x = self.block2(x)
|
197 |
+
x = self.block3(x)
|
198 |
+
|
199 |
+
# Middle flow
|
200 |
+
x = self.block4(x)
|
201 |
+
x = self.block5(x)
|
202 |
+
x = self.block6(x)
|
203 |
+
x = self.block7(x)
|
204 |
+
x = self.block8(x)
|
205 |
+
x = self.block9(x)
|
206 |
+
x = self.block10(x)
|
207 |
+
x = self.block11(x)
|
208 |
+
x = self.block12(x)
|
209 |
+
x = self.block13(x)
|
210 |
+
x = self.block14(x)
|
211 |
+
x = self.block15(x)
|
212 |
+
x = self.block16(x)
|
213 |
+
x = self.block17(x)
|
214 |
+
x = self.block18(x)
|
215 |
+
x = self.block19(x)
|
216 |
+
|
217 |
+
# Exit flow
|
218 |
+
x = self.block20(x)
|
219 |
+
x = self.relu(x)
|
220 |
+
x = self.conv3(x)
|
221 |
+
x = self.bn3(x)
|
222 |
+
x = self.relu(x)
|
223 |
+
|
224 |
+
x = self.conv4(x)
|
225 |
+
x = self.bn4(x)
|
226 |
+
x = self.relu(x)
|
227 |
+
|
228 |
+
x = self.conv5(x)
|
229 |
+
x = self.bn5(x)
|
230 |
+
x = self.relu(x)
|
231 |
+
|
232 |
+
return x, low_level_feat
|
233 |
+
|
234 |
+
def _init_weight(self):
|
235 |
+
for m in self.modules():
|
236 |
+
if isinstance(m, nn.Conv2d):
|
237 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
238 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
239 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
240 |
+
m.weight.data.fill_(1)
|
241 |
+
m.bias.data.zero_()
|
242 |
+
elif isinstance(m, nn.BatchNorm2d):
|
243 |
+
m.weight.data.fill_(1)
|
244 |
+
m.bias.data.zero_()
|
245 |
+
|
246 |
+
|
247 |
+
def _load_pretrained_model(self):
|
248 |
+
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
|
249 |
+
model_dict = {}
|
250 |
+
state_dict = self.state_dict()
|
251 |
+
|
252 |
+
for k, v in pretrain_dict.items():
|
253 |
+
if k in state_dict:
|
254 |
+
if 'pointwise' in k:
|
255 |
+
v = v.unsqueeze(-1).unsqueeze(-1)
|
256 |
+
if k.startswith('block11'):
|
257 |
+
model_dict[k] = v
|
258 |
+
model_dict[k.replace('block11', 'block12')] = v
|
259 |
+
model_dict[k.replace('block11', 'block13')] = v
|
260 |
+
model_dict[k.replace('block11', 'block14')] = v
|
261 |
+
model_dict[k.replace('block11', 'block15')] = v
|
262 |
+
model_dict[k.replace('block11', 'block16')] = v
|
263 |
+
model_dict[k.replace('block11', 'block17')] = v
|
264 |
+
model_dict[k.replace('block11', 'block18')] = v
|
265 |
+
model_dict[k.replace('block11', 'block19')] = v
|
266 |
+
elif k.startswith('block12'):
|
267 |
+
model_dict[k.replace('block12', 'block20')] = v
|
268 |
+
elif k.startswith('bn3'):
|
269 |
+
model_dict[k] = v
|
270 |
+
model_dict[k.replace('bn3', 'bn4')] = v
|
271 |
+
elif k.startswith('conv4'):
|
272 |
+
model_dict[k.replace('conv4', 'conv5')] = v
|
273 |
+
elif k.startswith('bn4'):
|
274 |
+
model_dict[k.replace('bn4', 'bn5')] = v
|
275 |
+
else:
|
276 |
+
model_dict[k] = v
|
277 |
+
state_dict.update(model_dict)
|
278 |
+
self.load_state_dict(state_dict)
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == "__main__":
|
283 |
+
import torch
|
284 |
+
model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
|
285 |
+
input = torch.rand(1, 3, 512, 512)
|
286 |
+
output, low_level_feat = model(input)
|
287 |
+
print(output.size())
|
288 |
+
print(low_level_feat.size())
|
data/MBD/model/deep_lab_model/decoder.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6 |
+
|
7 |
+
class Decoder(nn.Module):
|
8 |
+
def __init__(self, num_classes, backbone, BatchNorm):
|
9 |
+
super(Decoder, self).__init__()
|
10 |
+
if backbone == 'resnet' or backbone == 'drn':
|
11 |
+
low_level_inplanes = 256
|
12 |
+
elif backbone == 'xception':
|
13 |
+
low_level_inplanes = 128
|
14 |
+
elif backbone == 'mobilenet':
|
15 |
+
low_level_inplanes = 24
|
16 |
+
else:
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
|
20 |
+
self.bn1 = BatchNorm(48)
|
21 |
+
self.relu = nn.ReLU()
|
22 |
+
self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
23 |
+
BatchNorm(256),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Dropout(0.5),
|
26 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
27 |
+
BatchNorm(256),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Dropout(0.1),
|
30 |
+
nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
|
31 |
+
nn.Sigmoid()
|
32 |
+
)
|
33 |
+
self._init_weight()
|
34 |
+
|
35 |
+
|
36 |
+
def forward(self, x, low_level_feat):
|
37 |
+
low_level_feat = self.conv1(low_level_feat)
|
38 |
+
low_level_feat = self.bn1(low_level_feat)
|
39 |
+
low_level_feat = self.relu(low_level_feat)
|
40 |
+
|
41 |
+
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
|
42 |
+
x = torch.cat((x, low_level_feat), dim=1)
|
43 |
+
x = self.last_conv(x)
|
44 |
+
|
45 |
+
return x
|
46 |
+
|
47 |
+
def _init_weight(self):
|
48 |
+
for m in self.modules():
|
49 |
+
if isinstance(m, nn.Conv2d):
|
50 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
51 |
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
52 |
+
m.weight.data.fill_(1)
|
53 |
+
m.bias.data.zero_()
|
54 |
+
elif isinstance(m, nn.BatchNorm2d):
|
55 |
+
m.weight.data.fill_(1)
|
56 |
+
m.bias.data.zero_()
|
57 |
+
|
58 |
+
def build_decoder(num_classes, backbone, BatchNorm):
|
59 |
+
return Decoder(num_classes, backbone, BatchNorm)
|
data/MBD/model/deep_lab_model/deeplab.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5 |
+
from model.deep_lab_model.aspp import build_aspp
|
6 |
+
from model.deep_lab_model.decoder import build_decoder
|
7 |
+
from model.deep_lab_model.backbone import build_backbone
|
8 |
+
|
9 |
+
class DeepLab(nn.Module):
|
10 |
+
def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
|
11 |
+
sync_bn=True, freeze_bn=False):
|
12 |
+
super(DeepLab, self).__init__()
|
13 |
+
if backbone == 'drn':
|
14 |
+
output_stride = 8
|
15 |
+
|
16 |
+
if sync_bn == True:
|
17 |
+
BatchNorm = SynchronizedBatchNorm2d
|
18 |
+
else:
|
19 |
+
BatchNorm = nn.BatchNorm2d
|
20 |
+
|
21 |
+
self.backbone = build_backbone(backbone, output_stride, BatchNorm)
|
22 |
+
self.aspp = build_aspp(backbone, output_stride, BatchNorm)
|
23 |
+
self.decoder = build_decoder(num_classes, backbone, BatchNorm)
|
24 |
+
|
25 |
+
self.freeze_bn = freeze_bn
|
26 |
+
|
27 |
+
def forward(self, input):
|
28 |
+
x, low_level_feat = self.backbone(input)
|
29 |
+
x = self.aspp(x)
|
30 |
+
x = self.decoder(x, low_level_feat)
|
31 |
+
x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
def freeze_bn(self):
|
36 |
+
for m in self.modules():
|
37 |
+
if isinstance(m, SynchronizedBatchNorm2d):
|
38 |
+
m.eval()
|
39 |
+
elif isinstance(m, nn.BatchNorm2d):
|
40 |
+
m.eval()
|
41 |
+
|
42 |
+
def get_1x_lr_params(self):
|
43 |
+
modules = [self.backbone]
|
44 |
+
for i in range(len(modules)):
|
45 |
+
for m in modules[i].named_modules():
|
46 |
+
if self.freeze_bn:
|
47 |
+
if isinstance(m[1], nn.Conv2d):
|
48 |
+
for p in m[1].parameters():
|
49 |
+
if p.requires_grad:
|
50 |
+
yield p
|
51 |
+
else:
|
52 |
+
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
|
53 |
+
or isinstance(m[1], nn.BatchNorm2d):
|
54 |
+
for p in m[1].parameters():
|
55 |
+
if p.requires_grad:
|
56 |
+
yield p
|
57 |
+
|
58 |
+
def get_10x_lr_params(self):
|
59 |
+
modules = [self.aspp, self.decoder]
|
60 |
+
for i in range(len(modules)):
|
61 |
+
for m in modules[i].named_modules():
|
62 |
+
if self.freeze_bn:
|
63 |
+
if isinstance(m[1], nn.Conv2d):
|
64 |
+
for p in m[1].parameters():
|
65 |
+
if p.requires_grad:
|
66 |
+
yield p
|
67 |
+
else:
|
68 |
+
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
|
69 |
+
or isinstance(m[1], nn.BatchNorm2d):
|
70 |
+
for p in m[1].parameters():
|
71 |
+
if p.requires_grad:
|
72 |
+
yield p
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
model = DeepLab(backbone='mobilenet', output_stride=16)
|
76 |
+
model.eval()
|
77 |
+
input = torch.rand(1, 3, 513, 513)
|
78 |
+
output = model(input)
|
79 |
+
print(output.size())
|
80 |
+
|
81 |
+
|
data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
return F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
|
55 |
+
# Resize the input to (B, C, -1).
|
56 |
+
input_shape = input.size()
|
57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
58 |
+
|
59 |
+
# Compute the sum and square-sum.
|
60 |
+
sum_size = input.size(0) * input.size(2)
|
61 |
+
input_sum = _sum_ft(input)
|
62 |
+
input_ssum = _sum_ft(input ** 2)
|
63 |
+
|
64 |
+
# Reduce-and-broadcast the statistics.
|
65 |
+
if self._parallel_id == 0:
|
66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
67 |
+
else:
|
68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
69 |
+
|
70 |
+
# Compute the output.
|
71 |
+
if self.affine:
|
72 |
+
# MJY:: Fuse the multiplication for speed.
|
73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
74 |
+
else:
|
75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
76 |
+
|
77 |
+
# Reshape it.
|
78 |
+
return output.view(input_shape)
|
79 |
+
|
80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
81 |
+
self._is_parallel = True
|
82 |
+
self._parallel_id = copy_id
|
83 |
+
|
84 |
+
# parallel_id == 0 means master device.
|
85 |
+
if self._parallel_id == 0:
|
86 |
+
ctx.sync_master = self._sync_master
|
87 |
+
else:
|
88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
89 |
+
|
90 |
+
def _data_parallel_master(self, intermediates):
|
91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
92 |
+
|
93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
96 |
+
|
97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
100 |
+
|
101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
104 |
+
|
105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
106 |
+
|
107 |
+
outputs = []
|
108 |
+
for i, rec in enumerate(intermediates):
|
109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
110 |
+
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
115 |
+
also maintains the moving average on the master device."""
|
116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
117 |
+
mean = sum_ / size
|
118 |
+
sumvar = ssum - sum_ * mean
|
119 |
+
unbias_var = sumvar / (size - 1)
|
120 |
+
bias_var = sumvar / size
|
121 |
+
|
122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
124 |
+
|
125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
126 |
+
|
127 |
+
|
128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
130 |
+
mini-batch.
|
131 |
+
.. math::
|
132 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
133 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
134 |
+
standard-deviation are reduced across all devices during training.
|
135 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
136 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
137 |
+
the statistics only on that device, which accelerated the computation and
|
138 |
+
is also easy to implement, but the statistics might be inaccurate.
|
139 |
+
Instead, in this synchronized version, the statistics will be computed
|
140 |
+
over all training samples distributed on multiple devices.
|
141 |
+
|
142 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
143 |
+
as the built-in PyTorch implementation.
|
144 |
+
The mean and standard-deviation are calculated per-dimension over
|
145 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
146 |
+
of size C (where C is the input size).
|
147 |
+
During training, this layer keeps a running estimate of its computed mean
|
148 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
149 |
+
During evaluation, this running mean/variance is used for normalization.
|
150 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
151 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
152 |
+
Args:
|
153 |
+
num_features: num_features from an expected input of size
|
154 |
+
`batch_size x num_features [x width]`
|
155 |
+
eps: a value added to the denominator for numerical stability.
|
156 |
+
Default: 1e-5
|
157 |
+
momentum: the value used for the running_mean and running_var
|
158 |
+
computation. Default: 0.1
|
159 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
160 |
+
affine parameters. Default: ``True``
|
161 |
+
Shape:
|
162 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
163 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
164 |
+
Examples:
|
165 |
+
>>> # With Learnable Parameters
|
166 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
167 |
+
>>> # Without Learnable Parameters
|
168 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
169 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
170 |
+
>>> output = m(input)
|
171 |
+
"""
|
172 |
+
|
173 |
+
def _check_input_dim(self, input):
|
174 |
+
if input.dim() != 2 and input.dim() != 3:
|
175 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
176 |
+
.format(input.dim()))
|
177 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
178 |
+
|
179 |
+
|
180 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
181 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
182 |
+
of 3d inputs
|
183 |
+
.. math::
|
184 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
185 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
186 |
+
standard-deviation are reduced across all devices during training.
|
187 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
188 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
189 |
+
the statistics only on that device, which accelerated the computation and
|
190 |
+
is also easy to implement, but the statistics might be inaccurate.
|
191 |
+
Instead, in this synchronized version, the statistics will be computed
|
192 |
+
over all training samples distributed on multiple devices.
|
193 |
+
|
194 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
195 |
+
as the built-in PyTorch implementation.
|
196 |
+
The mean and standard-deviation are calculated per-dimension over
|
197 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
198 |
+
of size C (where C is the input size).
|
199 |
+
During training, this layer keeps a running estimate of its computed mean
|
200 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
201 |
+
During evaluation, this running mean/variance is used for normalization.
|
202 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
203 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
204 |
+
Args:
|
205 |
+
num_features: num_features from an expected input of
|
206 |
+
size batch_size x num_features x height x width
|
207 |
+
eps: a value added to the denominator for numerical stability.
|
208 |
+
Default: 1e-5
|
209 |
+
momentum: the value used for the running_mean and running_var
|
210 |
+
computation. Default: 0.1
|
211 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
212 |
+
affine parameters. Default: ``True``
|
213 |
+
Shape:
|
214 |
+
- Input: :math:`(N, C, H, W)`
|
215 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
216 |
+
Examples:
|
217 |
+
>>> # With Learnable Parameters
|
218 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
219 |
+
>>> # Without Learnable Parameters
|
220 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
221 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
222 |
+
>>> output = m(input)
|
223 |
+
"""
|
224 |
+
|
225 |
+
def _check_input_dim(self, input):
|
226 |
+
if input.dim() != 4:
|
227 |
+
raise ValueError('expected 4D input (got {}D input)'
|
228 |
+
.format(input.dim()))
|
229 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
230 |
+
|
231 |
+
|
232 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
233 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
234 |
+
of 4d inputs
|
235 |
+
.. math::
|
236 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
237 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
238 |
+
standard-deviation are reduced across all devices during training.
|
239 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
240 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
241 |
+
the statistics only on that device, which accelerated the computation and
|
242 |
+
is also easy to implement, but the statistics might be inaccurate.
|
243 |
+
Instead, in this synchronized version, the statistics will be computed
|
244 |
+
over all training samples distributed on multiple devices.
|
245 |
+
|
246 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
247 |
+
as the built-in PyTorch implementation.
|
248 |
+
The mean and standard-deviation are calculated per-dimension over
|
249 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
250 |
+
of size C (where C is the input size).
|
251 |
+
During training, this layer keeps a running estimate of its computed mean
|
252 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
253 |
+
During evaluation, this running mean/variance is used for normalization.
|
254 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
255 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
256 |
+
or Spatio-temporal BatchNorm
|
257 |
+
Args:
|
258 |
+
num_features: num_features from an expected input of
|
259 |
+
size batch_size x num_features x depth x height x width
|
260 |
+
eps: a value added to the denominator for numerical stability.
|
261 |
+
Default: 1e-5
|
262 |
+
momentum: the value used for the running_mean and running_var
|
263 |
+
computation. Default: 0.1
|
264 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
265 |
+
affine parameters. Default: ``True``
|
266 |
+
Shape:
|
267 |
+
- Input: :math:`(N, C, D, H, W)`
|
268 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
269 |
+
Examples:
|
270 |
+
>>> # With Learnable Parameters
|
271 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
272 |
+
>>> # Without Learnable Parameters
|
273 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
274 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
275 |
+
>>> output = m(input)
|
276 |
+
"""
|
277 |
+
|
278 |
+
def _check_input_dim(self, input):
|
279 |
+
if input.dim() != 5:
|
280 |
+
raise ValueError('expected 5D input (got {}D input)'
|
281 |
+
.format(input.dim()))
|
282 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
data/MBD/model/deep_lab_model/sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
59 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
60 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
61 |
+
and passed to a registered callback.
|
62 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
63 |
+
back to each slave devices.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, master_callback):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
70 |
+
"""
|
71 |
+
self._master_callback = master_callback
|
72 |
+
self._queue = queue.Queue()
|
73 |
+
self._registry = collections.OrderedDict()
|
74 |
+
self._activated = False
|
75 |
+
|
76 |
+
def __getstate__(self):
|
77 |
+
return {'master_callback': self._master_callback}
|
78 |
+
|
79 |
+
def __setstate__(self, state):
|
80 |
+
self.__init__(state['master_callback'])
|
81 |
+
|
82 |
+
def register_slave(self, identifier):
|
83 |
+
"""
|
84 |
+
Register an slave device.
|
85 |
+
Args:
|
86 |
+
identifier: an identifier, usually is the device id.
|
87 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
88 |
+
"""
|
89 |
+
if self._activated:
|
90 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
91 |
+
self._activated = False
|
92 |
+
self._registry.clear()
|
93 |
+
future = FutureResult()
|
94 |
+
self._registry[identifier] = _MasterRegistry(future)
|
95 |
+
return SlavePipe(identifier, self._queue, future)
|
96 |
+
|
97 |
+
def run_master(self, master_msg):
|
98 |
+
"""
|
99 |
+
Main entry for the master device in each forward pass.
|
100 |
+
The messages were first collected from each devices (including the master device), and then
|
101 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
102 |
+
(including the master device).
|
103 |
+
Args:
|
104 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
105 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
106 |
+
Returns: the message to be sent back to the master device.
|
107 |
+
"""
|
108 |
+
self._activated = True
|
109 |
+
|
110 |
+
intermediates = [(0, master_msg)]
|
111 |
+
for i in range(self.nr_slaves):
|
112 |
+
intermediates.append(self._queue.get())
|
113 |
+
|
114 |
+
results = self._master_callback(intermediates)
|
115 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
116 |
+
|
117 |
+
for i, res in results:
|
118 |
+
if i == 0:
|
119 |
+
continue
|
120 |
+
self._registry[i].result.put(res)
|
121 |
+
|
122 |
+
for i in range(self.nr_slaves):
|
123 |
+
assert self._queue.get() is True
|
124 |
+
|
125 |
+
return results[0][1]
|
126 |
+
|
127 |
+
@property
|
128 |
+
def nr_slaves(self):
|
129 |
+
return len(self._registry)
|
data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
31 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
32 |
+
(shared among multiple copies of this module on different devices).
|
33 |
+
Through this context, different copies can share some information.
|
34 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
35 |
+
of any slave copies.
|
36 |
+
"""
|
37 |
+
master_copy = modules[0]
|
38 |
+
nr_modules = len(list(master_copy.modules()))
|
39 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
40 |
+
|
41 |
+
for i, module in enumerate(modules):
|
42 |
+
for j, m in enumerate(module.modules()):
|
43 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
44 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
45 |
+
|
46 |
+
|
47 |
+
class DataParallelWithCallback(DataParallel):
|
48 |
+
"""
|
49 |
+
Data Parallel with a replication callback.
|
50 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
51 |
+
original `replicate` function.
|
52 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
53 |
+
Examples:
|
54 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
55 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
56 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def replicate(self, module, device_ids):
|
60 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
61 |
+
execute_replication_callbacks(modules)
|
62 |
+
return modules
|
63 |
+
|
64 |
+
|
65 |
+
def patch_replication_callback(data_parallel):
|
66 |
+
"""
|
67 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
68 |
+
Useful when you have customized `DataParallel` implementation.
|
69 |
+
Examples:
|
70 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
71 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
72 |
+
> patch_replication_callback(sync_bn)
|
73 |
+
# this is equivalent to
|
74 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
75 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
76 |
+
"""
|
77 |
+
|
78 |
+
assert isinstance(data_parallel, DataParallel)
|
79 |
+
|
80 |
+
old_replicate = data_parallel.replicate
|
81 |
+
|
82 |
+
@functools.wraps(old_replicate)
|
83 |
+
def new_replicate(module, device_ids):
|
84 |
+
modules = old_replicate(module, device_ids)
|
85 |
+
execute_replication_callbacks(modules)
|
86 |
+
return modules
|
87 |
+
|
88 |
+
data_parallel.replicate = new_replicate
|
data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|
data/MBD/model/densenetccnl.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Densenet decoder encoder with intermediate fully connected layers and dropout
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.backends.cudnn as cudnn
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import functools
|
8 |
+
from torch.autograd import gradcheck
|
9 |
+
from torch.autograd import Function
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from torch.autograd import gradcheck
|
12 |
+
from torch.autograd import Function
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def add_coordConv_channels(t):
|
17 |
+
n,c,h,w=t.size()
|
18 |
+
xx_channel=np.ones((h, w))
|
19 |
+
xx_range=np.array(range(h))
|
20 |
+
xx_range=np.expand_dims(xx_range,-1)
|
21 |
+
xx_coord=xx_channel*xx_range
|
22 |
+
yy_coord=xx_coord.transpose()
|
23 |
+
|
24 |
+
xx_coord=xx_coord/(h-1)
|
25 |
+
yy_coord=yy_coord/(h-1)
|
26 |
+
xx_coord=xx_coord*2 - 1
|
27 |
+
yy_coord=yy_coord*2 - 1
|
28 |
+
xx_coord=torch.from_numpy(xx_coord).float()
|
29 |
+
yy_coord=torch.from_numpy(yy_coord).float()
|
30 |
+
|
31 |
+
if t.is_cuda:
|
32 |
+
xx_coord=xx_coord.cuda()
|
33 |
+
yy_coord=yy_coord.cuda()
|
34 |
+
|
35 |
+
xx_coord=xx_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1)
|
36 |
+
yy_coord=yy_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1)
|
37 |
+
|
38 |
+
t_cc=torch.cat((t,xx_coord,yy_coord),dim=1)
|
39 |
+
|
40 |
+
return t_cc
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class DenseBlockEncoder(nn.Module):
|
45 |
+
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
|
46 |
+
super(DenseBlockEncoder, self).__init__()
|
47 |
+
assert(n_convs > 0)
|
48 |
+
|
49 |
+
self.n_channels = n_channels
|
50 |
+
self.n_convs = n_convs
|
51 |
+
self.layers = nn.ModuleList()
|
52 |
+
for i in range(n_convs):
|
53 |
+
self.layers.append(nn.Sequential(
|
54 |
+
nn.BatchNorm2d(n_channels),
|
55 |
+
activation(*args),
|
56 |
+
nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),))
|
57 |
+
|
58 |
+
def forward(self, inputs):
|
59 |
+
outputs = []
|
60 |
+
|
61 |
+
for i, layer in enumerate(self.layers):
|
62 |
+
if i > 0:
|
63 |
+
next_output = 0
|
64 |
+
for no in outputs:
|
65 |
+
next_output = next_output + no
|
66 |
+
outputs.append(next_output)
|
67 |
+
else:
|
68 |
+
outputs.append(layer(inputs))
|
69 |
+
return outputs[-1]
|
70 |
+
|
71 |
+
# Dense block in encoder.
|
72 |
+
class DenseBlockDecoder(nn.Module):
|
73 |
+
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
|
74 |
+
super(DenseBlockDecoder, self).__init__()
|
75 |
+
assert(n_convs > 0)
|
76 |
+
|
77 |
+
self.n_channels = n_channels
|
78 |
+
self.n_convs = n_convs
|
79 |
+
self.layers = nn.ModuleList()
|
80 |
+
for i in range(n_convs):
|
81 |
+
self.layers.append(nn.Sequential(
|
82 |
+
nn.BatchNorm2d(n_channels),
|
83 |
+
activation(*args),
|
84 |
+
nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),))
|
85 |
+
|
86 |
+
def forward(self, inputs):
|
87 |
+
outputs = []
|
88 |
+
|
89 |
+
for i, layer in enumerate(self.layers):
|
90 |
+
if i > 0:
|
91 |
+
next_output = 0
|
92 |
+
for no in outputs:
|
93 |
+
next_output = next_output + no
|
94 |
+
outputs.append(next_output)
|
95 |
+
else:
|
96 |
+
outputs.append(layer(inputs))
|
97 |
+
return outputs[-1]
|
98 |
+
|
99 |
+
class DenseTransitionBlockEncoder(nn.Module):
|
100 |
+
def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]):
|
101 |
+
super(DenseTransitionBlockEncoder, self).__init__()
|
102 |
+
self.n_channels_in = n_channels_in
|
103 |
+
self.n_channels_out = n_channels_out
|
104 |
+
self.mp = mp
|
105 |
+
self.main = nn.Sequential(
|
106 |
+
nn.BatchNorm2d(n_channels_in),
|
107 |
+
activation(*args),
|
108 |
+
nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False),
|
109 |
+
nn.MaxPool2d(mp),
|
110 |
+
)
|
111 |
+
def forward(self, inputs):
|
112 |
+
# print(inputs.shape,'222222222222222',self.main(inputs).shape)
|
113 |
+
return self.main(inputs)
|
114 |
+
|
115 |
+
|
116 |
+
class DenseTransitionBlockDecoder(nn.Module):
|
117 |
+
def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]):
|
118 |
+
super(DenseTransitionBlockDecoder, self).__init__()
|
119 |
+
self.n_channels_in = n_channels_in
|
120 |
+
self.n_channels_out = n_channels_out
|
121 |
+
self.main = nn.Sequential(
|
122 |
+
nn.BatchNorm2d(n_channels_in),
|
123 |
+
activation(*args),
|
124 |
+
nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False),
|
125 |
+
)
|
126 |
+
def forward(self, inputs):
|
127 |
+
# print(inputs.shape,'333333333333',self.main(inputs).shape)
|
128 |
+
return self.main(inputs)
|
129 |
+
|
130 |
+
## Dense encoders and decoders for image of size 128 128
|
131 |
+
class waspDenseEncoder128(nn.Module):
|
132 |
+
def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]):
|
133 |
+
super(waspDenseEncoder128, self).__init__()
|
134 |
+
self.ndim = ndim
|
135 |
+
|
136 |
+
self.main = nn.Sequential(
|
137 |
+
# input is (nc) x 128 x 128
|
138 |
+
nn.BatchNorm2d(nc),
|
139 |
+
nn.ReLU(True),
|
140 |
+
nn.Conv2d(nc, ndf, 4, stride=2, padding=1),
|
141 |
+
|
142 |
+
# state size. (ndf) x 64 x 64
|
143 |
+
DenseBlockEncoder(ndf, 6),
|
144 |
+
DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args),
|
145 |
+
|
146 |
+
# state size. (ndf*2) x 32 x 32
|
147 |
+
DenseBlockEncoder(ndf*2, 12),
|
148 |
+
DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args),
|
149 |
+
|
150 |
+
# state size. (ndf*4) x 16 x 16
|
151 |
+
DenseBlockEncoder(ndf*4, 16),
|
152 |
+
DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args),
|
153 |
+
|
154 |
+
# state size. (ndf*4) x 8 x 8
|
155 |
+
DenseBlockEncoder(ndf*8, 16),
|
156 |
+
DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args),
|
157 |
+
|
158 |
+
# state size. (ndf*8) x 4 x 4
|
159 |
+
DenseBlockEncoder(ndf*8, 16),
|
160 |
+
DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args),
|
161 |
+
f_activation(*f_args),
|
162 |
+
)
|
163 |
+
|
164 |
+
def forward(self, input):
|
165 |
+
input=add_coordConv_channels(input)
|
166 |
+
output = self.main(input).view(-1,self.ndim)
|
167 |
+
#print(output.size())
|
168 |
+
return output
|
169 |
+
|
170 |
+
class waspDenseDecoder128(nn.Module):
|
171 |
+
def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, f_args=[]):
|
172 |
+
super(waspDenseDecoder128, self).__init__()
|
173 |
+
self.main = nn.Sequential(
|
174 |
+
# input is Z, going into convolution
|
175 |
+
nn.BatchNorm2d(nz),
|
176 |
+
activation(*args),
|
177 |
+
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
178 |
+
|
179 |
+
# state size. (ngf*8) x 4 x 4
|
180 |
+
DenseBlockDecoder(ngf*8, 16),
|
181 |
+
DenseTransitionBlockDecoder(ngf*8, ngf*8),
|
182 |
+
|
183 |
+
# state size. (ngf*4) x 8 x 8
|
184 |
+
DenseBlockDecoder(ngf*8, 16),
|
185 |
+
DenseTransitionBlockDecoder(ngf*8, ngf*4),
|
186 |
+
|
187 |
+
# state size. (ngf*2) x 16 x 16
|
188 |
+
DenseBlockDecoder(ngf*4, 12),
|
189 |
+
DenseTransitionBlockDecoder(ngf*4, ngf*2),
|
190 |
+
|
191 |
+
# state size. (ngf) x 32 x 32
|
192 |
+
DenseBlockDecoder(ngf*2, 6),
|
193 |
+
DenseTransitionBlockDecoder(ngf*2, ngf),
|
194 |
+
|
195 |
+
# state size. (ngf) x 64 x 64
|
196 |
+
DenseBlockDecoder(ngf, 6),
|
197 |
+
DenseTransitionBlockDecoder(ngf, ngf),
|
198 |
+
|
199 |
+
# state size (ngf) x 128 x 128
|
200 |
+
nn.BatchNorm2d(ngf),
|
201 |
+
activation(*args),
|
202 |
+
nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
|
203 |
+
f_activation(*f_args),
|
204 |
+
)
|
205 |
+
# self.smooth=nn.Sequential(
|
206 |
+
# nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False),
|
207 |
+
# f_activation(*f_args),
|
208 |
+
# )
|
209 |
+
def forward(self, inputs):
|
210 |
+
# return self.smooth(self.main(inputs))
|
211 |
+
return self.main(inputs)
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
## Dense encoders and decoders for image of size 512 512
|
216 |
+
class waspDenseEncoder512(nn.Module):
|
217 |
+
def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]):
|
218 |
+
super(waspDenseEncoder512, self).__init__()
|
219 |
+
self.ndim = ndim
|
220 |
+
|
221 |
+
self.main = nn.Sequential(
|
222 |
+
# input is (nc) x 128 x 128 > *4
|
223 |
+
nn.BatchNorm2d(nc),
|
224 |
+
nn.ReLU(True),
|
225 |
+
nn.Conv2d(nc, ndf, 4, stride=2, padding=1),
|
226 |
+
|
227 |
+
# state size. (ndf) x 64 x 64 > *4
|
228 |
+
DenseBlockEncoder(ndf, 6),
|
229 |
+
DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args),
|
230 |
+
|
231 |
+
# state size. (ndf*2) x 32 x 32 > *4
|
232 |
+
DenseBlockEncoder(ndf*2, 12),
|
233 |
+
DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args),
|
234 |
+
|
235 |
+
# state size. (ndf*4) x 16 x 16 > *4
|
236 |
+
DenseBlockEncoder(ndf*4, 16),
|
237 |
+
DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args),
|
238 |
+
|
239 |
+
# state size. (ndf*8) x 8 x 8 *4
|
240 |
+
DenseBlockEncoder(ndf*8, 16),
|
241 |
+
DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args),
|
242 |
+
|
243 |
+
# state size. (ndf*8) x 4 x 4 > *4
|
244 |
+
DenseBlockEncoder(ndf*8, 16),
|
245 |
+
DenseTransitionBlockEncoder(ndf*8, ndf*8, 4, activation=activation, args=args),
|
246 |
+
f_activation(*f_args),
|
247 |
+
|
248 |
+
# state size. (ndf*8) x 2 x 2 > *4
|
249 |
+
DenseBlockEncoder(ndf*8, 16),
|
250 |
+
DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args),
|
251 |
+
f_activation(*f_args),
|
252 |
+
)
|
253 |
+
|
254 |
+
def forward(self, input):
|
255 |
+
input=add_coordConv_channels(input)
|
256 |
+
output = self.main(input).view(-1,self.ndim)
|
257 |
+
# output = self.main(input).view(8,-1)
|
258 |
+
# print(input.shape,'---------------------')
|
259 |
+
#print(output.size())
|
260 |
+
return output
|
261 |
+
|
262 |
+
class waspDenseDecoder512(nn.Module):
|
263 |
+
def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Tanh, f_args=[]):
|
264 |
+
super(waspDenseDecoder512, self).__init__()
|
265 |
+
self.main = nn.Sequential(
|
266 |
+
# input is Z, going into convolution
|
267 |
+
nn.BatchNorm2d(nz),
|
268 |
+
activation(*args),
|
269 |
+
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
270 |
+
|
271 |
+
# state size. (ngf*8) x 4 x 4
|
272 |
+
DenseBlockDecoder(ngf*8, 16),
|
273 |
+
DenseTransitionBlockDecoder(ngf*8, ngf*8),
|
274 |
+
|
275 |
+
# state size. (ngf*8) x 8 x 8
|
276 |
+
DenseBlockDecoder(ngf*8, 16),
|
277 |
+
DenseTransitionBlockDecoder(ngf*8, ngf*8),
|
278 |
+
|
279 |
+
# state size. (ngf*4) x 16 x 16
|
280 |
+
DenseBlockDecoder(ngf*8, 16),
|
281 |
+
DenseTransitionBlockDecoder(ngf*8, ngf*4),
|
282 |
+
|
283 |
+
# state size. (ngf*2) x 32 x 32
|
284 |
+
DenseBlockDecoder(ngf*4, 12),
|
285 |
+
DenseTransitionBlockDecoder(ngf*4, ngf*2),
|
286 |
+
|
287 |
+
# state size. (ngf) x 64 x 64
|
288 |
+
DenseBlockDecoder(ngf*2, 6),
|
289 |
+
DenseTransitionBlockDecoder(ngf*2, ngf),
|
290 |
+
|
291 |
+
# state size. (ngf) x 128 x 128
|
292 |
+
DenseBlockDecoder(ngf, 6),
|
293 |
+
DenseTransitionBlockDecoder(ngf, ngf),
|
294 |
+
|
295 |
+
# state size. (ngf) x 256 x 256
|
296 |
+
DenseBlockDecoder(ngf, 6),
|
297 |
+
DenseTransitionBlockDecoder(ngf, ngf),
|
298 |
+
|
299 |
+
# state size (ngf) x 512 x 512
|
300 |
+
nn.BatchNorm2d(ngf),
|
301 |
+
activation(*args),
|
302 |
+
nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
|
303 |
+
f_activation(*f_args),
|
304 |
+
)
|
305 |
+
# self.smooth=nn.Sequential(
|
306 |
+
# nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False),
|
307 |
+
# f_activation(*f_args),
|
308 |
+
# )
|
309 |
+
def forward(self, inputs):
|
310 |
+
# return self.smooth(self.main(inputs))
|
311 |
+
return self.main(inputs)
|
312 |
+
|
313 |
+
|
314 |
+
class dnetccnl(nn.Module):
|
315 |
+
#in_channels -> nc | encoder first layer
|
316 |
+
#filters -> ndf | encoder first layer
|
317 |
+
#img_size(h,w) -> ndim
|
318 |
+
#out_channels -> optical flow (x,y)
|
319 |
+
|
320 |
+
def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100):
|
321 |
+
super(dnetccnl, self).__init__()
|
322 |
+
self.nc=in_channels
|
323 |
+
self.nf=filters
|
324 |
+
self.ndim=img_size
|
325 |
+
self.oc=out_channels
|
326 |
+
self.fcu=fc_units
|
327 |
+
|
328 |
+
self.encoder=waspDenseEncoder128(nc=self.nc+2,ndf=self.nf,ndim=self.ndim)
|
329 |
+
self.decoder=waspDenseDecoder128(nz=self.ndim,nc=self.oc,ngf=self.nf)
|
330 |
+
# self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu),
|
331 |
+
# nn.ReLU(True),
|
332 |
+
# nn.Dropout(0.25),
|
333 |
+
# nn.Linear(self.fcu,self.ndim),
|
334 |
+
# nn.ReLU(True),
|
335 |
+
# nn.Dropout(0.25),
|
336 |
+
# )
|
337 |
+
|
338 |
+
def forward(self, inputs):
|
339 |
+
|
340 |
+
encoded=self.encoder(inputs)
|
341 |
+
encoded=encoded.unsqueeze(-1).unsqueeze(-1)
|
342 |
+
decoded=self.decoder(encoded)
|
343 |
+
# print torch.max(decoded)
|
344 |
+
# print torch.min(decoded)
|
345 |
+
# print(decoded.shape,'11111111111111111',encoded.shape)
|
346 |
+
|
347 |
+
return decoded
|
348 |
+
|
349 |
+
class dnetccnl512(nn.Module):
|
350 |
+
#in_channels -> nc | encoder first layer
|
351 |
+
#filters -> ndf | encoder first layer
|
352 |
+
#img_size(h,w) -> ndim
|
353 |
+
#out_channels -> optical flow (x,y)
|
354 |
+
|
355 |
+
def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100):
|
356 |
+
super(dnetccnl512, self).__init__()
|
357 |
+
self.nc=in_channels
|
358 |
+
self.nf=filters
|
359 |
+
self.ndim=img_size
|
360 |
+
self.oc=out_channels
|
361 |
+
self.fcu=fc_units
|
362 |
+
|
363 |
+
self.encoder=waspDenseEncoder512(nc=self.nc+2,ndf=self.nf,ndim=self.ndim)
|
364 |
+
self.decoder=waspDenseDecoder512(nz=self.ndim,nc=self.oc,ngf=self.nf)
|
365 |
+
# self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu),
|
366 |
+
# nn.ReLU(True),
|
367 |
+
# nn.Dropout(0.25),
|
368 |
+
# nn.Linear(self.fcu,self.ndim),
|
369 |
+
# nn.ReLU(True),
|
370 |
+
# nn.Dropout(0.25),
|
371 |
+
# )
|
372 |
+
|
373 |
+
def forward(self, inputs):
|
374 |
+
|
375 |
+
encoded=self.encoder(inputs)
|
376 |
+
encoded=encoded.unsqueeze(-1).unsqueeze(-1)
|
377 |
+
decoded=self.decoder(encoded)
|
378 |
+
# print torch.max(decoded)
|
379 |
+
# print torch.min(decoded)
|
380 |
+
# print(decoded.shape,'11111111111111111',encoded.shape)
|
381 |
+
|
382 |
+
return decoded
|
data/MBD/model/gienet.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import log
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import init
|
5 |
+
import functools
|
6 |
+
from model.cbam import CBAM
|
7 |
+
# Defines the Unet generator.
|
8 |
+
# |num_downs|: number of downsamplings in UNet. For example,
|
9 |
+
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
10 |
+
# at the bottleneck
|
11 |
+
class SingleConv(nn.Module):
|
12 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
13 |
+
|
14 |
+
def __init__(self, in_channels, out_channels):
|
15 |
+
super().__init__()
|
16 |
+
self.double_conv = nn.Sequential(
|
17 |
+
nn.ReflectionPad2d(1),
|
18 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
|
19 |
+
nn.BatchNorm2d(out_channels),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
# nn.ReflectionPad2d(1),
|
22 |
+
# nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
|
23 |
+
# nn.BatchNorm2d(out_channels),
|
24 |
+
# nn.ReLU(inplace=True)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.double_conv(x)
|
29 |
+
class Down_single(nn.Module):
|
30 |
+
"""Downscaling with maxpool then double conv"""
|
31 |
+
|
32 |
+
def __init__(self, in_channels, out_channels):
|
33 |
+
super().__init__()
|
34 |
+
self.maxpool_conv = nn.Sequential(
|
35 |
+
nn.MaxPool2d(2),
|
36 |
+
SingleConv(in_channels, out_channels)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return self.maxpool_conv(x)
|
41 |
+
class Up_single(nn.Module):
|
42 |
+
"""Upscaling then double conv"""
|
43 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
44 |
+
super().__init__()
|
45 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
46 |
+
self.conv = SingleConv(in_channels, out_channels)
|
47 |
+
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
|
48 |
+
def forward(self, x1, x2):
|
49 |
+
x1 = self.deconv(x1)
|
50 |
+
# input is BCHW
|
51 |
+
x = torch.cat([x2, x1], dim=1)
|
52 |
+
return self.conv(x)
|
53 |
+
class DoubleConv(nn.Module):
|
54 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
55 |
+
|
56 |
+
def __init__(self, in_channels, out_channels):
|
57 |
+
super().__init__()
|
58 |
+
self.double_conv = nn.Sequential(
|
59 |
+
nn.ReflectionPad2d(1),
|
60 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0,stride=1),
|
61 |
+
nn.BatchNorm2d(out_channels),
|
62 |
+
nn.ReLU(inplace=True),
|
63 |
+
nn.ReflectionPad2d(1),
|
64 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0,stride=1),
|
65 |
+
nn.BatchNorm2d(out_channels),
|
66 |
+
nn.ReLU(inplace=True)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return self.double_conv(x)
|
71 |
+
class Down(nn.Module):
|
72 |
+
"""Downscaling with maxpool then double conv"""
|
73 |
+
|
74 |
+
def __init__(self, in_channels, out_channels):
|
75 |
+
super().__init__()
|
76 |
+
self.maxpool_conv = nn.Sequential(
|
77 |
+
nn.MaxPool2d(2),
|
78 |
+
DoubleConv(in_channels, out_channels)
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return self.maxpool_conv(x)
|
83 |
+
class Up(nn.Module):
|
84 |
+
"""Upscaling then double conv"""
|
85 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
86 |
+
super().__init__()
|
87 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
88 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
89 |
+
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
|
90 |
+
def forward(self, x1, x2):
|
91 |
+
x1 = self.deconv(x1)
|
92 |
+
# input is BCHW
|
93 |
+
x = torch.cat([x2, x1], dim=1)
|
94 |
+
return self.conv(x)
|
95 |
+
|
96 |
+
class OutConv(nn.Module):
|
97 |
+
def __init__(self, in_channels, out_channels):
|
98 |
+
super(OutConv, self).__init__()
|
99 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
100 |
+
self.tanh = nn.Tanh()
|
101 |
+
self.hardtanh = nn.Hardtanh()
|
102 |
+
self.sigmoid = nn.Sigmoid()
|
103 |
+
|
104 |
+
def forward(self, x1):
|
105 |
+
x = self.conv(x1)
|
106 |
+
# x = self.sigmoid(x)
|
107 |
+
# x = self.hardtanh(x)
|
108 |
+
# x = (x+1)/2
|
109 |
+
return x
|
110 |
+
class GiemaskGenerator(nn.Module):
|
111 |
+
"""Create a Unet-based generator"""
|
112 |
+
|
113 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
114 |
+
"""Construct a Unet generator
|
115 |
+
Parameters:
|
116 |
+
input_nc (int) -- the number of channels in input images
|
117 |
+
output_nc (int) -- the number of channels in output images
|
118 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
119 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
120 |
+
ngf (int) -- the number of filters in the last conv layer
|
121 |
+
norm_layer -- normalization layer
|
122 |
+
|
123 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
124 |
+
It is a recursive process.
|
125 |
+
"""
|
126 |
+
super(GiemaskGenerator, self).__init__()
|
127 |
+
self.init_channel =32
|
128 |
+
self.inc = DoubleConv(3,self.init_channel)
|
129 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
130 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
131 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
132 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
133 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
134 |
+
|
135 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
136 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
137 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
138 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
139 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
140 |
+
self.outc = OutConv(self.init_channel, 1)
|
141 |
+
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
|
142 |
+
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
|
143 |
+
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
|
144 |
+
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
|
145 |
+
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
|
146 |
+
self.outc_1 = OutConv(self.init_channel, 1)
|
147 |
+
# self.dropout = nn.Dropout(p=0.5)
|
148 |
+
def forward(self, input):
|
149 |
+
x1 = self.inc(input)
|
150 |
+
x2 = self.down1(x1)
|
151 |
+
x3 = self.down2(x2)
|
152 |
+
x4 = self.down3(x3)
|
153 |
+
x5 = self.down4(x4)
|
154 |
+
x6 = self.down5(x5)
|
155 |
+
|
156 |
+
|
157 |
+
x_1 = self.up1_1(x6, x5)
|
158 |
+
x_1 = self.up2_1(x_1, x4)
|
159 |
+
x_1 = self.up3_1(x_1, x3)
|
160 |
+
x_1 = self.up4_1(x_1, x2)
|
161 |
+
x_1 = self.up5_1(x_1, x1)
|
162 |
+
mask = self.outc_1(x_1)
|
163 |
+
|
164 |
+
x = self.up1(x6, x5)
|
165 |
+
# x = self.dropout(x)
|
166 |
+
x = self.up2(x, x4)
|
167 |
+
# x = self.dropout(x)
|
168 |
+
x = self.up3(x, x3)
|
169 |
+
# x = self.dropout(x)
|
170 |
+
x = self.up4(x, x2)
|
171 |
+
# x = self.dropout(x)
|
172 |
+
x = self.up5(x, x1)
|
173 |
+
# x = self.dropout(x)
|
174 |
+
depth = self.outc(x)
|
175 |
+
return depth,mask
|
176 |
+
"""Create a Unet-based generator"""
|
177 |
+
class Giemask2Generator(nn.Module):
|
178 |
+
"""Create a Unet-based generator"""
|
179 |
+
|
180 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
181 |
+
"""Construct a Unet generator
|
182 |
+
Parameters:
|
183 |
+
input_nc (int) -- the number of channels in input images
|
184 |
+
output_nc (int) -- the number of channels in output images
|
185 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
186 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
187 |
+
ngf (int) -- the number of filters in the last conv layer
|
188 |
+
norm_layer -- normalization layer
|
189 |
+
|
190 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
191 |
+
It is a recursive process.
|
192 |
+
"""
|
193 |
+
super(Giemask2Generator, self).__init__()
|
194 |
+
self.init_channel =32
|
195 |
+
self.inc = DoubleConv(3,self.init_channel)
|
196 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
197 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
198 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
199 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
200 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
201 |
+
|
202 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
203 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
204 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
205 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
206 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
207 |
+
self.outc = OutConv(self.init_channel, 1)
|
208 |
+
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
|
209 |
+
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
|
210 |
+
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
|
211 |
+
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
|
212 |
+
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
|
213 |
+
self.outc_1 = OutConv(self.init_channel, 1)
|
214 |
+
self.outc_2 = OutConv(self.init_channel, 1)
|
215 |
+
# self.dropout = nn.Dropout(p=0.5)
|
216 |
+
def forward(self, input):
|
217 |
+
x1 = self.inc(input)
|
218 |
+
x2 = self.down1(x1)
|
219 |
+
x3 = self.down2(x2)
|
220 |
+
x4 = self.down3(x3)
|
221 |
+
x5 = self.down4(x4)
|
222 |
+
x6 = self.down5(x5)
|
223 |
+
|
224 |
+
|
225 |
+
x_1 = self.up1_1(x6, x5)
|
226 |
+
x_1 = self.up2_1(x_1, x4)
|
227 |
+
x_1 = self.up3_1(x_1, x3)
|
228 |
+
x_1 = self.up4_1(x_1, x2)
|
229 |
+
x_1 = self.up5_1(x_1, x1)
|
230 |
+
mask = self.outc_1(x_1)
|
231 |
+
edge = self.outc_2(x_1)
|
232 |
+
|
233 |
+
x = self.up1(x6, x5)
|
234 |
+
# x = self.dropout(x)
|
235 |
+
x = self.up2(x, x4)
|
236 |
+
# x = self.dropout(x)
|
237 |
+
x = self.up3(x, x3)
|
238 |
+
# x = self.dropout(x)
|
239 |
+
x = self.up4(x, x2)
|
240 |
+
# x = self.dropout(x)
|
241 |
+
x = self.up5(x, x1)
|
242 |
+
# x = self.dropout(x)
|
243 |
+
depth = self.outc(x)
|
244 |
+
return depth,mask,edge
|
245 |
+
"""Create a Unet-based generator"""
|
246 |
+
class GieGenerator(nn.Module):
|
247 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
248 |
+
"""Construct a Unet generator
|
249 |
+
Parameters:
|
250 |
+
input_nc (int) -- the number of channels in input images
|
251 |
+
output_nc (int) -- the number of channels in output images
|
252 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
253 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
254 |
+
ngf (int) -- the number of filters in the last conv layer
|
255 |
+
norm_layer -- normalization layer
|
256 |
+
|
257 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
258 |
+
It is a recursive process.
|
259 |
+
"""
|
260 |
+
super(GieGenerator, self).__init__()
|
261 |
+
self.init_channel =32
|
262 |
+
self.inc = DoubleConv(input_nc,self.init_channel)
|
263 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
264 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
265 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
266 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
267 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
268 |
+
|
269 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
270 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
271 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
272 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
273 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
274 |
+
self.outc = OutConv(self.init_channel, 2)
|
275 |
+
# self.dropout = nn.Dropout(p=0.5)
|
276 |
+
def forward(self, input):
|
277 |
+
x1 = self.inc(input)
|
278 |
+
x2 = self.down1(x1)
|
279 |
+
x3 = self.down2(x2)
|
280 |
+
x4 = self.down3(x3)
|
281 |
+
x5 = self.down4(x4)
|
282 |
+
x6 = self.down5(x5)
|
283 |
+
|
284 |
+
x = self.up1(x6, x5)
|
285 |
+
# x = self.dropout(x)
|
286 |
+
x = self.up2(x, x4)
|
287 |
+
# x = self.dropout(x)
|
288 |
+
x = self.up3(x, x3)
|
289 |
+
# x = self.dropout(x)
|
290 |
+
x = self.up4(x, x2)
|
291 |
+
# x = self.dropout(x)
|
292 |
+
x = self.up5(x, x1)
|
293 |
+
# x = self.dropout(x)
|
294 |
+
logits1 = self.outc(x)
|
295 |
+
return logits1
|
296 |
+
|
297 |
+
|
298 |
+
class GiecbamGenerator(nn.Module):
|
299 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
300 |
+
"""Construct a Unet generator
|
301 |
+
Parameters:
|
302 |
+
input_nc (int) -- the number of channels in input images
|
303 |
+
output_nc (int) -- the number of channels in output images
|
304 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
305 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
306 |
+
ngf (int) -- the number of filters in the last conv layer
|
307 |
+
norm_layer -- normalization layer
|
308 |
+
|
309 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
310 |
+
It is a recursive process.
|
311 |
+
"""
|
312 |
+
super(GiecbamGenerator, self).__init__()
|
313 |
+
self.init_channel =32
|
314 |
+
self.inc = DoubleConv(input_nc,self.init_channel)
|
315 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
316 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
317 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
318 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
319 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
320 |
+
self.cbam = CBAM(gate_channels=self.init_channel*32)
|
321 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
322 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
323 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
324 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
325 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
326 |
+
self.outc = OutConv(self.init_channel, 2)
|
327 |
+
self.dropout = nn.Dropout(p=0.1)
|
328 |
+
def forward(self, input):
|
329 |
+
x1 = self.inc(input)
|
330 |
+
x2 = self.down1(x1)
|
331 |
+
x3 = self.down2(x2)
|
332 |
+
x4 = self.down3(x3)
|
333 |
+
x5 = self.down4(x4)
|
334 |
+
x6 = self.down5(x5)
|
335 |
+
x6 = self.cbam(x6)
|
336 |
+
x = self.up1(x6, x5)
|
337 |
+
x = self.up2(x, x4)
|
338 |
+
x = self.up3(x, x3)
|
339 |
+
x = self.up4(x, x2)
|
340 |
+
x = self.up5(x, x1)
|
341 |
+
x = self.dropout(x)
|
342 |
+
logits1 = self.outc(x)
|
343 |
+
return logits1
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
class Gie2headGenerator(nn.Module):
|
349 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
350 |
+
"""Construct a Unet generator
|
351 |
+
Parameters:
|
352 |
+
input_nc (int) -- the number of channels in input images
|
353 |
+
output_nc (int) -- the number of channels in output images
|
354 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
355 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
356 |
+
ngf (int) -- the number of filters in the last conv layer
|
357 |
+
norm_layer -- normalization layer
|
358 |
+
|
359 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
360 |
+
It is a recursive process.
|
361 |
+
"""
|
362 |
+
super(Gie2headGenerator, self).__init__()
|
363 |
+
self.init_channel =32
|
364 |
+
self.inc = DoubleConv(input_nc,self.init_channel)
|
365 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
366 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
367 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
368 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
369 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
370 |
+
|
371 |
+
self.up1_1 = Up(self.init_channel*32, self.init_channel*16)
|
372 |
+
self.up2_1 = Up(self.init_channel*16, self.init_channel*8)
|
373 |
+
self.up3_1 = Up(self.init_channel*8, self.init_channel*4)
|
374 |
+
self.up4_1 = Up(self.init_channel*4,self.init_channel*2)
|
375 |
+
self.up5_1 = Up(self.init_channel*2, self.init_channel)
|
376 |
+
self.outc_1 = OutConv(self.init_channel, 1)
|
377 |
+
|
378 |
+
self.up1_2 = Up(self.init_channel*32, self.init_channel*16)
|
379 |
+
self.up2_2 = Up(self.init_channel*16, self.init_channel*8)
|
380 |
+
self.up3_2 = Up(self.init_channel*8, self.init_channel*4)
|
381 |
+
self.up4_2 = Up(self.init_channel*4,self.init_channel*2)
|
382 |
+
self.up5_2 = Up(self.init_channel*2, self.init_channel)
|
383 |
+
self.outc_2 = OutConv(self.init_channel, 1)
|
384 |
+
|
385 |
+
def forward(self, input):
|
386 |
+
x1 = self.inc(input)
|
387 |
+
x2 = self.down1(x1)
|
388 |
+
x3 = self.down2(x2)
|
389 |
+
x4 = self.down3(x3)
|
390 |
+
x5 = self.down4(x4)
|
391 |
+
x6 = self.down5(x5)
|
392 |
+
|
393 |
+
x_1 = self.up1_1(x6, x5)
|
394 |
+
x_1 = self.up2_1(x_1, x4)
|
395 |
+
x_1 = self.up3_1(x_1, x3)
|
396 |
+
x_1 = self.up4_1(x_1, x2)
|
397 |
+
x_1 = self.up5_1(x_1, x1)
|
398 |
+
logits_1 = self.outc_1(x_1)
|
399 |
+
|
400 |
+
x_2 = self.up1_2(x6, x5)
|
401 |
+
x_2 = self.up2_2(x_2, x4)
|
402 |
+
x_2 = self.up3_2(x_2, x3)
|
403 |
+
x_2 = self.up4_2(x_2, x2)
|
404 |
+
x_2 = self.up5_2(x_2, x1)
|
405 |
+
logits_2 = self.outc_2(x_2)
|
406 |
+
|
407 |
+
logits = torch.cat((logits_1,logits_2),1)
|
408 |
+
|
409 |
+
return logits
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
class BmpGenerator(nn.Module):
|
414 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
415 |
+
"""Construct a Unet generator
|
416 |
+
Parameters:
|
417 |
+
input_nc (int) -- the number of channels in input images
|
418 |
+
output_nc (int) -- the number of channels in output images
|
419 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
420 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
421 |
+
ngf (int) -- the number of filters in the last conv layer
|
422 |
+
norm_layer -- normalization layer
|
423 |
+
|
424 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
425 |
+
It is a recursive process.
|
426 |
+
"""
|
427 |
+
super(BmpGenerator, self).__init__()
|
428 |
+
self.init_channel =32
|
429 |
+
self.output_nc = output_nc
|
430 |
+
self.inc = DoubleConv(input_nc,self.init_channel)
|
431 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
432 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
433 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
434 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
435 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
436 |
+
|
437 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
438 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
439 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
440 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
441 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
442 |
+
self.outc = OutConv(self.init_channel, self.output_nc)
|
443 |
+
# self.dropout = nn.Dropout(p=0.5)
|
444 |
+
def forward(self, input):
|
445 |
+
x1 = self.inc(input)
|
446 |
+
x2 = self.down1(x1)
|
447 |
+
x3 = self.down2(x2)
|
448 |
+
x4 = self.down3(x3)
|
449 |
+
x5 = self.down4(x4)
|
450 |
+
x6 = self.down5(x5)
|
451 |
+
|
452 |
+
x = self.up1(x6, x5)
|
453 |
+
# x = self.dropout(x)
|
454 |
+
x = self.up2(x, x4)
|
455 |
+
# x = self.dropout(x)
|
456 |
+
x = self.up3(x, x3)
|
457 |
+
# x = self.dropout(x)
|
458 |
+
x = self.up4(x, x2)
|
459 |
+
# x = self.dropout(x)
|
460 |
+
x = self.up5(x, x1)
|
461 |
+
# x = self.dropout(x)
|
462 |
+
logits1 = self.outc(x)
|
463 |
+
return logits1
|
464 |
+
class Bmp2Generator(nn.Module):
|
465 |
+
"""Create a Unet-based generator"""
|
466 |
+
|
467 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
468 |
+
"""Construct a Unet generator
|
469 |
+
Parameters:
|
470 |
+
input_nc (int) -- the number of channels in input images
|
471 |
+
output_nc (int) -- the number of channels in output images
|
472 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
473 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
474 |
+
ngf (int) -- the number of filters in the last conv layer
|
475 |
+
norm_layer -- normalization layer
|
476 |
+
|
477 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
478 |
+
It is a recursive process.
|
479 |
+
"""
|
480 |
+
super(Bmp2Generator, self).__init__()
|
481 |
+
#gienet
|
482 |
+
self.init_channel =32
|
483 |
+
self.inc = DoubleConv(3,self.init_channel)
|
484 |
+
self.down1 = Down(self.init_channel, self.init_channel*2)
|
485 |
+
self.down2 = Down(self.init_channel*2, self.init_channel*4)
|
486 |
+
self.down3 = Down(self.init_channel*4, self.init_channel*8)
|
487 |
+
self.down4 = Down(self.init_channel*8, self.init_channel*16)
|
488 |
+
self.down5 = Down(self.init_channel*16, self.init_channel*32)
|
489 |
+
|
490 |
+
self.up1 = Up(self.init_channel*32, self.init_channel*16)
|
491 |
+
self.up2 = Up(self.init_channel*16, self.init_channel*8)
|
492 |
+
self.up3 = Up(self.init_channel*8, self.init_channel*4)
|
493 |
+
self.up4 = Up(self.init_channel*4,self.init_channel*2)
|
494 |
+
self.up5 = Up(self.init_channel*2, self.init_channel)
|
495 |
+
self.outc = OutConv(self.init_channel, 1)
|
496 |
+
self.up1_1 = Up_single(self.init_channel*32, self.init_channel*16)
|
497 |
+
self.up2_1 = Up_single(self.init_channel*16, self.init_channel*8)
|
498 |
+
self.up3_1 = Up_single(self.init_channel*8, self.init_channel*4)
|
499 |
+
self.up4_1 = Up_single(self.init_channel*4,self.init_channel*2)
|
500 |
+
self.up5_1 = Up_single(self.init_channel*2, self.init_channel)
|
501 |
+
self.outc_1 = OutConv(self.init_channel, 1)
|
502 |
+
self.outc_2 = OutConv(self.init_channel, 1)
|
503 |
+
|
504 |
+
#bpm net
|
505 |
+
self.inc_b = DoubleConv(4,self.init_channel)
|
506 |
+
self.down1_b = Down(self.init_channel, self.init_channel*2)
|
507 |
+
self.down2_b = Down(self.init_channel*2, self.init_channel*4)
|
508 |
+
self.down3_b = Down(self.init_channel*4, self.init_channel*8)
|
509 |
+
self.down4_b = Down(self.init_channel*8, self.init_channel*16)
|
510 |
+
self.down5_b = Down(self.init_channel*16, self.init_channel*32)
|
511 |
+
|
512 |
+
self.up1_b = Up(self.init_channel*32, self.init_channel*16)
|
513 |
+
self.up2_b = Up(self.init_channel*16, self.init_channel*8)
|
514 |
+
self.up3_b = Up(self.init_channel*8, self.init_channel*4)
|
515 |
+
self.up4_b = Up(self.init_channel*4,self.init_channel*2)
|
516 |
+
self.up5_b = Up(self.init_channel*2, self.init_channel)
|
517 |
+
self.outc_b = OutConv(self.init_channel, 2)
|
518 |
+
# self.dropout = nn.Dropout(p=0.5)
|
519 |
+
def forward(self, input):
|
520 |
+
#gienet
|
521 |
+
x1 = self.inc(input)
|
522 |
+
x2 = self.down1(x1)
|
523 |
+
x3 = self.down2(x2)
|
524 |
+
x4 = self.down3(x3)
|
525 |
+
x5 = self.down4(x4)
|
526 |
+
x6 = self.down5(x5)
|
527 |
+
|
528 |
+
x_1 = self.up1_1(x6, x5)
|
529 |
+
x_1 = self.up2_1(x_1, x4)
|
530 |
+
x_1 = self.up3_1(x_1, x3)
|
531 |
+
x_1 = self.up4_1(x_1, x2)
|
532 |
+
x_1 = self.up5_1(x_1, x1)
|
533 |
+
mask = self.outc_1(x_1)
|
534 |
+
edge = self.outc_2(x_1)
|
535 |
+
|
536 |
+
x = self.up1(x6, x5)
|
537 |
+
x = self.up2(x, x4)
|
538 |
+
x = self.up3(x, x3)
|
539 |
+
x = self.up4(x, x2)
|
540 |
+
x = self.up5(x, x1)
|
541 |
+
depth = self.outc(x)
|
542 |
+
|
543 |
+
#bmpnet
|
544 |
+
mask[mask>0.5]=1.
|
545 |
+
mask[mask<=0.5]=0.
|
546 |
+
image_cat_depth = torch.cat((input*mask,depth*mask),dim=1)
|
547 |
+
x1_b = self.inc_b(image_cat_depth)
|
548 |
+
x2_b = self.down1_b(x1_b)
|
549 |
+
x3_b = self.down2_b(x2_b)
|
550 |
+
x4_b = self.down3_b(x3_b)
|
551 |
+
x5_b = self.down4_b(x4_b)
|
552 |
+
x6_b = self.down5_b(x5_b)
|
553 |
+
x_b = self.up1_b(x6_b, x5_b)
|
554 |
+
x_b = self.up2_b(x_b, x4_b)
|
555 |
+
x_b = self.up3_b(x_b, x3_b)
|
556 |
+
x_b = self.up4_b(x_b, x2_b)
|
557 |
+
x_b = self.up5_b(x_b, x1_b)
|
558 |
+
bm = self.outc_b(x_b)
|
559 |
+
# return depth,mask,edge,bm
|
560 |
+
return bm
|
561 |
+
class UnetGenerator(nn.Module):
|
562 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
563 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
564 |
+
super(UnetGenerator, self).__init__()
|
565 |
+
|
566 |
+
# construct unet structure
|
567 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
568 |
+
for i in range(num_downs - 5):
|
569 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
570 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
571 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
572 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
573 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
574 |
+
|
575 |
+
self.model = unet_block
|
576 |
+
|
577 |
+
def forward(self, input):
|
578 |
+
return self.model(input)
|
579 |
+
|
580 |
+
#class GieGenerator(nn.Module):
|
581 |
+
# def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
582 |
+
# norm_layer=nn.BatchNorm2d, use_dropout=False):
|
583 |
+
# super(GieGenerator, self).__init__()
|
584 |
+
#
|
585 |
+
# # construct unet structure
|
586 |
+
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
587 |
+
# for i in range(num_downs - 5):
|
588 |
+
# unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
589 |
+
# unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
590 |
+
# unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
591 |
+
# unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
592 |
+
# unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
593 |
+
#
|
594 |
+
# self.model = unet_block
|
595 |
+
#
|
596 |
+
# def forward(self, input):
|
597 |
+
# return self.model(input)
|
598 |
+
|
599 |
+
# Defines the submodule with skip connection.
|
600 |
+
# X -------------------identity---------------------- X
|
601 |
+
# |-- downsampling -- |submodule| -- upsampling --|
|
602 |
+
class UnetSkipConnectionBlock(nn.Module):
|
603 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
604 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
605 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
606 |
+
self.outermost = outermost
|
607 |
+
if type(norm_layer) == functools.partial:
|
608 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
609 |
+
else:
|
610 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
611 |
+
if input_nc is None:
|
612 |
+
input_nc = outer_nc
|
613 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
614 |
+
stride=2, padding=1, bias=use_bias)
|
615 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
616 |
+
downnorm = norm_layer(inner_nc)
|
617 |
+
uprelu = nn.ReLU(True)
|
618 |
+
upnorm = norm_layer(outer_nc)
|
619 |
+
|
620 |
+
if outermost:
|
621 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
622 |
+
kernel_size=4, stride=2,
|
623 |
+
padding=1)
|
624 |
+
down = [downconv]
|
625 |
+
up = [uprelu, upconv, nn.Tanh()]
|
626 |
+
model = down + [submodule] + up
|
627 |
+
elif innermost:
|
628 |
+
# resize = nn.Upsample(scale_factor=2)
|
629 |
+
# conv = nn.Conv2d(inner_nc,outer_nc,kernel_size=4,stride=2,padding=1,bias=use_bias)
|
630 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
631 |
+
kernel_size=4, stride=2,
|
632 |
+
padding=1, bias=use_bias)
|
633 |
+
down = [downrelu, downconv]
|
634 |
+
up = [uprelu, upconv, upnorm]
|
635 |
+
#up = [uprelu, resize, conv, upnorm]
|
636 |
+
model = down + up
|
637 |
+
else:
|
638 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
639 |
+
kernel_size=4, stride=2,
|
640 |
+
padding=1, bias=use_bias)
|
641 |
+
down = [downrelu, downconv, downnorm]
|
642 |
+
up = [uprelu, upconv, upnorm]
|
643 |
+
|
644 |
+
if use_dropout:
|
645 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
646 |
+
else:
|
647 |
+
model = down + [submodule] + up
|
648 |
+
|
649 |
+
self.model = nn.Sequential(*model)
|
650 |
+
|
651 |
+
def forward(self, x):
|
652 |
+
if self.outermost:
|
653 |
+
return self.model(x)
|
654 |
+
else:
|
655 |
+
return torch.cat([x, self.model(x)], 1)
|
656 |
+
|
657 |
+
|
658 |
+
|
659 |
+
##===================================================================================================
|
660 |
+
class DilatedDoubleConv(nn.Module):
|
661 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
662 |
+
|
663 |
+
def __init__(self, in_channels, out_channels):
|
664 |
+
super().__init__()
|
665 |
+
self.double_conv = nn.Sequential(
|
666 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
|
667 |
+
nn.BatchNorm2d(out_channels),
|
668 |
+
nn.ReLU(inplace=True),
|
669 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
|
670 |
+
nn.BatchNorm2d(out_channels),
|
671 |
+
nn.ReLU(inplace=True)
|
672 |
+
)
|
673 |
+
|
674 |
+
def forward(self, x):
|
675 |
+
return self.double_conv(x)
|
676 |
+
|
677 |
+
class DilatedDown(nn.Module):
|
678 |
+
"""Downscaling with maxpool then double conv"""
|
679 |
+
|
680 |
+
def __init__(self, in_channels, out_channels):
|
681 |
+
super().__init__()
|
682 |
+
self.maxpool_conv = nn.Sequential(
|
683 |
+
nn.MaxPool2d(2),
|
684 |
+
DilatedDoubleConv(in_channels, out_channels)
|
685 |
+
)
|
686 |
+
|
687 |
+
def forward(self, x):
|
688 |
+
return self.maxpool_conv(x)
|
689 |
+
|
690 |
+
class DilatedUp(nn.Module):
|
691 |
+
"""Upscaling then double conv"""
|
692 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
693 |
+
super().__init__()
|
694 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
695 |
+
self.conv = DilatedDoubleConv(in_channels, out_channels)
|
696 |
+
|
697 |
+
self.conv1 = nn.Sequential(
|
698 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=4,stride=1,dilation=4),
|
699 |
+
nn.BatchNorm2d(out_channels),
|
700 |
+
nn.ReLU(inplace=True),
|
701 |
+
)
|
702 |
+
# self.deconv = nn.ConvTranspose2d(in_channels, out_channels,kernel_size=4, stride=2,padding=1, bias=True)
|
703 |
+
def forward(self, x1, x2):
|
704 |
+
x1 = self.up(x1)
|
705 |
+
x1 = self.conv1(x1)
|
706 |
+
# x1 = self.deconv(x1)
|
707 |
+
# input is BCHW
|
708 |
+
x = torch.cat([x2, x1], dim=1)
|
709 |
+
return self.conv(x)
|
710 |
+
class DilatedSingleUnet(nn.Module):
|
711 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, biline=True, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
712 |
+
super(DilatedSingleUnet, self).__init__()
|
713 |
+
self.init_channel = 32
|
714 |
+
self.inc = DilatedDoubleConv(input_nc,self.init_channel)
|
715 |
+
self.down1 = DilatedDown(self.init_channel, self.init_channel*2)
|
716 |
+
self.down2 = DilatedDown(self.init_channel*2, self.init_channel*4)
|
717 |
+
self.down3 = DilatedDown(self.init_channel*4, self.init_channel*8)
|
718 |
+
self.down4 = DilatedDown(self.init_channel*8, self.init_channel*16)
|
719 |
+
self.down5 = DilatedDown(self.init_channel*16, self.init_channel*32)
|
720 |
+
self.cbam = CBAM(gate_channels=self.init_channel*32)
|
721 |
+
|
722 |
+
self.up1 = DilatedUp(self.init_channel*32, self.init_channel*16)
|
723 |
+
self.up2 = DilatedUp(self.init_channel*16, self.init_channel*8)
|
724 |
+
self.up3 = DilatedUp(self.init_channel*8, self.init_channel*4)
|
725 |
+
self.up4 = DilatedUp(self.init_channel*4,self.init_channel*2)
|
726 |
+
self.up5 = DilatedUp(self.init_channel*2, self.init_channel)
|
727 |
+
self.outc = OutConv(self.init_channel, output_nc)
|
728 |
+
def forward(self, input):
|
729 |
+
x1 = self.inc(input)
|
730 |
+
x2 = self.down1(x1)
|
731 |
+
x3 = self.down2(x2)
|
732 |
+
x4 = self.down3(x3)
|
733 |
+
x5 = self.down4(x4)
|
734 |
+
x6 = self.down5(x5)
|
735 |
+
x6 = self.cbam(x6)
|
736 |
+
x = self.up1(x6, x5)
|
737 |
+
x = self.up2(x, x4)
|
738 |
+
x = self.up3(x, x3)
|
739 |
+
x = self.up4(x, x2)
|
740 |
+
x = self.up5(x, x1)
|
741 |
+
logits1 = self.outc(x)
|
742 |
+
return logits1
|
data/MBD/model/unetnc.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
|
6 |
+
# Defines the Unet generator.
|
7 |
+
# |num_downs|: number of downsamplings in UNet. For example,
|
8 |
+
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
9 |
+
# at the bottleneck
|
10 |
+
class UnetGenerator(nn.Module):
|
11 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
12 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
13 |
+
super(UnetGenerator, self).__init__()
|
14 |
+
|
15 |
+
# construct unet structure
|
16 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
17 |
+
for i in range(num_downs - 5):
|
18 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
19 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
20 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
21 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
22 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
23 |
+
|
24 |
+
self.model = unet_block
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
return self.model(input)
|
28 |
+
|
29 |
+
|
30 |
+
def forward(self, input):
|
31 |
+
return self.model(input)
|
32 |
+
|
33 |
+
# Defines the submodule with skip connection.
|
34 |
+
# X -------------------identity---------------------- X
|
35 |
+
# |-- downsampling -- |submodule| -- upsampling --|
|
36 |
+
class UnetSkipConnectionBlock(nn.Module):
|
37 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
38 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
39 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
40 |
+
self.outermost = outermost
|
41 |
+
if type(norm_layer) == functools.partial:
|
42 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
43 |
+
else:
|
44 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
45 |
+
if input_nc is None:
|
46 |
+
input_nc = outer_nc
|
47 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
48 |
+
stride=2, padding=1, bias=use_bias)
|
49 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
50 |
+
downnorm = norm_layer(inner_nc)
|
51 |
+
uprelu = nn.ReLU(True)
|
52 |
+
upnorm = norm_layer(outer_nc)
|
53 |
+
|
54 |
+
if outermost:
|
55 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
56 |
+
kernel_size=4, stride=2,
|
57 |
+
padding=1)
|
58 |
+
down = [downconv]
|
59 |
+
up = [uprelu, upconv, nn.Tanh()]
|
60 |
+
model = down + [submodule] + up
|
61 |
+
elif innermost:
|
62 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
63 |
+
kernel_size=4, stride=2,
|
64 |
+
padding=1, bias=use_bias)
|
65 |
+
down = [downrelu, downconv]
|
66 |
+
up = [uprelu, upconv, upnorm]
|
67 |
+
model = down + up
|
68 |
+
else:
|
69 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
70 |
+
kernel_size=4, stride=2,
|
71 |
+
padding=1, bias=use_bias)
|
72 |
+
down = [downrelu, downconv, downnorm]
|
73 |
+
up = [uprelu, upconv, upnorm]
|
74 |
+
|
75 |
+
if use_dropout:
|
76 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
77 |
+
else:
|
78 |
+
model = down + [submodule] + up
|
79 |
+
|
80 |
+
self.model = nn.Sequential(*model)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
if self.outermost:
|
84 |
+
return self.model(x)
|
85 |
+
else:
|
86 |
+
return torch.cat([x, self.model(x)], 1)
|
data/MBD/modify_stn_model/stn_head.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init
|
11 |
+
|
12 |
+
|
13 |
+
def conv3x3_block(in_planes, out_planes, stride=1):
|
14 |
+
"""3x3 convolution with padding"""
|
15 |
+
conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)
|
16 |
+
|
17 |
+
block = nn.Sequential(
|
18 |
+
conv_layer,
|
19 |
+
nn.BatchNorm2d(out_planes),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
)
|
22 |
+
return block
|
23 |
+
|
24 |
+
|
25 |
+
class STNHead(nn.Module):
|
26 |
+
def __init__(self, in_planes, num_ctrlpoints, activation='none'):
|
27 |
+
super(STNHead, self).__init__()
|
28 |
+
|
29 |
+
self.in_planes = in_planes
|
30 |
+
self.num_ctrlpoints = num_ctrlpoints
|
31 |
+
self.activation = activation
|
32 |
+
self.stn_convnet = nn.Sequential(
|
33 |
+
conv3x3_block(in_planes, 32), # 32*64
|
34 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
35 |
+
conv3x3_block(32, 64), # 16*32
|
36 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
37 |
+
conv3x3_block(64, 128), # 8*16
|
38 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
39 |
+
conv3x3_block(128, 256), # 4*8
|
40 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
41 |
+
conv3x3_block(256, 256), # 2*4,
|
42 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
43 |
+
conv3x3_block(256, 256)) # 1*2 > 256*8*8
|
44 |
+
|
45 |
+
self.stn_fc1 = nn.Sequential(
|
46 |
+
# nn.Linear(2*256, 512),
|
47 |
+
nn.Linear(8*8*256, 512),
|
48 |
+
nn.BatchNorm1d(512),
|
49 |
+
nn.ReLU(inplace=True))
|
50 |
+
self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)
|
51 |
+
|
52 |
+
self.init_weights(self.stn_convnet)
|
53 |
+
self.init_weights(self.stn_fc1)
|
54 |
+
self.init_stn(self.stn_fc2)
|
55 |
+
|
56 |
+
def init_weights(self, module):
|
57 |
+
for m in module.modules():
|
58 |
+
if isinstance(m, nn.Conv2d):
|
59 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
60 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
61 |
+
if m.bias is not None:
|
62 |
+
m.bias.data.zero_()
|
63 |
+
elif isinstance(m, nn.BatchNorm2d):
|
64 |
+
m.weight.data.fill_(1)
|
65 |
+
m.bias.data.zero_()
|
66 |
+
elif isinstance(m, nn.Linear):
|
67 |
+
m.weight.data.normal_(0, 0.001)
|
68 |
+
m.bias.data.zero_()
|
69 |
+
|
70 |
+
def init_stn(self, stn_fc2):
|
71 |
+
# margin = 0.01
|
72 |
+
# sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
73 |
+
# ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
|
74 |
+
# ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
75 |
+
# ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
|
76 |
+
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
77 |
+
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
78 |
+
# ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
79 |
+
|
80 |
+
margin_x, margin_y = 0.35,0.35
|
81 |
+
# margin_x, margin_y = 0,0
|
82 |
+
num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2
|
83 |
+
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
84 |
+
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
85 |
+
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
86 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
87 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
88 |
+
|
89 |
+
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
|
90 |
+
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
|
91 |
+
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
92 |
+
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
93 |
+
|
94 |
+
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32)
|
95 |
+
|
96 |
+
|
97 |
+
if self.activation is 'none':
|
98 |
+
pass
|
99 |
+
elif self.activation == 'sigmoid':
|
100 |
+
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
101 |
+
stn_fc2.weight.data.zero_()
|
102 |
+
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.stn_convnet(x)
|
106 |
+
batch_size, _, h, w = x.size()
|
107 |
+
x = x.view(batch_size, -1)
|
108 |
+
img_feat = self.stn_fc1(x)
|
109 |
+
x = self.stn_fc2(0.1 * img_feat)
|
110 |
+
if self.activation == 'sigmoid':
|
111 |
+
x = F.sigmoid(x)
|
112 |
+
x = x.view(-1, self.num_ctrlpoints, 2)
|
113 |
+
return img_feat, x
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
in_planes = 3
|
118 |
+
num_ctrlpoints = 20
|
119 |
+
activation='none' # 'sigmoid'
|
120 |
+
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
|
121 |
+
input = torch.randn(10, 3, 32, 64)
|
122 |
+
control_points = stn_head(input)
|
123 |
+
print(control_points.size())
|
data/MBD/modify_stn_model/tps_spatial_transformer.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
def grid_sample(input, grid, canvas = None):
|
11 |
+
output = F.grid_sample(input, grid)
|
12 |
+
if canvas is None:
|
13 |
+
return output
|
14 |
+
else:
|
15 |
+
input_mask = input.data.new(input.size()).fill_(1)
|
16 |
+
output_mask = F.grid_sample(input_mask, grid)
|
17 |
+
padded_output = output * output_mask + canvas * (1 - output_mask)
|
18 |
+
return padded_output
|
19 |
+
|
20 |
+
|
21 |
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
22 |
+
def compute_partial_repr(input_points, control_points):
|
23 |
+
N = input_points.size(0)
|
24 |
+
M = control_points.size(0)
|
25 |
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
26 |
+
# original implementation, very slow
|
27 |
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
28 |
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
29 |
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
30 |
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
31 |
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
32 |
+
mask = repr_matrix != repr_matrix
|
33 |
+
repr_matrix.masked_fill_(mask, 0)
|
34 |
+
return repr_matrix
|
35 |
+
|
36 |
+
|
37 |
+
# # output_ctrl_pts are specified, according to our task.
|
38 |
+
# def build_output_control_points(num_control_points, margins):
|
39 |
+
# margin_x, margin_y = margins
|
40 |
+
# margin_x, margin_y = 0,0
|
41 |
+
# num_ctrl_pts_per_side = num_control_points // 2
|
42 |
+
# ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
43 |
+
# ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
44 |
+
# ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
45 |
+
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
46 |
+
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
47 |
+
# # ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
48 |
+
# # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
49 |
+
# output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
50 |
+
# output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
51 |
+
# return output_ctrl_pts
|
52 |
+
|
53 |
+
# output_ctrl_pts are specified, according to our task.
|
54 |
+
# def build_output_control_points(num_control_points, margins):
|
55 |
+
# margin_x, margin_y = margins
|
56 |
+
# # margin_x, margin_y = 0,0
|
57 |
+
# num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
|
58 |
+
# ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
59 |
+
# ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
60 |
+
# ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
61 |
+
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
62 |
+
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
63 |
+
|
64 |
+
# ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
|
65 |
+
# ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
|
66 |
+
# ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
67 |
+
# ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
68 |
+
|
69 |
+
# output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
|
70 |
+
# output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
71 |
+
# return output_ctrl_pts
|
72 |
+
|
73 |
+
def build_output_control_points(num_control_points, margins):
|
74 |
+
points = [0.25,0.5,0.75]
|
75 |
+
pts2 = [[0, 0],[1, 0], [0, 1],[1, 1]]
|
76 |
+
# pts22 = []
|
77 |
+
for ratio in points:
|
78 |
+
pts2.append([1*ratio,0])
|
79 |
+
for ratio in points:
|
80 |
+
pts2.append([1*ratio,1])
|
81 |
+
for ratio in points:
|
82 |
+
pts2.append([0,1*ratio])
|
83 |
+
for ratio in points:
|
84 |
+
pts2.append([1,1*ratio])
|
85 |
+
pts2 = np.float32(pts2)
|
86 |
+
margin_x, margin_y = margins
|
87 |
+
# margin_x, margin_y = 0,0
|
88 |
+
num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
|
89 |
+
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
90 |
+
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
91 |
+
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
92 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
93 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
94 |
+
|
95 |
+
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
|
96 |
+
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
|
97 |
+
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
98 |
+
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
99 |
+
|
100 |
+
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
|
101 |
+
# output_ctrl_pts_arr = np.asarray([[0,0],[1,0],[1,1],[0,1],
|
102 |
+
# [],[],[],[],
|
103 |
+
# [],[],[],[],
|
104 |
+
# [],[],[],[]])
|
105 |
+
output_ctrl_pts_arr = pts2
|
106 |
+
# print(output_ctrl_pts_arr.shape,'=================')
|
107 |
+
output_ctrl_pts = torch.FloatTensor(output_ctrl_pts_arr)
|
108 |
+
return output_ctrl_pts
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
# demo: ~/test/models/test_tps_transformation.py
|
113 |
+
class TPSSpatialTransformer(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, output_image_size=None, num_control_points=None, margins=None):
|
116 |
+
super(TPSSpatialTransformer, self).__init__()
|
117 |
+
self.output_image_size = output_image_size
|
118 |
+
self.num_control_points = num_control_points
|
119 |
+
self.margins = margins
|
120 |
+
|
121 |
+
self.target_height, self.target_width = output_image_size
|
122 |
+
target_control_points = build_output_control_points(num_control_points, margins)
|
123 |
+
N = num_control_points
|
124 |
+
# N = N - 4
|
125 |
+
|
126 |
+
# create padded kernel matrix
|
127 |
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
128 |
+
target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
|
129 |
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
130 |
+
forward_kernel[:N, -3].fill_(1)
|
131 |
+
forward_kernel[-3, :N].fill_(1)
|
132 |
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
133 |
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
134 |
+
# compute inverse matrix
|
135 |
+
# print(forward_kernel.shape)
|
136 |
+
inverse_kernel = torch.inverse(forward_kernel)
|
137 |
+
|
138 |
+
# create target cordinate matrix
|
139 |
+
HW = self.target_height * self.target_width
|
140 |
+
target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width)))
|
141 |
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
142 |
+
Y, X = target_coordinate.split(1, dim = 1)
|
143 |
+
Y = Y / (self.target_height - 1)
|
144 |
+
X = X / (self.target_width - 1)
|
145 |
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
146 |
+
target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
|
147 |
+
target_coordinate_repr = torch.cat([
|
148 |
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
149 |
+
], dim = 1)
|
150 |
+
|
151 |
+
# register precomputed matrices
|
152 |
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
153 |
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
154 |
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
155 |
+
self.register_buffer('target_control_points', target_control_points)
|
156 |
+
|
157 |
+
def forward(self, input, source_control_points,direction='dewarp'):
|
158 |
+
if direction == 'dewarp':
|
159 |
+
assert source_control_points.ndimension() == 3
|
160 |
+
assert source_control_points.size(1) == self.num_control_points
|
161 |
+
assert source_control_points.size(2) == 2
|
162 |
+
batch_size = source_control_points.size(0)
|
163 |
+
|
164 |
+
Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1)
|
165 |
+
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
166 |
+
source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
|
167 |
+
|
168 |
+
grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
|
169 |
+
grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
|
170 |
+
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
171 |
+
grid = 2.0 * grid - 1.0
|
172 |
+
output = grid_sample(input, grid, canvas=None)
|
173 |
+
return output, grid
|
174 |
+
|
175 |
+
# elif direction == 'warp':
|
176 |
+
# target_control_points = source_control_points.clone()
|
177 |
+
# source_control_points = (build_output_control_points(self.num_control_points, self.margins)).clone()
|
178 |
+
# source_control_points = source_control_points.unsqueeze(0)
|
179 |
+
# source_control_points = source_control_points.expand(target_control_points.size(0),self.num_control_points,2)
|
180 |
+
# assert source_control_points.ndimension() == 3
|
181 |
+
# assert source_control_points.size(1) == self.num_control_points
|
182 |
+
# assert source_control_points.size(2) == 2
|
183 |
+
# batch_size = source_control_points.size(0)
|
184 |
+
|
185 |
+
# Y = torch.cat([source_control_points.to('cuda'), self.padding_matrix.expand(batch_size, 3, 2)], 1)
|
186 |
+
# mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
187 |
+
# source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
|
188 |
+
|
189 |
+
# grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
|
190 |
+
# grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
|
191 |
+
# # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
192 |
+
# grid = 2.0 * grid - 1.0
|
193 |
+
# output_maps = grid_sample(input, grid, canvas=None)
|
194 |
+
# return output_maps, source_coordinate
|
data/MBD/stn_model/stn_head.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init
|
11 |
+
|
12 |
+
|
13 |
+
def conv3x3_block(in_planes, out_planes, stride=1):
|
14 |
+
"""3x3 convolution with padding"""
|
15 |
+
conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)
|
16 |
+
|
17 |
+
block = nn.Sequential(
|
18 |
+
conv_layer,
|
19 |
+
nn.BatchNorm2d(out_planes),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
)
|
22 |
+
return block
|
23 |
+
|
24 |
+
|
25 |
+
class STNHead(nn.Module):
|
26 |
+
def __init__(self, in_planes, num_ctrlpoints, activation='none'):
|
27 |
+
super(STNHead, self).__init__()
|
28 |
+
|
29 |
+
self.in_planes = in_planes
|
30 |
+
self.num_ctrlpoints = num_ctrlpoints
|
31 |
+
self.activation = activation
|
32 |
+
self.stn_convnet = nn.Sequential(
|
33 |
+
conv3x3_block(in_planes, 32), # 32*64
|
34 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
35 |
+
conv3x3_block(32, 64), # 16*32
|
36 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
37 |
+
conv3x3_block(64, 128), # 8*16
|
38 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
39 |
+
conv3x3_block(128, 256), # 4*8
|
40 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
41 |
+
conv3x3_block(256, 256), # 2*4,
|
42 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
43 |
+
conv3x3_block(256, 256)) # 1*2 > 256*8*8
|
44 |
+
|
45 |
+
self.stn_fc1 = nn.Sequential(
|
46 |
+
# nn.Linear(2*256, 512),
|
47 |
+
nn.Linear(8*8*256, 512),
|
48 |
+
nn.BatchNorm1d(512),
|
49 |
+
nn.ReLU(inplace=True))
|
50 |
+
self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2)
|
51 |
+
|
52 |
+
self.init_weights(self.stn_convnet)
|
53 |
+
self.init_weights(self.stn_fc1)
|
54 |
+
self.init_stn(self.stn_fc2)
|
55 |
+
|
56 |
+
def init_weights(self, module):
|
57 |
+
for m in module.modules():
|
58 |
+
if isinstance(m, nn.Conv2d):
|
59 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
60 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
61 |
+
if m.bias is not None:
|
62 |
+
m.bias.data.zero_()
|
63 |
+
elif isinstance(m, nn.BatchNorm2d):
|
64 |
+
m.weight.data.fill_(1)
|
65 |
+
m.bias.data.zero_()
|
66 |
+
elif isinstance(m, nn.Linear):
|
67 |
+
m.weight.data.normal_(0, 0.001)
|
68 |
+
m.bias.data.zero_()
|
69 |
+
|
70 |
+
def init_stn(self, stn_fc2):
|
71 |
+
# margin = 0.01
|
72 |
+
# sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
73 |
+
# ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side)
|
74 |
+
# ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
75 |
+
# ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin)
|
76 |
+
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
77 |
+
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
78 |
+
# ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
79 |
+
|
80 |
+
margin_x, margin_y = 0.35,0.35
|
81 |
+
# margin_x, margin_y = 0,0
|
82 |
+
num_ctrl_pts_per_side = (self.num_ctrlpoints-4) // 4 +2
|
83 |
+
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
84 |
+
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
85 |
+
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
86 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
87 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
88 |
+
|
89 |
+
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
|
90 |
+
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
|
91 |
+
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
92 |
+
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
93 |
+
|
94 |
+
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0).astype(np.float32)
|
95 |
+
|
96 |
+
|
97 |
+
if self.activation is 'none':
|
98 |
+
pass
|
99 |
+
elif self.activation == 'sigmoid':
|
100 |
+
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
101 |
+
stn_fc2.weight.data.zero_()
|
102 |
+
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.stn_convnet(x)
|
106 |
+
batch_size, _, h, w = x.size()
|
107 |
+
x = x.view(batch_size, -1)
|
108 |
+
img_feat = self.stn_fc1(x)
|
109 |
+
x = self.stn_fc2(0.1 * img_feat)
|
110 |
+
if self.activation == 'sigmoid':
|
111 |
+
x = F.sigmoid(x)
|
112 |
+
x = x.view(-1, self.num_ctrlpoints, 2)
|
113 |
+
return img_feat, x
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
in_planes = 3
|
118 |
+
num_ctrlpoints = 20
|
119 |
+
activation='none' # 'sigmoid'
|
120 |
+
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
|
121 |
+
input = torch.randn(10, 3, 32, 64)
|
122 |
+
control_points = stn_head(input)
|
123 |
+
print(control_points.size())
|
data/MBD/stn_model/tps_spatial_transformer.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
def grid_sample(input, grid, canvas = None):
|
11 |
+
output = F.grid_sample(input, grid)
|
12 |
+
if canvas is None:
|
13 |
+
return output
|
14 |
+
else:
|
15 |
+
input_mask = input.data.new(input.size()).fill_(1)
|
16 |
+
output_mask = F.grid_sample(input_mask, grid)
|
17 |
+
padded_output = output * output_mask + canvas * (1 - output_mask)
|
18 |
+
return padded_output
|
19 |
+
|
20 |
+
|
21 |
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
22 |
+
def compute_partial_repr(input_points, control_points):
|
23 |
+
N = input_points.size(0)
|
24 |
+
M = control_points.size(0)
|
25 |
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
26 |
+
# original implementation, very slow
|
27 |
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
28 |
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
29 |
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
30 |
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
31 |
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
32 |
+
mask = repr_matrix != repr_matrix
|
33 |
+
repr_matrix.masked_fill_(mask, 0)
|
34 |
+
return repr_matrix
|
35 |
+
|
36 |
+
|
37 |
+
# # output_ctrl_pts are specified, according to our task.
|
38 |
+
# def build_output_control_points(num_control_points, margins):
|
39 |
+
# margin_x, margin_y = margins
|
40 |
+
# margin_x, margin_y = 0,0
|
41 |
+
# num_ctrl_pts_per_side = num_control_points // 2
|
42 |
+
# ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
43 |
+
# ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
44 |
+
# ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
45 |
+
# ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
46 |
+
# ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
47 |
+
# # ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
48 |
+
# # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
49 |
+
# output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
50 |
+
# output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
51 |
+
# return output_ctrl_pts
|
52 |
+
|
53 |
+
# output_ctrl_pts are specified, according to our task.
|
54 |
+
def build_output_control_points(num_control_points, margins):
|
55 |
+
margin_x, margin_y = margins
|
56 |
+
# margin_x, margin_y = 0,0
|
57 |
+
num_ctrl_pts_per_side = (num_control_points-4) // 4 +2
|
58 |
+
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
59 |
+
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
60 |
+
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
61 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
62 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
63 |
+
|
64 |
+
ctrl_pts_x_left = np.ones(num_ctrl_pts_per_side) * margin_x
|
65 |
+
ctrl_pts_x_right = np.ones(num_ctrl_pts_per_side) * (1.0-margin_x)
|
66 |
+
ctrl_pts_left = np.stack([ctrl_pts_x_left[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
67 |
+
ctrl_pts_right = np.stack([ctrl_pts_x_right[1:-1], ctrl_pts_x[1:-1]], axis=1)
|
68 |
+
|
69 |
+
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom, ctrl_pts_left, ctrl_pts_right], axis=0)
|
70 |
+
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
71 |
+
return output_ctrl_pts
|
72 |
+
|
73 |
+
# demo: ~/test/models/test_tps_transformation.py
|
74 |
+
class TPSSpatialTransformer(nn.Module):
|
75 |
+
|
76 |
+
def __init__(self, output_image_size=None, num_control_points=None, margins=None):
|
77 |
+
super(TPSSpatialTransformer, self).__init__()
|
78 |
+
self.output_image_size = output_image_size
|
79 |
+
self.num_control_points = num_control_points
|
80 |
+
self.margins = margins
|
81 |
+
|
82 |
+
self.target_height, self.target_width = output_image_size
|
83 |
+
target_control_points = build_output_control_points(num_control_points, margins)
|
84 |
+
N = num_control_points
|
85 |
+
# N = N - 4
|
86 |
+
|
87 |
+
# create padded kernel matrix
|
88 |
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
89 |
+
target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
|
90 |
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
91 |
+
forward_kernel[:N, -3].fill_(1)
|
92 |
+
forward_kernel[-3, :N].fill_(1)
|
93 |
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
94 |
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
95 |
+
# compute inverse matrix
|
96 |
+
# print(forward_kernel.shape)
|
97 |
+
inverse_kernel = torch.inverse(forward_kernel)
|
98 |
+
|
99 |
+
# create target cordinate matrix
|
100 |
+
HW = self.target_height * self.target_width
|
101 |
+
target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width)))
|
102 |
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
103 |
+
Y, X = target_coordinate.split(1, dim = 1)
|
104 |
+
Y = Y / (self.target_height - 1)
|
105 |
+
X = X / (self.target_width - 1)
|
106 |
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
107 |
+
target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
|
108 |
+
target_coordinate_repr = torch.cat([
|
109 |
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
110 |
+
], dim = 1)
|
111 |
+
|
112 |
+
# register precomputed matrices
|
113 |
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
114 |
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
115 |
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
116 |
+
self.register_buffer('target_control_points', target_control_points)
|
117 |
+
|
118 |
+
def forward(self, input, source_control_points,direction='dewarp'):
|
119 |
+
if direction == 'dewarp':
|
120 |
+
assert source_control_points.ndimension() == 3
|
121 |
+
assert source_control_points.size(1) == self.num_control_points
|
122 |
+
assert source_control_points.size(2) == 2
|
123 |
+
batch_size = source_control_points.size(0)
|
124 |
+
|
125 |
+
Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1)
|
126 |
+
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
127 |
+
source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
|
128 |
+
|
129 |
+
grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
|
130 |
+
grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
|
131 |
+
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
132 |
+
grid = 2.0 * grid - 1.0
|
133 |
+
output_maps = grid_sample(input, grid, canvas=None)
|
134 |
+
return output_maps, source_coordinate
|
135 |
+
|
136 |
+
# elif direction == 'warp':
|
137 |
+
# target_control_points = source_control_points.clone()
|
138 |
+
# source_control_points = (build_output_control_points(self.num_control_points, self.margins)).clone()
|
139 |
+
# source_control_points = source_control_points.unsqueeze(0)
|
140 |
+
# source_control_points = source_control_points.expand(target_control_points.size(0),self.num_control_points,2)
|
141 |
+
# assert source_control_points.ndimension() == 3
|
142 |
+
# assert source_control_points.size(1) == self.num_control_points
|
143 |
+
# assert source_control_points.size(2) == 2
|
144 |
+
# batch_size = source_control_points.size(0)
|
145 |
+
|
146 |
+
# Y = torch.cat([source_control_points.to('cuda'), self.padding_matrix.expand(batch_size, 3, 2)], 1)
|
147 |
+
# mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
148 |
+
# source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix)
|
149 |
+
|
150 |
+
# grid = source_coordinate.view(-1, self.target_height, self.target_width, 2)
|
151 |
+
# grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1].
|
152 |
+
# # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
153 |
+
# grid = 2.0 * grid - 1.0
|
154 |
+
# output_maps = grid_sample(input, grid, canvas=None)
|
155 |
+
# return output_maps, source_coordinate
|
data/MBD/tps_grid_gen.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# encoding: utf-8
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import itertools
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Function, Variable
|
7 |
+
|
8 |
+
class TPSGridGen(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, target_height, target_width, target_control_points):
|
11 |
+
super(TPSGridGen, self).__init__()
|
12 |
+
assert target_control_points.ndimension() == 2
|
13 |
+
assert target_control_points.size(1) == 2
|
14 |
+
N = target_control_points.size(0)
|
15 |
+
self.num_points = N
|
16 |
+
target_control_points = target_control_points.float()
|
17 |
+
|
18 |
+
# create padded kernel matrix
|
19 |
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
20 |
+
target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
|
21 |
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
22 |
+
forward_kernel[:N, -3].fill_(1)
|
23 |
+
forward_kernel[-3, :N].fill_(1)
|
24 |
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
25 |
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
26 |
+
# compute inverse matrix
|
27 |
+
inverse_kernel = torch.inverse(forward_kernel)
|
28 |
+
|
29 |
+
# create target cordinate matrix
|
30 |
+
HW = target_height * target_width
|
31 |
+
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
|
32 |
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
33 |
+
Y, X = target_coordinate.split(1, dim = 1)
|
34 |
+
Y = Y * 2 / (target_height - 1) - 1
|
35 |
+
X = X * 2 / (target_width - 1) - 1
|
36 |
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
37 |
+
target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate, target_control_points)
|
38 |
+
target_coordinate_repr = torch.cat([
|
39 |
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
40 |
+
], dim = 1)
|
41 |
+
|
42 |
+
# register precomputed matrices
|
43 |
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
44 |
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
45 |
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
46 |
+
|
47 |
+
def forward(self, source_control_points):
|
48 |
+
assert source_control_points.ndimension() == 3
|
49 |
+
assert source_control_points.size(1) == self.num_points
|
50 |
+
assert source_control_points.size(2) == 2
|
51 |
+
batch_size = source_control_points.size(0)
|
52 |
+
|
53 |
+
Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
|
54 |
+
mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
|
55 |
+
source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
|
56 |
+
return source_coordinate
|
57 |
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
58 |
+
def compute_partial_repr(self, input_points, control_points):
|
59 |
+
N = input_points.size(0)
|
60 |
+
M = control_points.size(0)
|
61 |
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
62 |
+
# original implementation, very slow
|
63 |
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
64 |
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
65 |
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
66 |
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
67 |
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
68 |
+
mask = repr_matrix != repr_matrix
|
69 |
+
repr_matrix.masked_fill_(mask, 0)
|
70 |
+
return repr_matrix
|
data/MBD/utils.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Misc Utility functions
|
3 |
+
'''
|
4 |
+
from collections import OrderedDict
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import random
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
def recursive_glob(rootdir='.', suffix=''):
|
12 |
+
"""Performs recursive glob with given suffix and rootdir
|
13 |
+
:param rootdir is the root directory
|
14 |
+
:param suffix is the suffix to be searched
|
15 |
+
"""
|
16 |
+
return [os.path.join(looproot, filename)
|
17 |
+
for looproot, _, filenames in os.walk(rootdir)
|
18 |
+
for filename in filenames if filename.endswith(suffix)]
|
19 |
+
|
20 |
+
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
|
21 |
+
"""Polynomial decay of learning rate
|
22 |
+
:param init_lr is base learning rate
|
23 |
+
:param iter is a current iteration
|
24 |
+
:param lr_decay_iter how frequently decay occurs, default is 1
|
25 |
+
:param max_iter is number of maximum iterations
|
26 |
+
:param power is a polymomial power
|
27 |
+
|
28 |
+
"""
|
29 |
+
if iter % lr_decay_iter or iter > max_iter:
|
30 |
+
return optimizer
|
31 |
+
|
32 |
+
for param_group in optimizer.param_groups:
|
33 |
+
param_group['lr'] = init_lr*(1 - iter/max_iter)**power
|
34 |
+
|
35 |
+
|
36 |
+
def adjust_learning_rate(optimizer, init_lr, epoch):
|
37 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
38 |
+
lr = init_lr * (0.1 ** (epoch // 30))
|
39 |
+
for param_group in optimizer.param_groups:
|
40 |
+
param_group['lr'] = lr
|
41 |
+
|
42 |
+
|
43 |
+
def alpha_blend(input_image, segmentation_mask, alpha=0.5):
|
44 |
+
"""Alpha Blending utility to overlay RGB masks on RBG images
|
45 |
+
:param input_image is a np.ndarray with 3 channels
|
46 |
+
:param segmentation_mask is a np.ndarray with 3 channels
|
47 |
+
:param alpha is a float value
|
48 |
+
|
49 |
+
"""
|
50 |
+
blended = np.zeros(input_image.size, dtype=np.float32)
|
51 |
+
blended = input_image * alpha + segmentation_mask * (1 - alpha)
|
52 |
+
return blended
|
53 |
+
|
54 |
+
def convert_state_dict(state_dict):
|
55 |
+
"""Converts a state dict saved from a dataParallel module to normal
|
56 |
+
module state_dict inplace
|
57 |
+
:param state_dict is the loaded DataParallel model_state
|
58 |
+
|
59 |
+
"""
|
60 |
+
new_state_dict = OrderedDict()
|
61 |
+
for k, v in state_dict.items():
|
62 |
+
name = k[7:] # remove `module.`
|
63 |
+
new_state_dict[name] = v
|
64 |
+
return new_state_dict
|
65 |
+
|
66 |
+
|
67 |
+
class ImagePool():
|
68 |
+
def __init__(self, pool_size):
|
69 |
+
self.pool_size = pool_size
|
70 |
+
if self.pool_size > 0:
|
71 |
+
self.num_imgs = 0
|
72 |
+
self.images = []
|
73 |
+
|
74 |
+
def query(self, images):
|
75 |
+
if self.pool_size == 0:
|
76 |
+
return images
|
77 |
+
return_images = []
|
78 |
+
for image in images:
|
79 |
+
image = torch.unsqueeze(image.data, 0)
|
80 |
+
if self.num_imgs < self.pool_size:
|
81 |
+
self.num_imgs = self.num_imgs + 1
|
82 |
+
self.images.append(image)
|
83 |
+
return_images.append(image)
|
84 |
+
else:
|
85 |
+
p = random.uniform(0, 1)
|
86 |
+
if p > 0.5:
|
87 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
88 |
+
tmp = self.images[random_id].clone()
|
89 |
+
self.images[random_id] = image
|
90 |
+
return_images.append(tmp)
|
91 |
+
else:
|
92 |
+
return_images.append(image)
|
93 |
+
return_images = torch.cat(return_images, 0)
|
94 |
+
return return_images
|
95 |
+
|
96 |
+
|
97 |
+
def set_requires_grad(nets, requires_grad=False):
|
98 |
+
if not isinstance(nets, list):
|
99 |
+
nets = [nets]
|
100 |
+
for net in nets:
|
101 |
+
if net is not None:
|
102 |
+
for param in net.parameters():
|
103 |
+
param.requires_grad = requires_grad
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
def get_lr(optimizer):
|
108 |
+
for param_group in optimizer.param_groups:
|
109 |
+
return float(param_group['lr'])
|
110 |
+
|
111 |
+
def visualize(epoch,model,layer):
|
112 |
+
#get conv layers
|
113 |
+
conv_layers=[]
|
114 |
+
for m in model.modules():
|
115 |
+
if isinstance(m,torch.nn.modules.conv.Conv2d):
|
116 |
+
conv_layers.append(m)
|
117 |
+
|
118 |
+
# print conv_layers[layer].weight.data.cpu().numpy().shape
|
119 |
+
tensor=conv_layers[layer].weight.data.cpu()
|
120 |
+
vistensor(tensor, epoch, ch=0, allkernels=False, nrow=8, padding=1)
|
121 |
+
|
122 |
+
|
123 |
+
def vistensor(tensor, epoch, ch=0, allkernels=False, nrow=8, padding=1):
|
124 |
+
'''
|
125 |
+
vistensor: visuzlization tensor
|
126 |
+
@ch: visualization channel
|
127 |
+
@allkernels: visualization all tensors
|
128 |
+
https://github.com/pedrodiamel/pytorchvision/blob/a14672fe4b07995e99f8af755de875daf8aababb/pytvision/visualization.py#L325
|
129 |
+
'''
|
130 |
+
|
131 |
+
n,c,w,h = tensor.shape
|
132 |
+
if allkernels: tensor = tensor.view(n*c,-1,w,h )
|
133 |
+
elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)
|
134 |
+
|
135 |
+
rows = np.min( (tensor.shape[0]//nrow + 1, 64 ) )
|
136 |
+
# print rows
|
137 |
+
# print tensor.shape
|
138 |
+
grid = utils.make_grid(tensor, nrow=8, normalize=True, padding=padding)
|
139 |
+
# print grid.shape
|
140 |
+
plt.figure( figsize=(10,10), dpi=200 )
|
141 |
+
plt.imshow(grid.numpy().transpose((1, 2, 0)))
|
142 |
+
plt.savefig('./generated/filters_layer1_dwuv_'+str(epoch)+'.png')
|
143 |
+
plt.close()
|
144 |
+
|
145 |
+
|
146 |
+
def show_uloss(uwpred,uworg,inp_img, samples=7):
|
147 |
+
|
148 |
+
n,c,h,w=inp_img.shape
|
149 |
+
# print(labels.shape)
|
150 |
+
uwpred=uwpred.detach().cpu().numpy()
|
151 |
+
uworg=uworg.detach().cpu().numpy()
|
152 |
+
inp_img=inp_img.detach().cpu().numpy()
|
153 |
+
|
154 |
+
#NCHW->NHWC
|
155 |
+
uwpred=uwpred.transpose((0, 2, 3, 1))
|
156 |
+
uworg=uworg.transpose((0, 2, 3, 1))
|
157 |
+
|
158 |
+
choices=random.sample(range(n), min(n,samples))
|
159 |
+
f, axarr = plt.subplots(samples, 3)
|
160 |
+
for j in range(samples):
|
161 |
+
# print(np.min(labels[j]))
|
162 |
+
# print imgs[j].shape
|
163 |
+
img=inp_img[j].transpose(1,2,0)
|
164 |
+
axarr[j][0].imshow(img[:,:,::-1])
|
165 |
+
axarr[j][1].imshow(uworg[j])
|
166 |
+
axarr[j][2].imshow(uwpred[j])
|
167 |
+
|
168 |
+
plt.savefig('./generated/unwarp.png')
|
169 |
+
plt.close()
|
170 |
+
|
171 |
+
|
172 |
+
def show_uloss_visdom(vis,uwpred,uworg,labels_win,out_win,labelopts,outopts,args):
|
173 |
+
samples=7
|
174 |
+
n,c,h,w=uwpred.shape
|
175 |
+
uwpred=uwpred.detach().cpu().numpy()
|
176 |
+
uworg=uworg.detach().cpu().numpy()
|
177 |
+
out_arr=np.full((samples,3,args.img_rows,args.img_cols),0.0)
|
178 |
+
label_arr=np.full((samples,3,args.img_rows,args.img_cols),0.0)
|
179 |
+
choices=random.sample(range(n), min(n,samples))
|
180 |
+
idx=0
|
181 |
+
for c in choices:
|
182 |
+
out_arr[idx,:,:,:]=uwpred[c]
|
183 |
+
label_arr[idx,:,:,:]=uworg[c]
|
184 |
+
idx+=1
|
185 |
+
|
186 |
+
vis.images(out_arr,
|
187 |
+
win=out_win,
|
188 |
+
opts=outopts)
|
189 |
+
vis.images(label_arr,
|
190 |
+
win=labels_win,
|
191 |
+
opts=labelopts)
|
192 |
+
|
193 |
+
def show_unwarp_tnsboard(global_step,writer,uwpred,uworg,grid_samples,gt_tag,pred_tag):
|
194 |
+
idxs=torch.LongTensor(random.sample(range(images.shape[0]), min(grid_samples,images.shape[0])))
|
195 |
+
grid_uworg = torchvision.utils.make_grid(uworg[idxs],normalize=True, scale_each=True)
|
196 |
+
writer.add_image(gt_tag, grid_uworg, global_step)
|
197 |
+
grid_uwpr = torchvision.utils.make_grid(uwpred[idxs],normalize=True, scale_each=True)
|
198 |
+
writer.add_image(pred_tag, grid_uwpr, global_step)
|
199 |
+
|
200 |
+
def show_wc_tnsboard(global_step,writer,images,labels, pred, grid_samples,inp_tag, gt_tag, pred_tag):
|
201 |
+
idxs=torch.LongTensor(random.sample(range(images.shape[0]), min(grid_samples,images.shape[0])))
|
202 |
+
grid_inp = torchvision.utils.make_grid(images[idxs],normalize=True, scale_each=True)
|
203 |
+
writer.add_image(inp_tag, grid_inp, global_step)
|
204 |
+
grid_lbl = torchvision.utils.make_grid(labels[idxs],normalize=True, scale_each=True)
|
205 |
+
writer.add_image(gt_tag, grid_lbl, global_step)
|
206 |
+
grid_pred = torchvision.utils.make_grid(pred[idxs],normalize=True, scale_each=True)
|
207 |
+
writer.add_image(pred_tag, grid_pred, global_step)
|
208 |
+
def torch2cvimg(tensor,min=0,max=1):
|
209 |
+
'''
|
210 |
+
input:
|
211 |
+
tensor -> torch.tensor BxCxHxW C can be 1,3
|
212 |
+
return
|
213 |
+
im -> ndarray uint8 HxWxC
|
214 |
+
'''
|
215 |
+
im_list = []
|
216 |
+
for i in range(tensor.shape[0]):
|
217 |
+
im = tensor.detach().cpu().data.numpy()[i]
|
218 |
+
im = im.transpose(1,2,0)
|
219 |
+
im = np.clip(im,min,max)
|
220 |
+
im = ((im-min)/(max-min)*255).astype(np.uint8)
|
221 |
+
im_list.append(im)
|
222 |
+
return im_list
|
223 |
+
def cvimg2torch(img,min=0,max=1):
|
224 |
+
'''
|
225 |
+
input:
|
226 |
+
im -> ndarray uint8 HxWxC
|
227 |
+
return
|
228 |
+
tensor -> torch.tensor BxCxHxW
|
229 |
+
'''
|
230 |
+
img = img.astype(float) / 255.0
|
231 |
+
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
232 |
+
img = np.expand_dims(img, 0)
|
233 |
+
img = torch.from_numpy(img).float()
|
234 |
+
return img
|
data/README.md
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset Preparation
|
2 |
+
The data files tree should be look like:
|
3 |
+
```
|
4 |
+
data/
|
5 |
+
eval/
|
6 |
+
dir300/
|
7 |
+
1_in.png
|
8 |
+
1_gt.png
|
9 |
+
...
|
10 |
+
kligler/
|
11 |
+
jung/
|
12 |
+
osr/
|
13 |
+
realdae/
|
14 |
+
docunet_docaligner/
|
15 |
+
dibco18/
|
16 |
+
train/
|
17 |
+
dewarping/
|
18 |
+
doc3d/
|
19 |
+
deshadowing/
|
20 |
+
fsdsrd/
|
21 |
+
tdd/
|
22 |
+
appearance/
|
23 |
+
clean_pdfs/
|
24 |
+
realdae/
|
25 |
+
deblurring/
|
26 |
+
tdd/
|
27 |
+
binarization/
|
28 |
+
bickly/
|
29 |
+
dibco/
|
30 |
+
noise_office/
|
31 |
+
phibd/
|
32 |
+
msi/
|
33 |
+
```
|
34 |
+
|
35 |
+
## Evaluation Dataset
|
36 |
+
You can find the links for downloading the dataset we used for evaluation (Tables 1 and 2) in [this](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing/tree/master) repository, including DIR300 (300 samples), Kligler (300 samples), Jung (87 samples), OSR (237 samples), RealDAE (150 samples), DocUNet_DocAligner (150 samples), TDD (16000 samples) and DIBCO18 (10 samples). After downloading, add the suffix of `_in` and `_gt` to the input image and gt image respectively, and place them in the folder of the corresponding dataset
|
37 |
+
|
38 |
+
|
39 |
+
## Training Dataset
|
40 |
+
You can find the links for downloading the dataset we used for training in [this](https://github.com/ZZZHANG-jx/Recommendations-Document-Image-Processing/tree/master) repository.
|
41 |
+
### Dewarping
|
42 |
+
- Doc3D
|
43 |
+
- Mask extraction: you should extract the mask for each image from the uv data in Doc3D
|
44 |
+
- Background preparation: you can download the background data from [here](https://www.robots.ox.ac.uk/~vgg/data/dtd/) and specify it for self.background_paths in `loaders/docres_loader.py`
|
45 |
+
- JSON preparation:
|
46 |
+
```
|
47 |
+
|
48 |
+
[
|
49 |
+
## you need to specify the paths of 'in_path', 'mask_path and 'gt_path':
|
50 |
+
{
|
51 |
+
"in_path": "dewarping/doc3d/img/1/102_1-pp_Page_048-xov0001.png",
|
52 |
+
"mask_path": "dewarping/doc3d/mask/1/102_1-pp_Page_048-xov0001.png",
|
53 |
+
"gt_path": "dewarping/doc3d/bm/1/102_1-pp_Page_048-xov0001.npy"
|
54 |
+
}
|
55 |
+
]
|
56 |
+
|
57 |
+
```
|
58 |
+
### Deshadowing
|
59 |
+
- RDD
|
60 |
+
- FSDSRD
|
61 |
+
- JSON preparation
|
62 |
+
```
|
63 |
+
[ ## you need to specify the paths of 'in_path' and 'gt_path', for example:
|
64 |
+
{
|
65 |
+
"in_path": "deshadowing/fsdsrd/im/00004.png",
|
66 |
+
"gt_path": "deshadowing/fsdsrd/gt/00004.png"
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"in_path": "deshadowing/rdd/im/00004.png",
|
70 |
+
"gt_path": "deshadowing/rdd/gt/00004.png"
|
71 |
+
}
|
72 |
+
]
|
73 |
+
```
|
74 |
+
### Appearance enhancement
|
75 |
+
- Doc3DShade
|
76 |
+
- Clean PDFs collection: You should collection PDFs files from the internet and convert them as images to serve as the source for synthesis.
|
77 |
+
- Extract shadows from Doc3DShade by using `data/preprocess/shadow_extract.py` and dewarp the obtained shadows by using `data/MBD/infer.py`. Then you should specify self.shadow_paths in `loaders/docres_loader.py`
|
78 |
+
- RealDAE
|
79 |
+
- JSON preparation:
|
80 |
+
```
|
81 |
+
[
|
82 |
+
## for Doc3DShade dataset, you only need to specify the path of image from PDF, for example:
|
83 |
+
{
|
84 |
+
'gt_path':'appearance/clean_pdfs/1.jpg'
|
85 |
+
},
|
86 |
+
|
87 |
+
## for RealDAE dataset, you need to specify the paths of both input and gt, for example:
|
88 |
+
{
|
89 |
+
'in_path': 'appearance/realdae/1_in.jpg',
|
90 |
+
'gt_path': 'appearance/realdae/1_gt.jpg'
|
91 |
+
}
|
92 |
+
]
|
93 |
+
|
94 |
+
```
|
95 |
+
|
96 |
+
### Debluring
|
97 |
+
- TDD
|
98 |
+
- JSON preparation
|
99 |
+
```
|
100 |
+
[ ## you need to specify the paths of 'in_path' and 'gt_path', for example:
|
101 |
+
{
|
102 |
+
"in_path": "debluring/tdd/im/00004.png",
|
103 |
+
"gt_path": "debluring/tdd/gt/00004.png"
|
104 |
+
},
|
105 |
+
]
|
106 |
+
```
|
107 |
+
### Binarization
|
108 |
+
- Bickly
|
109 |
+
- DTPrompt preparation: Since the DTPrompt for binarization is time-expensive, we obtain it offline before training. Use `data/preprocess/sauvola_binarize.py`
|
110 |
+
- DIBCO
|
111 |
+
- DTPrompt preparation: the same as Bickly
|
112 |
+
- Noise Office
|
113 |
+
- DTPrompt preparation: the same as Bickly
|
114 |
+
- PHIDB
|
115 |
+
- DTPrompt preparation: the same as Bickly
|
116 |
+
- MSI
|
117 |
+
- DTPrompt preparation: the same as Bickly
|
118 |
+
- JSON preparation
|
119 |
+
```
|
120 |
+
[
|
121 |
+
## you need to specify the paths of 'in_path', 'gt_path', 'bin_path', 'thr_path' and 'gradient_path', for example:
|
122 |
+
{
|
123 |
+
"in_path": "binarization/noise_office/imgs/1.png",
|
124 |
+
"gt_path": "binarization/noise_office/gt_imgs/1.png",
|
125 |
+
"bin_path": "binarization/noise_office/imgs/1_bin.png",
|
126 |
+
"thr_path": "binarization/noise_office/imgs/1_thr.png",
|
127 |
+
"gradient_path": "binarization/noise_office/imgs/1_gradient.png"
|
128 |
+
},
|
129 |
+
]
|
130 |
+
```
|
131 |
+
|
132 |
+
After all the data are prepared, you should specify the dataset_setting in `train.py`.
|
133 |
+
|
134 |
+
|
135 |
+
|
data/preprocess/crop_merge_image.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
# SIZE =256
|
6 |
+
# BATCH_SIZE = 32
|
7 |
+
# STRIDES = 256
|
8 |
+
|
9 |
+
def split_img(img, size_x, size_y, strides):
|
10 |
+
max_y, max_x = img.shape[:2]
|
11 |
+
border_y = 0
|
12 |
+
if max_y % size_y != 0:
|
13 |
+
border_y = size_y - (max_y % size_y)
|
14 |
+
img = cv2.copyMakeBorder(img,border_y,0,0,0,cv2.BORDER_REPLICATE)
|
15 |
+
# img = cv2.copyMakeBorder(img, border_y, 0, 0, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
|
16 |
+
border_x = 0
|
17 |
+
if max_x % size_x != 0:
|
18 |
+
border_x = size_x - (max_x % size_x)
|
19 |
+
# img = cv2.copyMakeBorder(img, 0, 0, border_x, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
|
20 |
+
img = cv2.copyMakeBorder(img,0,0,border_x,0,cv2.BORDER_REPLICATE)
|
21 |
+
# h,w
|
22 |
+
max_y, max_x = img.shape[:2]
|
23 |
+
parts = []
|
24 |
+
curr_y = 0
|
25 |
+
x = 0
|
26 |
+
y = 0
|
27 |
+
# TODO: rewrite with generators.
|
28 |
+
while (curr_y + size_y) <= max_y:
|
29 |
+
curr_x = 0
|
30 |
+
while (curr_x + size_x) <= max_x:
|
31 |
+
parts.append(img[curr_y:curr_y + size_y, curr_x:curr_x + size_x])
|
32 |
+
curr_x += strides
|
33 |
+
y += 1
|
34 |
+
curr_y += strides
|
35 |
+
# parts is a list
|
36 |
+
# (windows_number_x*windows_number_y,SIZE,SIZE,3)
|
37 |
+
# print(max_y,max_x)
|
38 |
+
# print(y,x)
|
39 |
+
# print(np.array(parts).shape)
|
40 |
+
return parts, border_x, border_y, max_x, max_y
|
41 |
+
|
42 |
+
|
43 |
+
def combine_imgs(border_x,border_y,imgs, max_y, max_x,size_x, size_y, strides):
|
44 |
+
|
45 |
+
# weighted_img
|
46 |
+
|
47 |
+
index = int(size_x / strides)
|
48 |
+
weight_img = np.ones(shape=(max_y,max_x))
|
49 |
+
weight_img[0:strides] = index
|
50 |
+
weight_img[-strides:] = index
|
51 |
+
weight_img[:,0:strides]=index
|
52 |
+
weight_img[:,-strides:]=index
|
53 |
+
|
54 |
+
# 边上
|
55 |
+
i = 0
|
56 |
+
for j in range(1,index+1):
|
57 |
+
# 左上
|
58 |
+
weight_img[0:strides,i:i+strides] = np.ones(shape=(strides,strides))*j
|
59 |
+
weight_img[i:i+strides,0:strides] = np.ones(shape=(strides,strides))*j
|
60 |
+
# 右上
|
61 |
+
weight_img[i:i+strides,-strides:] = np.ones(shape=(strides,strides))*j
|
62 |
+
if i == 0:
|
63 |
+
weight_img[0:strides,-strides:] = np.ones(shape=(strides,strides))*j
|
64 |
+
else:
|
65 |
+
weight_img[0:strides,-strides-i:-i] = np.ones(shape=(strides,strides))*j
|
66 |
+
# 左下
|
67 |
+
weight_img[-strides:,i:i+strides] = np.ones(shape=(strides,strides))*j
|
68 |
+
if i == 0:
|
69 |
+
weight_img[-strides:,0:strides] = np.ones(shape=(strides,strides))*j
|
70 |
+
else:
|
71 |
+
weight_img[-strides-i:-i:,0:strides] = np.ones(shape=(strides,strides))*j
|
72 |
+
# 右下
|
73 |
+
if i == 0:
|
74 |
+
weight_img[-strides:,-strides:] = np.ones(shape=(strides,strides))*j
|
75 |
+
else:
|
76 |
+
weight_img[-strides-i:-i,-strides:] = np.ones(shape=(strides,strides))*j
|
77 |
+
weight_img[-strides:,-strides-i:-i] = np.ones(shape=(strides,strides))*j
|
78 |
+
|
79 |
+
|
80 |
+
i += strides
|
81 |
+
|
82 |
+
for i in range(strides,max_y-strides,strides):
|
83 |
+
for j in range(strides,max_x-strides,strides):
|
84 |
+
weight_img[i:i+strides,j:j+strides] = np.ones(shape=(strides,strides))*weight_img[i][0]*weight_img[0][j]
|
85 |
+
|
86 |
+
|
87 |
+
if len(imgs[0].shape)==2:
|
88 |
+
new_img = np.zeros(shape=(max_y,max_x))
|
89 |
+
weight_img = (1 / weight_img)
|
90 |
+
else:
|
91 |
+
new_img = np.zeros(shape=(max_y,max_x,imgs[0].shape[-1]))
|
92 |
+
weight_img = (1 / weight_img).reshape((max_y,max_x,1))
|
93 |
+
weight_img = np.tile(weight_img,(1,1,imgs[0].shape[-1]))
|
94 |
+
|
95 |
+
curr_y = 0
|
96 |
+
x = 0
|
97 |
+
y = 0
|
98 |
+
i = 0
|
99 |
+
# TODO: rewrite with generators.
|
100 |
+
while (curr_y + size_y) <= max_y:
|
101 |
+
curr_x = 0
|
102 |
+
while (curr_x + size_x) <= max_x:
|
103 |
+
new_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x] += weight_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x]*imgs[i]
|
104 |
+
i += 1
|
105 |
+
curr_x += strides
|
106 |
+
y += 1
|
107 |
+
curr_y += strides
|
108 |
+
|
109 |
+
|
110 |
+
new_img = new_img[border_y:, border_x:]
|
111 |
+
# print(border_y,border_x)
|
112 |
+
|
113 |
+
return new_img
|
114 |
+
|
115 |
+
|
116 |
+
def stride_integral(img,stride=32):
|
117 |
+
h,w = img.shape[:2]
|
118 |
+
|
119 |
+
if (h%stride)!=0:
|
120 |
+
padding_h = stride - (h%stride)
|
121 |
+
img = cv2.copyMakeBorder(img,padding_h,0,0,0,borderType=cv2.BORDER_REPLICATE)
|
122 |
+
else:
|
123 |
+
padding_h = 0
|
124 |
+
|
125 |
+
if (w%stride)!=0:
|
126 |
+
padding_w = stride - (w%stride)
|
127 |
+
img = cv2.copyMakeBorder(img,0,0,padding_w,0,borderType=cv2.BORDER_REPLICATE)
|
128 |
+
else:
|
129 |
+
padding_w = 0
|
130 |
+
|
131 |
+
return img,padding_h,padding_w
|
132 |
+
|
133 |
+
|
134 |
+
def mkdir_s(path: str):
|
135 |
+
"""Create directory in specified path, if not exists."""
|
136 |
+
if not os.path.exists(path):
|
137 |
+
os.makedirs(path)
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ =='__main__':
|
141 |
+
parts, border_x, border_y, max_x, max_y = split_img(im,512,512,strides=512)
|
142 |
+
result = combine_imgs(border_x,border_y,parts, max_y, max_x,512, 512, 512)
|
data/preprocess/sauvola_binarize.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
# importing required libraries
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from skimage.filters import threshold_sauvola
|
6 |
+
import glob
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
from skimage import io
|
10 |
+
|
11 |
+
def SauvolaModBinarization(image,n1=51,n2=51,k1=0.3,k2=0.3,default=True):
|
12 |
+
'''
|
13 |
+
Binarization using Sauvola's algorithm
|
14 |
+
@name : SauvolaModBinarization
|
15 |
+
parameters
|
16 |
+
@param image (numpy array of shape (3/1) of type np.uint8): color or gray scale image
|
17 |
+
optional parameters
|
18 |
+
@param n1 (int) : window size for running sauvola during the first pass
|
19 |
+
@param n2 (int): window size for running sauvola during the second pass
|
20 |
+
@param k1 (float): k value corresponding to sauvola during the first pass
|
21 |
+
@param k2 (float): k value corresponding to sauvola during the second pass
|
22 |
+
@param default (bool) : bollean variable to set the above parameter as default.
|
23 |
+
@param default is set to True : thus default values of the above optional parameters (n1,n2,k1,k2) are set to
|
24 |
+
n1 = 5 % of min(image height, image width)
|
25 |
+
n2 = 10 % of min(image height, image width)
|
26 |
+
k1 = 0.5
|
27 |
+
k2 = 0.5
|
28 |
+
Returns
|
29 |
+
@return A binary image of same size as @param image
|
30 |
+
|
31 |
+
@cite https://drive.google.com/file/d/1D3CyI5vtodPJeZaD2UV5wdcaIMtkBbdZ/view?usp=sharing
|
32 |
+
'''
|
33 |
+
|
34 |
+
if(default):
|
35 |
+
n1 = int(0.05*min(image.shape[0],image.shape[1]))
|
36 |
+
if (n1%2==0):
|
37 |
+
n1 = n1+1
|
38 |
+
n2 = int(0.1*min(image.shape[0],image.shape[1]))
|
39 |
+
if (n2%2==0):
|
40 |
+
n2 = n2+1
|
41 |
+
k1 = 0.5
|
42 |
+
k2 = 0.5
|
43 |
+
if(image.ndim==3):
|
44 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
45 |
+
else:
|
46 |
+
gray = np.copy(image)
|
47 |
+
T1 = threshold_sauvola(gray, window_size=n1,k=k1)
|
48 |
+
max_val = np.amax(gray)
|
49 |
+
min_val = np.amin(gray)
|
50 |
+
C = np.copy(T1)
|
51 |
+
C = C.astype(np.float32)
|
52 |
+
C[gray > T1] = (gray[gray > T1] - T1[gray > T1])/(max_val - T1[gray > T1])
|
53 |
+
C[gray <= T1] = 0
|
54 |
+
C = C * 255.0
|
55 |
+
new_in = np.copy(C.astype(np.uint8))
|
56 |
+
T2 = threshold_sauvola(new_in, window_size=n2,k=k2)
|
57 |
+
binary = np.copy(gray)
|
58 |
+
binary[new_in <= T2] = 0
|
59 |
+
binary[new_in > T2] = 255
|
60 |
+
return binary,T2
|
61 |
+
|
62 |
+
|
63 |
+
def dtprompt(img):
|
64 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
65 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
66 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
67 |
+
absY = cv2.convertScaleAbs(y)
|
68 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
69 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
70 |
+
return high_frequency
|
71 |
+
|
72 |
+
|
73 |
+
im_paths = glob.glob('imgs/*')
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
for im_path in tqdm(im_paths):
|
78 |
+
if '_bin.' in im_path:
|
79 |
+
continue
|
80 |
+
if '_thr.' in im_path:
|
81 |
+
continue
|
82 |
+
if '_gradient.' in im_path:
|
83 |
+
continue
|
84 |
+
|
85 |
+
im = cv2.imread(im_path)
|
86 |
+
result,thresh = SauvolaModBinarization(im)
|
87 |
+
gradient = dtprompt(im)
|
88 |
+
thresh = thresh.astype(np.uint8)
|
89 |
+
cv2.imwrite(im_path.replace('.','_bin.'),result)
|
90 |
+
cv2.imwrite(im_path.replace('.','_thr.'),thresh)
|
91 |
+
cv2.imwrite(im_path.replace('.','_gradient.'),gradient)
|
data/preprocess/shadow_extraction.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
import random
|
7 |
+
|
8 |
+
im_paths = glob.glob('./img/*/*')
|
9 |
+
|
10 |
+
random.shuffle(im_paths)
|
11 |
+
|
12 |
+
for im_path in tqdm(im_paths):
|
13 |
+
# im_path = './img/1/23-180_5-y4_Page_034-wVO0001-L1_3-T_6600-I_5535.png'
|
14 |
+
if '-L1_' in im_path:
|
15 |
+
alb_path = im_path.split('-L1_')[0].replace('img/','alb/') + '.png'
|
16 |
+
else:
|
17 |
+
alb_path = im_path.split('-L2_')[0].replace('img/','alb/') + '.png'
|
18 |
+
|
19 |
+
if not os.path.exists(alb_path):
|
20 |
+
print(im_path)
|
21 |
+
print(alb_path)
|
22 |
+
|
23 |
+
im = cv2.imread(im_path)
|
24 |
+
alb = cv2.imread(alb_path)
|
25 |
+
_, mask = cv2.threshold(cv2.cvtColor(alb,cv2.COLOR_BGR2GRAY), 1, 255, cv2.THRESH_BINARY)
|
26 |
+
|
27 |
+
|
28 |
+
## clean
|
29 |
+
# std = np.max(np.std(alb,axis=-1))
|
30 |
+
# print(std)
|
31 |
+
im_min = np.min(im,axis=-1)
|
32 |
+
kernel = np.ones((3,3))
|
33 |
+
mask_erode = cv2.dilate(mask,kernel=kernel)
|
34 |
+
mask_erode = cv2.erode(mask_erode,kernel=kernel)
|
35 |
+
mask_erode = cv2.erode(mask_erode,iterations=4,kernel=kernel)
|
36 |
+
metric = np.min(im_min[mask_erode==255])
|
37 |
+
metric_num = 0
|
38 |
+
if metric==0 or metric==1:
|
39 |
+
metric_num = np.sum(im_min[mask_erode==255]==metric)
|
40 |
+
if metric_num>=20:
|
41 |
+
alb_temp = alb.astype(np.float64)
|
42 |
+
alb_temp[alb_temp==0] = alb_temp[alb_temp==0]+1e-5
|
43 |
+
shadow = np.clip(im.astype(np.float64)/alb_temp,0,1)
|
44 |
+
shadow = (shadow*255).astype(np.uint8)
|
45 |
+
|
46 |
+
shadow_path = im_path.replace('img/','temp/')
|
47 |
+
cv2.imwrite(shadow_path,shadow)
|
48 |
+
continue
|
49 |
+
|
50 |
+
|
51 |
+
alb_temp = alb.astype(np.float64)
|
52 |
+
alb_temp[alb_temp==0] = alb_temp[alb_temp==0]+1e-5
|
53 |
+
shadow = np.clip(im.astype(np.float64)/alb_temp,0,1)
|
54 |
+
shadow = (shadow*255).astype(np.uint8)
|
55 |
+
|
56 |
+
shadow_path = im_path.replace('img/','shadow/')
|
57 |
+
cv2.imwrite(shadow_path,shadow)
|
58 |
+
|
59 |
+
mask_path = im_path.replace('img/','mask/')
|
60 |
+
cv2.imwrite(mask_path,mask)
|
61 |
+
|
62 |
+
# cv2.imshow('im',im)
|
63 |
+
# cv2.imshow('alb',alb)
|
64 |
+
# cv2.imshow('shadow',shadow)
|
65 |
+
# cv2.imshow('mask_erode',mask_erode)
|
66 |
+
# print(im_min[mask_erode==255])
|
67 |
+
# print(metric,metric_num)
|
68 |
+
# cv2.waitKey(0)
|
eval.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import utils
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from skimage.metrics import structural_similarity,peak_signal_noise_ratio
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from utils import convert_state_dict
|
13 |
+
from models import restormer_arch
|
14 |
+
from data.preprocess.crop_merge_image import stride_integral
|
15 |
+
|
16 |
+
os.sys.path.append('./data/MBD/')
|
17 |
+
from data.MBD.infer import net1_net2_infer_single_im
|
18 |
+
|
19 |
+
|
20 |
+
def dewarp_prompt(img):
|
21 |
+
mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
|
22 |
+
base_coord = utils.getBasecoord(256,256)/256
|
23 |
+
img[mask==0]=0
|
24 |
+
mask = cv2.resize(mask,(256,256))/255
|
25 |
+
return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
|
26 |
+
|
27 |
+
def deshadow_prompt(img):
|
28 |
+
h,w = img.shape[:2]
|
29 |
+
# img = cv2.resize(img,(128,128))
|
30 |
+
img = cv2.resize(img,(1024,1024))
|
31 |
+
rgb_planes = cv2.split(img)
|
32 |
+
result_planes = []
|
33 |
+
result_norm_planes = []
|
34 |
+
bg_imgs = []
|
35 |
+
for plane in rgb_planes:
|
36 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
37 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
38 |
+
bg_imgs.append(bg_img)
|
39 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
40 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
41 |
+
result_planes.append(diff_img)
|
42 |
+
result_norm_planes.append(norm_img)
|
43 |
+
bg_imgs = cv2.merge(bg_imgs)
|
44 |
+
bg_imgs = cv2.resize(bg_imgs,(w,h))
|
45 |
+
# result = cv2.merge(result_planes)
|
46 |
+
result_norm = cv2.merge(result_norm_planes)
|
47 |
+
result_norm[result_norm==0]=1
|
48 |
+
shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
|
49 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
50 |
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
|
51 |
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
|
52 |
+
# return shadow_map
|
53 |
+
return bg_imgs
|
54 |
+
|
55 |
+
def deblur_prompt(img):
|
56 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
57 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
58 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
59 |
+
absY = cv2.convertScaleAbs(y)
|
60 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
61 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
62 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
|
63 |
+
return high_frequency
|
64 |
+
|
65 |
+
def appearance_prompt(img):
|
66 |
+
h,w = img.shape[:2]
|
67 |
+
# img = cv2.resize(img,(128,128))
|
68 |
+
img = cv2.resize(img,(1024,1024))
|
69 |
+
rgb_planes = cv2.split(img)
|
70 |
+
result_planes = []
|
71 |
+
result_norm_planes = []
|
72 |
+
for plane in rgb_planes:
|
73 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
74 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
75 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
76 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
77 |
+
result_planes.append(diff_img)
|
78 |
+
result_norm_planes.append(norm_img)
|
79 |
+
result_norm = cv2.merge(result_norm_planes)
|
80 |
+
result_norm = cv2.resize(result_norm,(w,h))
|
81 |
+
return result_norm
|
82 |
+
|
83 |
+
def binarization_promptv2(img):
|
84 |
+
result,thresh = utils.SauvolaModBinarization(img)
|
85 |
+
thresh = thresh.astype(np.uint8)
|
86 |
+
result[result>155]=255
|
87 |
+
result[result<=155]=0
|
88 |
+
|
89 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
90 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
91 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
92 |
+
absY = cv2.convertScaleAbs(y)
|
93 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
94 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
95 |
+
return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
|
96 |
+
|
97 |
+
def dewarping(model,im_path):
|
98 |
+
INPUT_SIZE=256
|
99 |
+
im_org = cv2.imread(im_path)
|
100 |
+
im_masked, prompt_org = dewarp_prompt(im_org.copy())
|
101 |
+
|
102 |
+
h,w = im_masked.shape[:2]
|
103 |
+
im_masked = im_masked.copy()
|
104 |
+
im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
|
105 |
+
im_masked = im_masked / 255.0
|
106 |
+
im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
|
107 |
+
im_masked = im_masked.float().to(DEVICE)
|
108 |
+
|
109 |
+
prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
|
110 |
+
prompt = prompt.float().to(DEVICE)
|
111 |
+
|
112 |
+
in_im = torch.cat((im_masked,prompt),dim=1)
|
113 |
+
|
114 |
+
# inference
|
115 |
+
base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
|
116 |
+
model = model.float()
|
117 |
+
with torch.no_grad():
|
118 |
+
pred = model(in_im)
|
119 |
+
pred = pred[0][:2].permute(1,2,0).cpu().numpy()
|
120 |
+
pred = pred+base_coord
|
121 |
+
## smooth
|
122 |
+
for i in range(15):
|
123 |
+
pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
|
124 |
+
pred = cv2.resize(pred,(w,h))*(w,h)
|
125 |
+
pred = pred.astype(np.float32)
|
126 |
+
out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
|
127 |
+
|
128 |
+
prompt_org = (prompt_org*255).astype(np.uint8)
|
129 |
+
prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
|
130 |
+
|
131 |
+
return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
|
132 |
+
|
133 |
+
def appearance(model,im_path):
|
134 |
+
MAX_SIZE=1600
|
135 |
+
# obtain im and prompt
|
136 |
+
im_org = cv2.imread(im_path)
|
137 |
+
h,w = im_org.shape[:2]
|
138 |
+
prompt = appearance_prompt(im_org)
|
139 |
+
in_im = np.concatenate((im_org,prompt),-1)
|
140 |
+
|
141 |
+
# constrain the max resolution
|
142 |
+
if max(w,h) < MAX_SIZE:
|
143 |
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
144 |
+
else:
|
145 |
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
146 |
+
|
147 |
+
# normalize
|
148 |
+
in_im = in_im / 255.0
|
149 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
150 |
+
|
151 |
+
# inference
|
152 |
+
in_im = in_im.half().to(DEVICE)
|
153 |
+
model = model.half()
|
154 |
+
with torch.no_grad():
|
155 |
+
pred = model(in_im)
|
156 |
+
pred = torch.clamp(pred,0,1)
|
157 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
158 |
+
pred = (pred*255).astype(np.uint8)
|
159 |
+
|
160 |
+
if max(w,h) < MAX_SIZE:
|
161 |
+
out_im = pred[padding_h:,padding_w:]
|
162 |
+
else:
|
163 |
+
pred[pred==0] = 1
|
164 |
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
165 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
166 |
+
shadow_map[shadow_map==0]=0.00001
|
167 |
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
168 |
+
|
169 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
170 |
+
|
171 |
+
|
172 |
+
def deshadowing(model,im_path):
|
173 |
+
MAX_SIZE=1600
|
174 |
+
# obtain im and prompt
|
175 |
+
im_org = cv2.imread(im_path)
|
176 |
+
h,w = im_org.shape[:2]
|
177 |
+
prompt = deshadow_prompt(im_org)
|
178 |
+
in_im = np.concatenate((im_org,prompt),-1)
|
179 |
+
|
180 |
+
# constrain the max resolution
|
181 |
+
if max(w,h) < MAX_SIZE:
|
182 |
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
183 |
+
else:
|
184 |
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
185 |
+
|
186 |
+
# normalize
|
187 |
+
in_im = in_im / 255.0
|
188 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
189 |
+
|
190 |
+
# inference
|
191 |
+
in_im = in_im.half().to(DEVICE)
|
192 |
+
model = model.half()
|
193 |
+
with torch.no_grad():
|
194 |
+
pred = model(in_im)
|
195 |
+
pred = torch.clamp(pred,0,1)
|
196 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
197 |
+
pred = (pred*255).astype(np.uint8)
|
198 |
+
|
199 |
+
if max(w,h) < MAX_SIZE:
|
200 |
+
out_im = pred[padding_h:,padding_w:]
|
201 |
+
else:
|
202 |
+
pred[pred==0]=1
|
203 |
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
204 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
205 |
+
shadow_map[shadow_map==0]=0.00001
|
206 |
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
207 |
+
|
208 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
209 |
+
|
210 |
+
|
211 |
+
def deblurring(model,im_path):
|
212 |
+
# setup image
|
213 |
+
im_org = cv2.imread(im_path)
|
214 |
+
in_im,padding_h,padding_w = stride_integral(im_org,8)
|
215 |
+
prompt = deblur_prompt(in_im)
|
216 |
+
in_im = np.concatenate((in_im,prompt),-1)
|
217 |
+
in_im = in_im / 255.0
|
218 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
219 |
+
in_im = in_im.half().to(DEVICE)
|
220 |
+
# inference
|
221 |
+
model.to(DEVICE)
|
222 |
+
model.eval()
|
223 |
+
model = model.half()
|
224 |
+
with torch.no_grad():
|
225 |
+
pred = model(in_im)
|
226 |
+
pred = torch.clamp(pred,0,1)
|
227 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
228 |
+
pred = (pred*255).astype(np.uint8)
|
229 |
+
out_im = pred[padding_h:,padding_w:]
|
230 |
+
|
231 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
def binarization(model,im_path):
|
236 |
+
im_org = cv2.imread(im_path)
|
237 |
+
im,padding_h,padding_w = stride_integral(im_org,8)
|
238 |
+
prompt = binarization_promptv2(im)
|
239 |
+
h,w = im.shape[:2]
|
240 |
+
in_im = np.concatenate((im,prompt),-1)
|
241 |
+
|
242 |
+
in_im = in_im / 255.0
|
243 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
244 |
+
in_im = in_im.to(DEVICE)
|
245 |
+
model = model.half()
|
246 |
+
in_im = in_im.half()
|
247 |
+
with torch.no_grad():
|
248 |
+
pred = model(in_im,'binarization')
|
249 |
+
pred = pred[:,:2,:,:]
|
250 |
+
pred = torch.max(torch.softmax(pred,1),1)[1]
|
251 |
+
pred = pred[0].cpu().numpy()
|
252 |
+
pred = (pred*255).astype(np.uint8)
|
253 |
+
pred = cv2.resize(pred,(w,h))
|
254 |
+
out_im = pred[padding_h:,padding_w:]
|
255 |
+
|
256 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
def get_args():
|
263 |
+
parser = argparse.ArgumentParser(description='Params')
|
264 |
+
parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
|
265 |
+
parser.add_argument('--dataset', nargs='?', type=str, default='./distorted/',help='Path of input document image')
|
266 |
+
args = parser.parse_args()
|
267 |
+
assert args.dataset in all_datasets.keys(), 'Unregisted dataset, dataset must be one of '+', '.join(all_datasets)
|
268 |
+
return args
|
269 |
+
|
270 |
+
def model_init(args):
|
271 |
+
# prepare model
|
272 |
+
model = restormer_arch.Restormer(
|
273 |
+
inp_channels=6,
|
274 |
+
out_channels=3,
|
275 |
+
dim = 48,
|
276 |
+
num_blocks = [2,3,3,4],
|
277 |
+
num_refinement_blocks = 4,
|
278 |
+
heads = [1,2,4,8],
|
279 |
+
ffn_expansion_factor = 2.66,
|
280 |
+
bias = False,
|
281 |
+
LayerNorm_type = 'WithBias',
|
282 |
+
dual_pixel_task = True
|
283 |
+
)
|
284 |
+
|
285 |
+
if DEVICE.type == 'cpu':
|
286 |
+
state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
|
287 |
+
else:
|
288 |
+
state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
|
289 |
+
model.load_state_dict(state)
|
290 |
+
|
291 |
+
model.eval()
|
292 |
+
model = model.to(DEVICE)
|
293 |
+
return model
|
294 |
+
|
295 |
+
def inference_one_im(model,im_path,task):
|
296 |
+
if task=='dewarping':
|
297 |
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
298 |
+
elif task=='deshadowing':
|
299 |
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
|
300 |
+
elif task=='appearance':
|
301 |
+
prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
|
302 |
+
elif task=='deblurring':
|
303 |
+
prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
|
304 |
+
elif task=='binarization':
|
305 |
+
prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
|
306 |
+
elif task=='end2end':
|
307 |
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
308 |
+
cv2.imwrite('./temp.jpg',restorted)
|
309 |
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,'./temp.jpg')
|
310 |
+
cv2.imwrite('./temp.jpg',restorted)
|
311 |
+
prompt1,prompt2,prompt3,restorted = appearance(model,'./temp.jpg')
|
312 |
+
os.remove('./temp.jpg')
|
313 |
+
|
314 |
+
return prompt1,prompt2,prompt3,restorted
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == '__main__':
|
319 |
+
all_datasets = {'dir300':'dewarping','kligler':'deshadowing','jung':'deshadowing','osr':'deshadowing','docunet_docaligner':'appearance','realdae':'appearance','tdd':'deblurring','dibco18':'binarization'}
|
320 |
+
|
321 |
+
## model init
|
322 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
323 |
+
args = get_args()
|
324 |
+
model = model_init(args)
|
325 |
+
|
326 |
+
## inference
|
327 |
+
print('Predicting')
|
328 |
+
task = all_datasets[args.dataset]
|
329 |
+
im_paths = glob.glob(os.path.join('./data/eval/',args.dataset,'*_in.*'))
|
330 |
+
for im_path in tqdm(im_paths):
|
331 |
+
_,_,_,restorted = inference_one_im(model,im_path,task)
|
332 |
+
cv2.imwrite(im_path.replace('_in','_docres'),restorted)
|
333 |
+
|
334 |
+
## obtain metric
|
335 |
+
print('Metric calculating')
|
336 |
+
if task == 'dewarping':
|
337 |
+
exit()
|
338 |
+
elif task=='deshadowing' or task=='appearance' or task=='deblurring':
|
339 |
+
psnr = []
|
340 |
+
ssim = []
|
341 |
+
for im_path in tqdm(im_paths):
|
342 |
+
pred = cv2.imread(im_path.replace('_in','_docres'))
|
343 |
+
gt = cv2.imread(im_path.replace('_in','_gt'))
|
344 |
+
ssim.append(structural_similarity(pred,gt,multichannel=True))
|
345 |
+
psnr.append(peak_signal_noise_ratio(pred, gt))
|
346 |
+
print(args.dataset)
|
347 |
+
print('ssim:',np.mean(ssim))
|
348 |
+
print('psnr:',np.mean(psnr))
|
349 |
+
elif task=='binarization':
|
350 |
+
fmeasures, pfmeasures,psnrs = [],[],[]
|
351 |
+
for im_path in tqdm(im_paths):
|
352 |
+
pred = cv2.imread(im_path.replace('_in','_docres'))
|
353 |
+
gt = cv2.imread(im_path.replace('_in','_gt'))
|
354 |
+
pred = cv2.cvtColor(pred,cv2.COLOR_BGR2GRAY)
|
355 |
+
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2GRAY)
|
356 |
+
pred[pred>155]=255
|
357 |
+
pred[pred<=155]=0
|
358 |
+
gt[gt>155]=255
|
359 |
+
gt[gt<=155]=0
|
360 |
+
fmeasure, pfmeasure,psnr,_,_,_ = utils.bin_metric(pred,gt)
|
361 |
+
fmeasures.append(fmeasure)
|
362 |
+
pfmeasures.append(pfmeasure)
|
363 |
+
psnrs.append(psnr)
|
364 |
+
print(args.dataset)
|
365 |
+
print('fmeasure:',np.mean(fmeasures))
|
366 |
+
print('pfmeasure:',np.mean(pfmeasures))
|
367 |
+
print('psnr:',np.mean(psnrs))
|
368 |
+
|
369 |
+
|
inference.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import utils
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from utils import convert_state_dict
|
10 |
+
from models import restormer_arch
|
11 |
+
from data.preprocess.crop_merge_image import stride_integral
|
12 |
+
|
13 |
+
os.sys.path.append('./data/MBD/')
|
14 |
+
from data.MBD.infer import net1_net2_infer_single_im
|
15 |
+
|
16 |
+
|
17 |
+
def dewarp_prompt(img):
|
18 |
+
mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
|
19 |
+
base_coord = utils.getBasecoord(256,256)/256
|
20 |
+
img[mask==0]=0
|
21 |
+
mask = cv2.resize(mask,(256,256))/255
|
22 |
+
return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
|
23 |
+
|
24 |
+
def deshadow_prompt(img):
|
25 |
+
h,w = img.shape[:2]
|
26 |
+
# img = cv2.resize(img,(128,128))
|
27 |
+
img = cv2.resize(img,(1024,1024))
|
28 |
+
rgb_planes = cv2.split(img)
|
29 |
+
result_planes = []
|
30 |
+
result_norm_planes = []
|
31 |
+
bg_imgs = []
|
32 |
+
for plane in rgb_planes:
|
33 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
34 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
35 |
+
bg_imgs.append(bg_img)
|
36 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
37 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
38 |
+
result_planes.append(diff_img)
|
39 |
+
result_norm_planes.append(norm_img)
|
40 |
+
bg_imgs = cv2.merge(bg_imgs)
|
41 |
+
bg_imgs = cv2.resize(bg_imgs,(w,h))
|
42 |
+
# result = cv2.merge(result_planes)
|
43 |
+
result_norm = cv2.merge(result_norm_planes)
|
44 |
+
result_norm[result_norm==0]=1
|
45 |
+
shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
|
46 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
47 |
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
|
48 |
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
|
49 |
+
# return shadow_map
|
50 |
+
return bg_imgs
|
51 |
+
|
52 |
+
def deblur_prompt(img):
|
53 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
54 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
55 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
56 |
+
absY = cv2.convertScaleAbs(y)
|
57 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
58 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
59 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
|
60 |
+
return high_frequency
|
61 |
+
|
62 |
+
def appearance_prompt(img):
|
63 |
+
h,w = img.shape[:2]
|
64 |
+
# img = cv2.resize(img,(128,128))
|
65 |
+
img = cv2.resize(img,(1024,1024))
|
66 |
+
rgb_planes = cv2.split(img)
|
67 |
+
result_planes = []
|
68 |
+
result_norm_planes = []
|
69 |
+
for plane in rgb_planes:
|
70 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
71 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
72 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
73 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
74 |
+
result_planes.append(diff_img)
|
75 |
+
result_norm_planes.append(norm_img)
|
76 |
+
result_norm = cv2.merge(result_norm_planes)
|
77 |
+
result_norm = cv2.resize(result_norm,(w,h))
|
78 |
+
return result_norm
|
79 |
+
|
80 |
+
def binarization_promptv2(img):
|
81 |
+
result,thresh = utils.SauvolaModBinarization(img)
|
82 |
+
thresh = thresh.astype(np.uint8)
|
83 |
+
result[result>155]=255
|
84 |
+
result[result<=155]=0
|
85 |
+
|
86 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
87 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
88 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
89 |
+
absY = cv2.convertScaleAbs(y)
|
90 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
91 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
92 |
+
return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
|
93 |
+
|
94 |
+
def dewarping(model,im_path):
|
95 |
+
INPUT_SIZE=256
|
96 |
+
im_org = cv2.imread(im_path)
|
97 |
+
im_masked, prompt_org = dewarp_prompt(im_org.copy())
|
98 |
+
|
99 |
+
h,w = im_masked.shape[:2]
|
100 |
+
im_masked = im_masked.copy()
|
101 |
+
im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
|
102 |
+
im_masked = im_masked / 255.0
|
103 |
+
im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
|
104 |
+
im_masked = im_masked.float().to(DEVICE)
|
105 |
+
|
106 |
+
prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
|
107 |
+
prompt = prompt.float().to(DEVICE)
|
108 |
+
|
109 |
+
in_im = torch.cat((im_masked,prompt),dim=1)
|
110 |
+
|
111 |
+
# inference
|
112 |
+
base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
|
113 |
+
model = model.float()
|
114 |
+
with torch.no_grad():
|
115 |
+
pred = model(in_im)
|
116 |
+
pred = pred[0][:2].permute(1,2,0).cpu().numpy()
|
117 |
+
pred = pred+base_coord
|
118 |
+
## smooth
|
119 |
+
for i in range(15):
|
120 |
+
pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
|
121 |
+
pred = cv2.resize(pred,(w,h))*(w,h)
|
122 |
+
pred = pred.astype(np.float32)
|
123 |
+
out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
|
124 |
+
|
125 |
+
prompt_org = (prompt_org*255).astype(np.uint8)
|
126 |
+
prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
|
127 |
+
|
128 |
+
return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
|
129 |
+
|
130 |
+
def appearance(model,im_path):
|
131 |
+
MAX_SIZE=1600
|
132 |
+
# obtain im and prompt
|
133 |
+
im_org = cv2.imread(im_path)
|
134 |
+
h,w = im_org.shape[:2]
|
135 |
+
prompt = appearance_prompt(im_org)
|
136 |
+
in_im = np.concatenate((im_org,prompt),-1)
|
137 |
+
|
138 |
+
# constrain the max resolution
|
139 |
+
if max(w,h) < MAX_SIZE:
|
140 |
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
141 |
+
else:
|
142 |
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
143 |
+
|
144 |
+
# normalize
|
145 |
+
in_im = in_im / 255.0
|
146 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
147 |
+
|
148 |
+
# inference
|
149 |
+
in_im = in_im.half().to(DEVICE)
|
150 |
+
model = model.half()
|
151 |
+
with torch.no_grad():
|
152 |
+
pred = model(in_im)
|
153 |
+
pred = torch.clamp(pred,0,1)
|
154 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
155 |
+
pred = (pred*255).astype(np.uint8)
|
156 |
+
|
157 |
+
if max(w,h) < MAX_SIZE:
|
158 |
+
out_im = pred[padding_h:,padding_w:]
|
159 |
+
else:
|
160 |
+
pred[pred==0] = 1
|
161 |
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
162 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
163 |
+
shadow_map[shadow_map==0]=0.00001
|
164 |
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
165 |
+
|
166 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
167 |
+
|
168 |
+
|
169 |
+
def deshadowing(model,im_path):
|
170 |
+
MAX_SIZE=1600
|
171 |
+
# obtain im and prompt
|
172 |
+
im_org = cv2.imread(im_path)
|
173 |
+
h,w = im_org.shape[:2]
|
174 |
+
prompt = deshadow_prompt(im_org)
|
175 |
+
in_im = np.concatenate((im_org,prompt),-1)
|
176 |
+
|
177 |
+
# constrain the max resolution
|
178 |
+
if max(w,h) < MAX_SIZE:
|
179 |
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
180 |
+
else:
|
181 |
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
182 |
+
|
183 |
+
# normalize
|
184 |
+
in_im = in_im / 255.0
|
185 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
186 |
+
|
187 |
+
# inference
|
188 |
+
in_im = in_im.half().to(DEVICE)
|
189 |
+
model = model.half()
|
190 |
+
with torch.no_grad():
|
191 |
+
pred = model(in_im)
|
192 |
+
pred = torch.clamp(pred,0,1)
|
193 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
194 |
+
pred = (pred*255).astype(np.uint8)
|
195 |
+
|
196 |
+
if max(w,h) < MAX_SIZE:
|
197 |
+
out_im = pred[padding_h:,padding_w:]
|
198 |
+
else:
|
199 |
+
pred[pred==0]=1
|
200 |
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
201 |
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
202 |
+
shadow_map[shadow_map==0]=0.00001
|
203 |
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
204 |
+
|
205 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
206 |
+
|
207 |
+
|
208 |
+
def deblurring(model,im_path):
|
209 |
+
# setup image
|
210 |
+
im_org = cv2.imread(im_path)
|
211 |
+
in_im,padding_h,padding_w = stride_integral(im_org,8)
|
212 |
+
prompt = deblur_prompt(in_im)
|
213 |
+
in_im = np.concatenate((in_im,prompt),-1)
|
214 |
+
in_im = in_im / 255.0
|
215 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
216 |
+
in_im = in_im.half().to(DEVICE)
|
217 |
+
# inference
|
218 |
+
model.to(DEVICE)
|
219 |
+
model.eval()
|
220 |
+
model = model.half()
|
221 |
+
with torch.no_grad():
|
222 |
+
pred = model(in_im)
|
223 |
+
pred = torch.clamp(pred,0,1)
|
224 |
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
225 |
+
pred = (pred*255).astype(np.uint8)
|
226 |
+
out_im = pred[padding_h:,padding_w:]
|
227 |
+
|
228 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
def binarization(model,im_path):
|
233 |
+
im_org = cv2.imread(im_path)
|
234 |
+
im,padding_h,padding_w = stride_integral(im_org,8)
|
235 |
+
prompt = binarization_promptv2(im)
|
236 |
+
h,w = im.shape[:2]
|
237 |
+
in_im = np.concatenate((im,prompt),-1)
|
238 |
+
|
239 |
+
in_im = in_im / 255.0
|
240 |
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
241 |
+
in_im = in_im.to(DEVICE)
|
242 |
+
model = model.half()
|
243 |
+
in_im = in_im.half()
|
244 |
+
with torch.no_grad():
|
245 |
+
pred = model(in_im)
|
246 |
+
pred = pred[:,:2,:,:]
|
247 |
+
pred = torch.max(torch.softmax(pred,1),1)[1]
|
248 |
+
pred = pred[0].cpu().numpy()
|
249 |
+
pred = (pred*255).astype(np.uint8)
|
250 |
+
pred = cv2.resize(pred,(w,h))
|
251 |
+
out_im = pred[padding_h:,padding_w:]
|
252 |
+
|
253 |
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
def get_args():
|
260 |
+
parser = argparse.ArgumentParser(description='Params')
|
261 |
+
parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
|
262 |
+
parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
|
263 |
+
help='Path of input document image')
|
264 |
+
parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
|
265 |
+
help='Folder of the output images')
|
266 |
+
parser.add_argument('--task', nargs='?', type=str, default='dewarping',
|
267 |
+
help='task that need to be executed')
|
268 |
+
parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
|
269 |
+
help='Width of the input image')
|
270 |
+
args = parser.parse_args()
|
271 |
+
possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
|
272 |
+
assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
|
273 |
+
return args
|
274 |
+
|
275 |
+
def model_init(args):
|
276 |
+
# prepare model
|
277 |
+
model = restormer_arch.Restormer(
|
278 |
+
inp_channels=6,
|
279 |
+
out_channels=3,
|
280 |
+
dim = 48,
|
281 |
+
num_blocks = [2,3,3,4],
|
282 |
+
num_refinement_blocks = 4,
|
283 |
+
heads = [1,2,4,8],
|
284 |
+
ffn_expansion_factor = 2.66,
|
285 |
+
bias = False,
|
286 |
+
LayerNorm_type = 'WithBias',
|
287 |
+
dual_pixel_task = True
|
288 |
+
)
|
289 |
+
|
290 |
+
if DEVICE.type == 'cpu':
|
291 |
+
state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
|
292 |
+
else:
|
293 |
+
state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
|
294 |
+
model.load_state_dict(state)
|
295 |
+
|
296 |
+
model.eval()
|
297 |
+
model = model.to(DEVICE)
|
298 |
+
return model
|
299 |
+
|
300 |
+
def inference_one_im(model,im_path,task):
|
301 |
+
if task=='dewarping':
|
302 |
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
303 |
+
elif task=='deshadowing':
|
304 |
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
|
305 |
+
elif task=='appearance':
|
306 |
+
prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
|
307 |
+
elif task=='deblurring':
|
308 |
+
prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
|
309 |
+
elif task=='binarization':
|
310 |
+
prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
|
311 |
+
elif task=='end2end':
|
312 |
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
313 |
+
cv2.imwrite('restorted/step1.jpg',restorted)
|
314 |
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg')
|
315 |
+
cv2.imwrite('restorted/step2.jpg',restorted)
|
316 |
+
prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
|
317 |
+
# os.remove('restorted/step1.jpg')
|
318 |
+
# os.remove('restorted/step2.jpg')
|
319 |
+
|
320 |
+
return prompt1,prompt2,prompt3,restorted
|
321 |
+
|
322 |
+
|
323 |
+
|
324 |
+
if __name__ == '__main__':
|
325 |
+
## model init
|
326 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
327 |
+
args = get_args()
|
328 |
+
model = model_init(args)
|
329 |
+
|
330 |
+
## inference
|
331 |
+
prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task)
|
332 |
+
|
333 |
+
## results saving
|
334 |
+
im_name = os.path.split(args.im_path)[-1]
|
335 |
+
im_format = '.'+im_name.split('.')[-1]
|
336 |
+
save_path = os.path.join(args.out_folder,im_name.replace(im_format,'_'+args.task+im_format))
|
337 |
+
cv2.imwrite(save_path,restorted)
|
338 |
+
if args.save_dtsprompt:
|
339 |
+
cv2.imwrite(save_path.replace(im_format,'_prompt1'+im_format),prompt1)
|
340 |
+
cv2.imwrite(save_path.replace(im_format,'_prompt2'+im_format),prompt2)
|
341 |
+
cv2.imwrite(save_path.replace(im_format,'_prompt3'+im_format),prompt3)
|
loaders/docres_loader.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join as pjoin
|
3 |
+
import collections
|
4 |
+
import json
|
5 |
+
from numpy.lib.histograms import histogram_bin_edges
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import random
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils import data
|
12 |
+
import glob
|
13 |
+
|
14 |
+
class DocResTrainDataset(data.Dataset):
|
15 |
+
def __init__(self, dataset={}, img_size=512,):
|
16 |
+
json_paths = dataset['json_paths']
|
17 |
+
self.task = dataset['task']
|
18 |
+
self.size = img_size
|
19 |
+
self.im_path = dataset['im_path']
|
20 |
+
|
21 |
+
self.datas = []
|
22 |
+
for json_path in json_paths:
|
23 |
+
with open(json_path,'r') as f:
|
24 |
+
data = json.load(f)
|
25 |
+
self.datas += data
|
26 |
+
|
27 |
+
self.background_paths = glob.glob('/data2/jiaxin/Training_Data/dewarping/doc_3d/background/*/*/*')
|
28 |
+
self.shadow_paths = glob.glob('/data2/jiaxin/Training_Data/illumination/doc3dshadow/new_shadow/*/*')
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.datas)
|
32 |
+
|
33 |
+
def __getitem__(self, index):
|
34 |
+
data = self.datas[index]
|
35 |
+
in_im,gt_im,dtsprompt = self.data_processing(self.task,data)
|
36 |
+
|
37 |
+
return torch.cat((in_im,dtsprompt),0), gt_im
|
38 |
+
|
39 |
+
def data_processing(self,task,data):
|
40 |
+
|
41 |
+
if task=='deblurring':
|
42 |
+
## image prepare
|
43 |
+
in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
|
44 |
+
gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
|
45 |
+
dtsprompt = self.deblur_dtsprompt(in_im)
|
46 |
+
## get prompt
|
47 |
+
in_im, gt_im,dtsprompt = self.randomcrop([in_im,gt_im,dtsprompt])
|
48 |
+
in_im = self.rgbim_transform(in_im)
|
49 |
+
gt_im = self.rgbim_transform(gt_im)
|
50 |
+
dtsprompt = self.rgbim_transform(dtsprompt)
|
51 |
+
elif task =='dewarping':
|
52 |
+
## image prepare
|
53 |
+
in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
|
54 |
+
mask = cv2.imread(os.path.join(self.im_path,data['mask_path']))[:,:,0]
|
55 |
+
bm = np.load(os.path.join(self.im_path,data['gt_path'])).astype(np.float) #-> 0-448
|
56 |
+
bm = cv2.resize(bm,(448,448))
|
57 |
+
## add background
|
58 |
+
background = cv2.imread(random.choice(self.background_paths))
|
59 |
+
min_length = min(background.shape[:2])
|
60 |
+
crop_size = random.randint(int(min_length*0.5),min_length-1)
|
61 |
+
shift_y = np.random.randint(0,background.shape[1]-crop_size)
|
62 |
+
shift_x = np.random.randint(0,background.shape[0]-crop_size)
|
63 |
+
background = background[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
|
64 |
+
background = cv2.resize(background,(448,448))
|
65 |
+
if np.mean(in_im[mask==0])<10:
|
66 |
+
in_im[mask==0]=background[mask==0]
|
67 |
+
## random crop and get prompt
|
68 |
+
in_im,mask,bm = self.random_margin_bm(in_im,mask,bm) # bm-> 0-1
|
69 |
+
in_im = cv2.resize(in_im,(self.size,self.size))
|
70 |
+
mask = cv2.resize(mask,(self.size,self.size))
|
71 |
+
mask_aug = self.mask_augment(mask)
|
72 |
+
in_im[mask_aug==0]=0
|
73 |
+
bm = cv2.resize(bm,(self.size,self.size)) # bm-> 0-1
|
74 |
+
bm_shift = (bm*self.size - self.getBasecoord(self.size,self.size))/self.size
|
75 |
+
base_coord = self.getBasecoord(self.size,self.size)/self.size
|
76 |
+
|
77 |
+
in_im = self.rgbim_transform(in_im)
|
78 |
+
base_coord = base_coord.transpose(2, 0, 1)
|
79 |
+
base_coord = torch.from_numpy(base_coord)
|
80 |
+
|
81 |
+
bm_shift = bm_shift.transpose(2, 0, 1)
|
82 |
+
bm_shift = torch.from_numpy(bm_shift)
|
83 |
+
|
84 |
+
mask[mask>155] = 255
|
85 |
+
mask[mask<=155] = 0
|
86 |
+
mask = mask/255
|
87 |
+
mask = np.expand_dims(mask,-1)
|
88 |
+
mask = mask.transpose(2, 0, 1)
|
89 |
+
mask = torch.from_numpy(mask)
|
90 |
+
|
91 |
+
mask_aug[mask_aug>155] = 255
|
92 |
+
mask_aug[mask_aug<=155] = 0
|
93 |
+
mask_aug = mask_aug/255
|
94 |
+
mask_aug = np.expand_dims(mask_aug,-1)
|
95 |
+
mask_aug = mask_aug.transpose(2, 0, 1)
|
96 |
+
mask_aug = torch.from_numpy(mask_aug)
|
97 |
+
|
98 |
+
in_im = in_im
|
99 |
+
gt_im = torch.cat((bm_shift,mask),0)
|
100 |
+
dtsprompt = torch.cat((base_coord,mask_aug),0)
|
101 |
+
|
102 |
+
elif task == 'binarization':
|
103 |
+
## image prepare
|
104 |
+
in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
|
105 |
+
gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
|
106 |
+
## get prompt
|
107 |
+
thr = cv2.imread(os.path.join(self.im_path,data['thr_path']))
|
108 |
+
bin_map = cv2.imread(os.path.join(self.im_path,data['bin_path']))
|
109 |
+
gradient = cv2.imread(os.path.join(self.im_path,data['gradient_path']))
|
110 |
+
bin_map[bin_map>155]=255
|
111 |
+
bin_map[bin_map<=155]=0
|
112 |
+
in_im, gt_im,thr,bin_map,gradient = self.randomcrop([in_im,gt_im,thr,bin_map,gradient])
|
113 |
+
in_im = self.randomAugment_binarization(in_im)
|
114 |
+
gt_im[gt_im>155]=255
|
115 |
+
gt_im[gt_im<=155]=0
|
116 |
+
gt_im = gt_im[:,:,0]
|
117 |
+
## transform
|
118 |
+
in_im = self.rgbim_transform(in_im)
|
119 |
+
thr = self.rgbim_transform(thr)
|
120 |
+
gradient = self.rgbim_transform(gradient)
|
121 |
+
bin_map = self.rgbim_transform(bin_map)
|
122 |
+
gt_im = gt_im.astype(np.float)/255.
|
123 |
+
gt_im = torch.from_numpy(gt_im)
|
124 |
+
gt_im = gt_im.unsqueeze(0)
|
125 |
+
dtsprompt = torch.cat((thr[0].unsqueeze(0),gradient[0].unsqueeze(0),bin_map[0].unsqueeze(0)),0)
|
126 |
+
elif task == 'deshadowing':
|
127 |
+
|
128 |
+
in_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
|
129 |
+
gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
|
130 |
+
shadow_im = self.deshadow_dtsprompt(in_im)
|
131 |
+
if 'fsdsrd' in data['in_path']:
|
132 |
+
in_im = cv2.resize(in_im,(512,512))
|
133 |
+
gt_im = cv2.resize(gt_im,(512,512))
|
134 |
+
shadow_im = cv2.resize(shadow_im,(512,512))
|
135 |
+
in_im, gt_im,shadow_im = self.randomcrop([in_im,gt_im,shadow_im])
|
136 |
+
else:
|
137 |
+
in_im, gt_im,shadow_im = self.randomcrop([in_im,gt_im,shadow_im])
|
138 |
+
in_im = self.rgbim_transform(in_im)
|
139 |
+
gt_im = self.rgbim_transform(gt_im)
|
140 |
+
shadow_im = self.rgbim_transform(shadow_im)
|
141 |
+
dtsprompt = shadow_im
|
142 |
+
|
143 |
+
elif task == 'appearance':
|
144 |
+
if 'in_path' in data.keys():
|
145 |
+
cap_im = cv2.imread(os.path.join(self.im_path,data['in_path']))
|
146 |
+
gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
|
147 |
+
gt_im,cap_im = self.randomcrop_realdae(gt_im,cap_im)
|
148 |
+
cap_im = self.appearance_randomAugmentv1(cap_im)
|
149 |
+
enhance_result = self.appearance_dtsprompt(cap_im)
|
150 |
+
else:
|
151 |
+
gt_im = cv2.imread(os.path.join(self.im_path,data['gt_path']))
|
152 |
+
bleed_im = cv2.imread(os.path.join(self.im_path,random.choice(self.datas)['gt_path']))
|
153 |
+
bleed_im = cv2.resize(bleed_im,gt_im.shape[:2][::-1])
|
154 |
+
gt_im = self.randomcrop([gt_im])[0]
|
155 |
+
bleed_im = self.randomcrop([bleed_im])[0]
|
156 |
+
cap_im = self.bleed_trough(gt_im,bleed_im)
|
157 |
+
|
158 |
+
shadow_path = random.choice(self.shadow_paths)
|
159 |
+
shadow_im = cv2.imread(shadow_path)
|
160 |
+
cap_im = self.appearance_randomAugmentv2(cap_im,shadow_im)
|
161 |
+
enhance_result = self.appearance_dtsprompt(cap_im)
|
162 |
+
|
163 |
+
|
164 |
+
in_im = self.rgbim_transform(cap_im)
|
165 |
+
gt_im = self.rgbim_transform(gt_im)
|
166 |
+
dtsprompt = self.rgbim_transform(enhance_result)
|
167 |
+
|
168 |
+
return in_im, gt_im,dtsprompt
|
169 |
+
|
170 |
+
def randomcrop(self,im_list):
|
171 |
+
im_num = len(im_list)
|
172 |
+
## random scale rotate
|
173 |
+
if random.uniform(0,1) <= 0.8:
|
174 |
+
y,x = im_list[0].shape[:2]
|
175 |
+
angle = random.uniform(-180,180)
|
176 |
+
scale = random.uniform(0.7,1.5)
|
177 |
+
M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
|
178 |
+
for i in range(im_num):
|
179 |
+
im_list[i] = cv2.warpAffine(im_list[i],M,(x,y),borderValue=(255,255,255))
|
180 |
+
|
181 |
+
## random crop
|
182 |
+
crop_size = self.size
|
183 |
+
for i in range(im_num):
|
184 |
+
h,w = im_list[i].shape[:2]
|
185 |
+
h = max(h,crop_size)
|
186 |
+
w = max(w,crop_size)
|
187 |
+
im_list[i] = cv2.resize(im_list[i],(w,h))
|
188 |
+
|
189 |
+
if h==crop_size:
|
190 |
+
shift_y=0
|
191 |
+
else:
|
192 |
+
shift_y = np.random.randint(0,h-crop_size)
|
193 |
+
if w==crop_size:
|
194 |
+
shift_x=0
|
195 |
+
else:
|
196 |
+
shift_x = np.random.randint(0,w-crop_size)
|
197 |
+
for i in range(im_num):
|
198 |
+
im_list[i] = im_list[i][shift_y:shift_y+crop_size,shift_x:shift_x+crop_size,:]
|
199 |
+
return im_list
|
200 |
+
|
201 |
+
def deblur_dtsprompt(self,img):
|
202 |
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
203 |
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
204 |
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
205 |
+
absY = cv2.convertScaleAbs(y)
|
206 |
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
207 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
208 |
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
|
209 |
+
return high_frequency
|
210 |
+
|
211 |
+
|
212 |
+
def appearance_dtsprompt(self,img):
|
213 |
+
h,w = img.shape[:2]
|
214 |
+
img = cv2.resize(img,(1024,1024))
|
215 |
+
rgb_planes = cv2.split(img)
|
216 |
+
result_planes = []
|
217 |
+
result_norm_planes = []
|
218 |
+
for plane in rgb_planes:
|
219 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
220 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
221 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
222 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
223 |
+
result_planes.append(diff_img)
|
224 |
+
result_norm_planes.append(norm_img)
|
225 |
+
result_norm = cv2.merge(result_norm_planes)
|
226 |
+
result_norm = cv2.resize(result_norm,(w,h))
|
227 |
+
return result_norm
|
228 |
+
|
229 |
+
|
230 |
+
def rgbim_transform(self,im):
|
231 |
+
im = im.astype(np.float)/255.
|
232 |
+
im = im.transpose(2, 0, 1)
|
233 |
+
im = torch.from_numpy(im)
|
234 |
+
return im
|
235 |
+
|
236 |
+
|
237 |
+
def random_margin_bm(self,in_im,msk,bm):
|
238 |
+
size = in_im.shape[:2]
|
239 |
+
[y, x] = (msk).nonzero()
|
240 |
+
minx = min(x)
|
241 |
+
maxx = max(x)
|
242 |
+
miny = min(y)
|
243 |
+
maxy = max(y)
|
244 |
+
|
245 |
+
s = 20
|
246 |
+
s = int(20*size[0]/128)
|
247 |
+
difference = int(5*size[0]/128)
|
248 |
+
cx1 = random.randint(0, s - difference)
|
249 |
+
cx2 = random.randint(0, s - difference) + 1
|
250 |
+
cy1 = random.randint(0, s - difference)
|
251 |
+
cy2 = random.randint(0, s - difference) + 1
|
252 |
+
|
253 |
+
t = miny-s+cy1
|
254 |
+
b = size[0]-maxy-s+cy2
|
255 |
+
l = minx-s+cx1
|
256 |
+
r = size[1]-maxx-s+cx2
|
257 |
+
|
258 |
+
t = max(0,t)
|
259 |
+
b = max(0,b)
|
260 |
+
l = max(0,l)
|
261 |
+
r = max(0,r)
|
262 |
+
|
263 |
+
in_im = in_im[t:size[0]-b,l:size[1]-r]
|
264 |
+
msk = msk[t:size[0]-b,l:size[1]-r]
|
265 |
+
bm[:,:,1]=bm[:,:,1]-t
|
266 |
+
bm[:,:,0]=bm[:,:,0]-l
|
267 |
+
bm=bm/np.array([448-l-r, 448-t-b])
|
268 |
+
|
269 |
+
return in_im,msk,bm
|
270 |
+
|
271 |
+
def mask_augment(self,mask):
|
272 |
+
if random.uniform(0,1) <= 0.6:
|
273 |
+
if random.uniform(0,1) <= 0.5:
|
274 |
+
mask = cv2.resize(mask,(64,64))
|
275 |
+
else:
|
276 |
+
mask = cv2.resize(mask,(128,128))
|
277 |
+
mask = cv2.resize(mask,(256,256))
|
278 |
+
mask[mask>155] = 255
|
279 |
+
mask[mask<=155] = 0
|
280 |
+
return mask
|
281 |
+
|
282 |
+
def bleed_trough(self, in_im, bleed_im):
|
283 |
+
if random.uniform(0,1) <= 0.5:
|
284 |
+
if random.uniform(0,1) <= 0.8:
|
285 |
+
ksize = np.random.randint(1,2)*2 + 1
|
286 |
+
bleed_im = cv2.blur(bleed_im,(ksize,ksize))
|
287 |
+
bleed_im = cv2.flip(bleed_im,1)
|
288 |
+
alpha = random.uniform(0.75,1)
|
289 |
+
in_im = cv2.addWeighted(in_im,alpha,bleed_im,1-alpha,0)
|
290 |
+
return in_im
|
291 |
+
|
292 |
+
def getBasecoord(self,h,w):
|
293 |
+
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
|
294 |
+
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
|
295 |
+
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
|
296 |
+
return base_coord
|
297 |
+
|
298 |
+
|
299 |
+
def randomcrop_realdae(self,gt_im,cap_im):
|
300 |
+
if random.uniform(0,1) <= 0.5:
|
301 |
+
y,x = gt_im.shape[:2]
|
302 |
+
angle = random.uniform(-30,30)
|
303 |
+
scale = random.uniform(0.8,1.5)
|
304 |
+
M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
|
305 |
+
gt_im = cv2.warpAffine(gt_im,M,(x,y),borderValue=(255,255,255))
|
306 |
+
cap_im = cv2.warpAffine(cap_im,M,(x,y),borderValue=(255,255,255))
|
307 |
+
crop_size = self.size
|
308 |
+
if gt_im.shape[0] <= crop_size:
|
309 |
+
gt_im = cv2.copyMakeBorder(gt_im,crop_size-gt_im.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
|
310 |
+
cap_im = cv2.copyMakeBorder(cap_im,crop_size-cap_im.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
|
311 |
+
if gt_im.shape[1] <= crop_size:
|
312 |
+
gt_im = cv2.copyMakeBorder(gt_im,0,0,crop_size-gt_im.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
|
313 |
+
cap_im = cv2.copyMakeBorder(cap_im,0,0,crop_size-cap_im.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(255,255,255))
|
314 |
+
shift_y = np.random.randint(0,gt_im.shape[1]-crop_size)
|
315 |
+
shift_x = np.random.randint(0,gt_im.shape[0]-crop_size)
|
316 |
+
gt_im = gt_im[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
|
317 |
+
cap_im = cap_im[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
|
318 |
+
return gt_im,cap_im
|
319 |
+
|
320 |
+
|
321 |
+
def randomAugment_binarization(self,in_img):
|
322 |
+
h,w = in_img.shape[:2]
|
323 |
+
## brightness
|
324 |
+
if random.uniform(0,1) <= 0.5:
|
325 |
+
high = 1.3
|
326 |
+
low = 0.8
|
327 |
+
ratio = np.random.uniform(low,high)
|
328 |
+
in_img = in_img.astype(np.float64)*ratio
|
329 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
330 |
+
## contrast
|
331 |
+
if random.uniform(0,1) <= 0.5:
|
332 |
+
high = 1.3
|
333 |
+
low = 0.8
|
334 |
+
ratio = np.random.uniform(low,high)
|
335 |
+
gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
|
336 |
+
mean = np.mean(gray)
|
337 |
+
mean_array = np.ones_like(in_img).astype(np.float64)*mean
|
338 |
+
in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
|
339 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
340 |
+
## color
|
341 |
+
if random.uniform(0,1) <= 0.5:
|
342 |
+
high = 0.2
|
343 |
+
low = 0.1
|
344 |
+
ratio = np.random.uniform(0.1,0.3)
|
345 |
+
random_color = np.random.randint(50,200,3).reshape(1,1,3)
|
346 |
+
random_color = (random_color*ratio).astype(np.uint8)
|
347 |
+
random_color = np.tile(random_color,(self.size,self.size,1))
|
348 |
+
in_img = in_img.astype(np.float64)*(1-ratio) + random_color
|
349 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
350 |
+
return in_img
|
351 |
+
|
352 |
+
|
353 |
+
def deshadow_dtsprompt(self,img):
|
354 |
+
h,w = img.shape[:2]
|
355 |
+
img = cv2.resize(img,(1024,1024))
|
356 |
+
rgb_planes = cv2.split(img)
|
357 |
+
result_planes = []
|
358 |
+
result_norm_planes = []
|
359 |
+
bg_imgs = []
|
360 |
+
for plane in rgb_planes:
|
361 |
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
362 |
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
363 |
+
bg_imgs.append(bg_img)
|
364 |
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
365 |
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
366 |
+
result_planes.append(diff_img)
|
367 |
+
result_norm_planes.append(norm_img)
|
368 |
+
result_norm = cv2.merge(result_norm_planes)
|
369 |
+
bg_imgs = cv2.merge(bg_imgs)
|
370 |
+
bg_imgs = cv2.resize(bg_imgs,(w,h))
|
371 |
+
return bg_imgs
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
+
def randomAugment(self,in_img,gt_img,shadow_img):
|
382 |
+
h,w = in_img.shape[:2]
|
383 |
+
# random crop
|
384 |
+
crop_size = random.randint(128,1024)
|
385 |
+
if shadow_img.shape[0] <= crop_size:
|
386 |
+
shadow_img = cv2.copyMakeBorder(shadow_img,crop_size-shadow_img.shape[0]+1,0,0,0,borderType=cv2.BORDER_CONSTANT,value=(128,128,128))
|
387 |
+
if shadow_img.shape[1] <= crop_size:
|
388 |
+
shadow_img = cv2.copyMakeBorder(shadow_img,0,0,crop_size-shadow_img.shape[1]+1,0,borderType=cv2.BORDER_CONSTANT,value=(128,128,128))
|
389 |
+
|
390 |
+
shift_y = np.random.randint(0,shadow_img.shape[1]-crop_size)
|
391 |
+
shift_x = np.random.randint(0,shadow_img.shape[0]-crop_size)
|
392 |
+
shadow_img = shadow_img[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
|
393 |
+
shadow_img = cv2.resize(shadow_img,(w,h))
|
394 |
+
in_img = in_img.astype(np.float64)*(shadow_img.astype(np.float64)+1)/255
|
395 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
396 |
+
|
397 |
+
## brightness
|
398 |
+
if random.uniform(0,1) <= 0.5:
|
399 |
+
high = 1.3
|
400 |
+
low = 0.8
|
401 |
+
ratio = np.random.uniform(low,high)
|
402 |
+
in_img = in_img.astype(np.float64)*ratio
|
403 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
404 |
+
## contrast
|
405 |
+
if random.uniform(0,1) <= 0.5:
|
406 |
+
high = 1.3
|
407 |
+
low = 0.8
|
408 |
+
ratio = np.random.uniform(low,high)
|
409 |
+
gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
|
410 |
+
mean = np.mean(gray)
|
411 |
+
mean_array = np.ones_like(in_img).astype(np.float64)*mean
|
412 |
+
in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
|
413 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
414 |
+
## color
|
415 |
+
if random.uniform(0,1) <= 0.5:
|
416 |
+
high = 0.2
|
417 |
+
low = 0.1
|
418 |
+
ratio = np.random.uniform(0.1,0.3)
|
419 |
+
random_color = np.random.randint(50,200,3).reshape(1,1,3)
|
420 |
+
random_color = (random_color*ratio).astype(np.uint8)
|
421 |
+
random_color = np.tile(random_color,(self.img_size[0],self.img_size[1],1))
|
422 |
+
in_img = in_img.astype(np.float64)*(1-ratio) + random_color
|
423 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
424 |
+
## scale and rotate
|
425 |
+
if random.uniform(0,1) <= 0:
|
426 |
+
y,x = self.img_size
|
427 |
+
angle = random.uniform(-180,180)
|
428 |
+
scale = random.uniform(0.5,1.5)
|
429 |
+
M = cv2.getRotationMatrix2D((int(x/2),int(y/2)),angle,scale)
|
430 |
+
in_img = cv2.warpAffine(in_img,M,(x,y),borderValue=0)
|
431 |
+
gt_img = cv2.warpAffine(gt_img,M,(x,y),borderValue=0)
|
432 |
+
# add noise
|
433 |
+
## jpegcompression
|
434 |
+
quanlity_high = 95
|
435 |
+
quanlity_low = 45
|
436 |
+
quanlity = int(np.random.randint(quanlity_low,quanlity_high))
|
437 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),quanlity]
|
438 |
+
result, encimg = cv2.imencode('.jpg',in_img,encode_param)
|
439 |
+
in_img = cv2.imdecode(encimg,1).astype(np.uint8)
|
440 |
+
## gaussiannoise
|
441 |
+
mean = 0
|
442 |
+
sigma = 0.02
|
443 |
+
noise_ratio = 0.004
|
444 |
+
num_noise = int(np.ceil(noise_ratio*w))
|
445 |
+
coords = [np.random.randint(0,i-1,int(num_noise)) for i in [h,w]]
|
446 |
+
gauss = np.random.normal(mean,sigma,num_noise*3)*255
|
447 |
+
guass = np.reshape(gauss,(-1,3))
|
448 |
+
in_img = in_img.astype(np.float64)
|
449 |
+
in_img[tuple(coords)] += guass
|
450 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
451 |
+
## blur
|
452 |
+
ksize = np.random.randint(1,2)*2 + 1
|
453 |
+
in_img = cv2.blur(in_img,(ksize,ksize))
|
454 |
+
|
455 |
+
## erase
|
456 |
+
if random.uniform(0,1) <= 0.7:
|
457 |
+
for i in range(100):
|
458 |
+
area = int(np.random.uniform(0.01,0.05)*h*w)
|
459 |
+
ration = np.random.uniform(0.3,1/0.3)
|
460 |
+
h_shift = int(np.sqrt(area*ration))
|
461 |
+
w_shift = int(np.sqrt(area/ration))
|
462 |
+
if (h_shift<h) and (w_shift<w):
|
463 |
+
break
|
464 |
+
h_start = np.random.randint(0,h-h_shift)
|
465 |
+
w_start = np.random.randint(0,w-w_shift)
|
466 |
+
randm_area = np.random.randint(low=0,high=255,size=(h_shift,w_shift,3))
|
467 |
+
in_img[h_start:h_start+h_shift,w_start:w_start+w_shift,:] = randm_area
|
468 |
+
|
469 |
+
|
470 |
+
return in_img, gt_img
|
471 |
+
|
472 |
+
|
473 |
+
def appearance_randomAugmentv1(self,in_img):
|
474 |
+
|
475 |
+
## brightness
|
476 |
+
if random.uniform(0,1) <= 0.8:
|
477 |
+
high = 1.3
|
478 |
+
low = 0.5
|
479 |
+
ratio = np.random.uniform(low,high)
|
480 |
+
in_img = in_img.astype(np.float64)*ratio
|
481 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
482 |
+
## contrast
|
483 |
+
if random.uniform(0,1) <= 0.8:
|
484 |
+
high = 1.3
|
485 |
+
low = 0.5
|
486 |
+
ratio = np.random.uniform(low,high)
|
487 |
+
gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
|
488 |
+
mean = np.mean(gray)
|
489 |
+
mean_array = np.ones_like(in_img).astype(np.float64)*mean
|
490 |
+
in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
|
491 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
492 |
+
## color
|
493 |
+
if random.uniform(0,1) <= 0.8:
|
494 |
+
high = 0.2
|
495 |
+
low = 0.1
|
496 |
+
ratio = np.random.uniform(0.1,0.3)
|
497 |
+
random_color = np.random.randint(50,200,3).reshape(1,1,3)
|
498 |
+
random_color = (random_color*ratio).astype(np.uint8)
|
499 |
+
random_color = np.tile(random_color,(self.size,self.size,1))
|
500 |
+
in_img = in_img.astype(np.float64)*(1-ratio) + random_color
|
501 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
502 |
+
|
503 |
+
return in_img
|
504 |
+
|
505 |
+
|
506 |
+
def appearance_randomAugmentv2(self,in_img,shadow_img):
|
507 |
+
h,w = in_img.shape[:2]
|
508 |
+
# random crop
|
509 |
+
crop_size = random.randint(96,1024)
|
510 |
+
if shadow_img.shape[0] <= crop_size:
|
511 |
+
shadow_img = cv2.resize(shadow_img,(crop_size+1,crop_size+1))
|
512 |
+
if shadow_img.shape[1] <= crop_size:
|
513 |
+
shadow_img = cv2.resize(shadow_img,(crop_size+1,crop_size+1))
|
514 |
+
|
515 |
+
shift_y = np.random.randint(0,shadow_img.shape[1]-crop_size)
|
516 |
+
shift_x = np.random.randint(0,shadow_img.shape[0]-crop_size)
|
517 |
+
shadow_img = shadow_img[shift_x:shift_x+crop_size,shift_y:shift_y+crop_size,:]
|
518 |
+
shadow_img = cv2.resize(shadow_img,(w,h))
|
519 |
+
in_img = in_img.astype(np.float64)*(shadow_img.astype(np.float64)+1)/255
|
520 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
521 |
+
|
522 |
+
## brightness
|
523 |
+
if random.uniform(0,1) <= 0.8:
|
524 |
+
high = 1.3
|
525 |
+
low = 0.5
|
526 |
+
ratio = np.random.uniform(low,high)
|
527 |
+
in_img = in_img.astype(np.float64)*ratio
|
528 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
529 |
+
## contrast
|
530 |
+
if random.uniform(0,1) <= 0.8:
|
531 |
+
high = 1.3
|
532 |
+
low = 0.5
|
533 |
+
ratio = np.random.uniform(low,high)
|
534 |
+
gray = cv2.cvtColor(in_img,cv2.COLOR_BGR2GRAY)
|
535 |
+
mean = np.mean(gray)
|
536 |
+
mean_array = np.ones_like(in_img).astype(np.float64)*mean
|
537 |
+
in_img = in_img.astype(np.float64)*ratio + mean_array*(1-ratio)
|
538 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
539 |
+
## color
|
540 |
+
if random.uniform(0,1) <= 0.8:
|
541 |
+
high = 0.2
|
542 |
+
low = 0.1
|
543 |
+
ratio = np.random.uniform(0.1,0.3)
|
544 |
+
random_color = np.random.randint(50,200,3).reshape(1,1,3)
|
545 |
+
random_color = (random_color*ratio).astype(np.uint8)
|
546 |
+
random_color = np.tile(random_color,(h,w,1))
|
547 |
+
in_img = in_img.astype(np.float64)*(1-ratio) + random_color
|
548 |
+
in_img = np.clip(in_img,0,255).astype(np.uint8)
|
549 |
+
|
550 |
+
if random.uniform(0,1) <= 0.8:
|
551 |
+
quanlity_high = 95
|
552 |
+
quanlity_low = 45
|
553 |
+
quanlity = int(np.random.randint(quanlity_low,quanlity_high))
|
554 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY),quanlity]
|
555 |
+
result, encimg = cv2.imencode('.jpg',in_img,encode_param)
|
556 |
+
in_img = cv2.imdecode(encimg,1).astype(np.uint8)
|
557 |
+
|
558 |
+
return in_img
|
models/restormer_arch.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
2 |
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
3 |
+
## https://arxiv.org/abs/2111.09881
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from pdb import set_trace as stx
|
10 |
+
import numbers
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
##########################################################################
|
17 |
+
## Layer Norm
|
18 |
+
|
19 |
+
def to_3d(x):
|
20 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
21 |
+
|
22 |
+
def to_4d(x,h,w):
|
23 |
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
24 |
+
|
25 |
+
class BiasFree_LayerNorm(nn.Module):
|
26 |
+
def __init__(self, normalized_shape):
|
27 |
+
super(BiasFree_LayerNorm, self).__init__()
|
28 |
+
if isinstance(normalized_shape, numbers.Integral):
|
29 |
+
normalized_shape = (normalized_shape,)
|
30 |
+
normalized_shape = torch.Size(normalized_shape)
|
31 |
+
|
32 |
+
assert len(normalized_shape) == 1
|
33 |
+
|
34 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
35 |
+
self.normalized_shape = normalized_shape
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
39 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
40 |
+
|
41 |
+
class WithBias_LayerNorm(nn.Module):
|
42 |
+
def __init__(self, normalized_shape):
|
43 |
+
super(WithBias_LayerNorm, self).__init__()
|
44 |
+
if isinstance(normalized_shape, numbers.Integral):
|
45 |
+
normalized_shape = (normalized_shape,)
|
46 |
+
normalized_shape = torch.Size(normalized_shape)
|
47 |
+
|
48 |
+
assert len(normalized_shape) == 1
|
49 |
+
|
50 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
51 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
52 |
+
self.normalized_shape = normalized_shape
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
mu = x.mean(-1, keepdim=True)
|
56 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
57 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
58 |
+
|
59 |
+
|
60 |
+
class LayerNorm(nn.Module):
|
61 |
+
def __init__(self, dim, LayerNorm_type):
|
62 |
+
super(LayerNorm, self).__init__()
|
63 |
+
if LayerNorm_type =='BiasFree':
|
64 |
+
self.body = BiasFree_LayerNorm(dim)
|
65 |
+
else:
|
66 |
+
self.body = WithBias_LayerNorm(dim)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
h, w = x.shape[-2:]
|
70 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
##########################################################################
|
75 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
76 |
+
class FeedForward(nn.Module):
|
77 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
78 |
+
super(FeedForward, self).__init__()
|
79 |
+
|
80 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
81 |
+
|
82 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
83 |
+
|
84 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
85 |
+
|
86 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x = self.project_in(x)
|
90 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
91 |
+
x = F.gelu(x1) * x2
|
92 |
+
x = self.project_out(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
##########################################################################
|
98 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
99 |
+
class Attention(nn.Module):
|
100 |
+
def __init__(self, dim, num_heads, bias):
|
101 |
+
super(Attention, self).__init__()
|
102 |
+
self.num_heads = num_heads
|
103 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
104 |
+
|
105 |
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
106 |
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
107 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
b,c,h,w = x.shape
|
113 |
+
|
114 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
115 |
+
q,k,v = qkv.chunk(3, dim=1)
|
116 |
+
|
117 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
118 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
119 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
120 |
+
|
121 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
122 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
123 |
+
|
124 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
125 |
+
attn = attn.softmax(dim=-1)
|
126 |
+
|
127 |
+
out = (attn @ v)
|
128 |
+
|
129 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
130 |
+
|
131 |
+
out = self.project_out(out)
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
##########################################################################
|
137 |
+
class TransformerBlock(nn.Module):
|
138 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
139 |
+
super(TransformerBlock, self).__init__()
|
140 |
+
|
141 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
142 |
+
self.attn = Attention(dim, num_heads, bias)
|
143 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
144 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = x + self.attn(self.norm1(x))
|
148 |
+
x = x + self.ffn(self.norm2(x))
|
149 |
+
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
##########################################################################
|
155 |
+
## Overlapped image patch embedding with 3x3 Conv
|
156 |
+
class OverlapPatchEmbed(nn.Module):
|
157 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
158 |
+
super(OverlapPatchEmbed, self).__init__()
|
159 |
+
|
160 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
x = self.proj(x)
|
164 |
+
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
##########################################################################
|
170 |
+
## Resizing modules
|
171 |
+
class Downsample(nn.Module):
|
172 |
+
def __init__(self, n_feat):
|
173 |
+
super(Downsample, self).__init__()
|
174 |
+
|
175 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
176 |
+
nn.PixelUnshuffle(2))
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
return self.body(x)
|
180 |
+
|
181 |
+
class Upsample(nn.Module):
|
182 |
+
def __init__(self, n_feat):
|
183 |
+
super(Upsample, self).__init__()
|
184 |
+
|
185 |
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
186 |
+
nn.PixelShuffle(2))
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
return self.body(x)
|
190 |
+
|
191 |
+
##########################################################################
|
192 |
+
##---------- Restormer -----------------------
|
193 |
+
class Restormer(nn.Module):
|
194 |
+
def __init__(self,
|
195 |
+
inp_channels=3,
|
196 |
+
out_channels=3,
|
197 |
+
dim = 48,
|
198 |
+
num_blocks = [4,6,6,8],
|
199 |
+
num_refinement_blocks = 4,
|
200 |
+
heads = [1,2,4,8],
|
201 |
+
ffn_expansion_factor = 2.66,
|
202 |
+
bias = False,
|
203 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
204 |
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
205 |
+
):
|
206 |
+
|
207 |
+
super(Restormer, self).__init__()
|
208 |
+
|
209 |
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
210 |
+
|
211 |
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
212 |
+
|
213 |
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
214 |
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
215 |
+
|
216 |
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
217 |
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
218 |
+
|
219 |
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
220 |
+
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
221 |
+
|
222 |
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
223 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
224 |
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
225 |
+
|
226 |
+
|
227 |
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
228 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
229 |
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
230 |
+
|
231 |
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
232 |
+
|
233 |
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
234 |
+
|
235 |
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
236 |
+
|
237 |
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
238 |
+
self.dual_pixel_task = dual_pixel_task
|
239 |
+
if self.dual_pixel_task:
|
240 |
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
241 |
+
###########################
|
242 |
+
|
243 |
+
|
244 |
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
245 |
+
|
246 |
+
def forward(self, inp_img,task=''):
|
247 |
+
|
248 |
+
inp_enc_level1 = self.patch_embed(inp_img)
|
249 |
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
250 |
+
|
251 |
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
252 |
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
253 |
+
|
254 |
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
255 |
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
256 |
+
|
257 |
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
258 |
+
latent = self.latent(inp_enc_level4)
|
259 |
+
|
260 |
+
|
261 |
+
inp_dec_level3 = self.up4_3(latent)
|
262 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
263 |
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
264 |
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
265 |
+
|
266 |
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
267 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
268 |
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
269 |
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
270 |
+
|
271 |
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
272 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
273 |
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
274 |
+
|
275 |
+
out_dec_level1 = self.refinement(out_dec_level1)
|
276 |
+
|
277 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
278 |
+
out_dec_level1 = self.output(out_dec_level1)
|
279 |
+
|
280 |
+
return out_dec_level1
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == '__main__':
|
285 |
+
from torchtoolbox.tools import summary
|
286 |
+
model = Restormer(
|
287 |
+
inp_channels=6,
|
288 |
+
out_channels=3,
|
289 |
+
dim = 48,
|
290 |
+
# num_blocks = [4,6,6,8],
|
291 |
+
num_blocks = [2,3,3,4],
|
292 |
+
num_refinement_blocks = 4,
|
293 |
+
heads = [1,2,4,8],
|
294 |
+
ffn_expansion_factor = 2.66,
|
295 |
+
bias = False,
|
296 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
297 |
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
298 |
+
)
|
299 |
+
# model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
|
300 |
+
print(summary(model,torch.rand((1, 6, 256, 256))))
|
301 |
+
|
302 |
+
from thop import profile
|
303 |
+
input = torch.rand((1, 6, 256, 256))
|
304 |
+
gflops,params = profile(model,inputs=(input,))
|
305 |
+
gflops = gflops*2 / 10**9
|
306 |
+
params = params / 10**6
|
307 |
+
print(gflops,'==============')
|
308 |
+
print(params,'==============')
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
numpy==1.21.6
|
3 |
+
opencv-python-headless>=4.2.0
|
4 |
+
scikit-image>=0.19.3
|
5 |
+
torch==1.11.0+cu113
|
6 |
+
torchvision==0.12.0+cu113
|
7 |
+
einops
|
8 |
+
tqdm
|
9 |
+
gradio
|
10 |
+
Pillow
|
start_train.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 26413 train.py
|
train.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
import datetime
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from piq import ssim,psnr
|
10 |
+
from itertools import cycle
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.utils import data
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
18 |
+
|
19 |
+
|
20 |
+
from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
|
21 |
+
from loaders import docres_loader
|
22 |
+
from models import restormer_arch
|
23 |
+
|
24 |
+
|
25 |
+
def seed_torch(seed=1029):
|
26 |
+
random.seed(seed)
|
27 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
28 |
+
np.random.seed(seed)
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
torch.cuda.manual_seed(seed)
|
31 |
+
torch.cuda.manual_seed_all(seed)
|
32 |
+
torch.backends.cudnn.benchmark = False
|
33 |
+
torch.backends.cudnn.deterministic = True
|
34 |
+
#torch.use_deterministic_algorithms(True)
|
35 |
+
# seed_torch()
|
36 |
+
|
37 |
+
|
38 |
+
def getBasecoord(h,w):
|
39 |
+
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
|
40 |
+
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
|
41 |
+
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
|
42 |
+
return base_coord
|
43 |
+
|
44 |
+
def train(args):
|
45 |
+
|
46 |
+
## DDP init
|
47 |
+
dist.init_process_group(backend='nccl',init_method='env://',timeout=datetime.timedelta(seconds=36000))
|
48 |
+
torch.cuda.set_device(args.local_rank)
|
49 |
+
device = torch.device('cuda',args.local_rank)
|
50 |
+
torch.cuda.manual_seed_all(42)
|
51 |
+
|
52 |
+
### Log file:
|
53 |
+
mkdir(args.logdir)
|
54 |
+
mkdir(os.path.join(args.logdir,args.experiment_name))
|
55 |
+
log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt')
|
56 |
+
log_file=open(log_file_path,'a')
|
57 |
+
log_file.write('\n--------------- '+args.experiment_name+' ---------------\n')
|
58 |
+
log_file.close()
|
59 |
+
|
60 |
+
### Setup tensorboard for visualization
|
61 |
+
if args.tboard:
|
62 |
+
writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
|
63 |
+
|
64 |
+
### Setup Dataloader
|
65 |
+
datasets_setting = [
|
66 |
+
{'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
|
67 |
+
{'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
|
68 |
+
{'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']},
|
69 |
+
{'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
|
70 |
+
{'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
|
71 |
+
]
|
72 |
+
|
73 |
+
|
74 |
+
ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
|
75 |
+
datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
|
76 |
+
trainloaders = [{'task':datasets_setting[i],'loader':data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True),'iter_loader':iter(data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True))} for i in range(len(datasets))]
|
77 |
+
|
78 |
+
|
79 |
+
### test loader
|
80 |
+
# for i in tqdm(range(args.total_iter)):
|
81 |
+
# loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
|
82 |
+
# in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
|
83 |
+
|
84 |
+
|
85 |
+
### Setup Model
|
86 |
+
model = restormer_arch.Restormer(
|
87 |
+
inp_channels=6,
|
88 |
+
out_channels=3,
|
89 |
+
dim = 48,
|
90 |
+
num_blocks = [2,3,3,4],
|
91 |
+
num_refinement_blocks = 4,
|
92 |
+
heads = [1,2,4,8],
|
93 |
+
ffn_expansion_factor = 2.66,
|
94 |
+
bias = False,
|
95 |
+
LayerNorm_type = 'WithBias',
|
96 |
+
dual_pixel_task = True
|
97 |
+
)
|
98 |
+
model=DDP(model.cuda(),device_ids=[args.local_rank],output_device=args.local_rank)
|
99 |
+
|
100 |
+
### Optimizer
|
101 |
+
optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
|
102 |
+
|
103 |
+
### LR Scheduler
|
104 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1)
|
105 |
+
|
106 |
+
### load checkpoint
|
107 |
+
iter_start=0
|
108 |
+
if args.resume is not None:
|
109 |
+
print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
|
110 |
+
x = checkpoint['model_state']
|
111 |
+
model.load_state_dict(x,strict=False)
|
112 |
+
iter_start=checkpoint['iter']
|
113 |
+
print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start))
|
114 |
+
|
115 |
+
###-----------------------------------------Training-----------------------------------------
|
116 |
+
##initialize
|
117 |
+
scaler = torch.cuda.amp.GradScaler()
|
118 |
+
loss_dict = {}
|
119 |
+
total_step = 0
|
120 |
+
l2 = nn.MSELoss()
|
121 |
+
l1 = nn.L1Loss()
|
122 |
+
ce = nn.CrossEntropyLoss()
|
123 |
+
bce = nn.BCEWithLogitsLoss()
|
124 |
+
m = nn.Sigmoid()
|
125 |
+
best = 0
|
126 |
+
best_ce = 999
|
127 |
+
|
128 |
+
## total_steps
|
129 |
+
for iters in range(iter_start,args.total_iter):
|
130 |
+
start_time = time.time()
|
131 |
+
loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
|
132 |
+
|
133 |
+
try:
|
134 |
+
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
|
135 |
+
except StopIteration:
|
136 |
+
trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
|
137 |
+
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
|
138 |
+
in_im = in_im.float().cuda()
|
139 |
+
gt_im = gt_im.float().cuda()
|
140 |
+
|
141 |
+
binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
|
142 |
+
with torch.cuda.amp.autocast():
|
143 |
+
pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
|
144 |
+
if trainloaders[loader_index]['task']['task'] == 'binarization':
|
145 |
+
gt_im = gt_im.long()
|
146 |
+
binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
|
147 |
+
loss = binarization_loss
|
148 |
+
elif trainloaders[loader_index]['task']['task'] == 'dewarping':
|
149 |
+
dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
|
150 |
+
loss = dewarping_loss
|
151 |
+
elif trainloaders[loader_index]['task']['task'] == 'appearance':
|
152 |
+
appearance_loss = l1(pred_im, gt_im)
|
153 |
+
loss = appearance_loss
|
154 |
+
elif trainloaders[loader_index]['task']['task'] == 'deblurring':
|
155 |
+
deblurring_loss = l1(pred_im, gt_im)
|
156 |
+
loss = deblurring_loss
|
157 |
+
elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
|
158 |
+
deshadowing_loss = l1(pred_im, gt_im)
|
159 |
+
loss = deshadowing_loss
|
160 |
+
|
161 |
+
optimizer.zero_grad()
|
162 |
+
scaler.scale(loss).backward()
|
163 |
+
scaler.step(optimizer)
|
164 |
+
scaler.update()
|
165 |
+
|
166 |
+
loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
|
167 |
+
loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
|
168 |
+
loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
|
169 |
+
loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
|
170 |
+
loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
|
171 |
+
end_time = time.time()
|
172 |
+
duration = end_time-start_time
|
173 |
+
## log
|
174 |
+
if (iters+1) % 10 == 0:
|
175 |
+
## print
|
176 |
+
print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
|
177 |
+
## tbord
|
178 |
+
if args.tboard:
|
179 |
+
for key,value in loss_dict.items():
|
180 |
+
writer.add_scalar('Train '+key+'/Iterations', value, total_step)
|
181 |
+
## logfile
|
182 |
+
with open(log_file_path,'a') as f:
|
183 |
+
f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
|
184 |
+
|
185 |
+
|
186 |
+
if (iters+1) % 5000 == 0:
|
187 |
+
state = {'iters': iters+1,
|
188 |
+
'model_state': model.state_dict(),
|
189 |
+
'optimizer_state' : optimizer.state_dict(),}
|
190 |
+
if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
|
191 |
+
os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
|
192 |
+
if torch.distributed.get_rank()==0:
|
193 |
+
torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
|
194 |
+
|
195 |
+
sched.step()
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
parser = argparse.ArgumentParser(description='Hyperparams')
|
201 |
+
parser.add_argument('--im_size', nargs='?', type=int, default=256,
|
202 |
+
help='Height of the input image')
|
203 |
+
parser.add_argument('--total_iter', nargs='?', type=int, default=100000,
|
204 |
+
help='# of the epochs')
|
205 |
+
parser.add_argument('--batch_size', nargs='?', type=int, default=10,
|
206 |
+
help='Batch Size')
|
207 |
+
parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
|
208 |
+
help='Learning Rate')
|
209 |
+
parser.add_argument('--resume', nargs='?', type=str, default=None,
|
210 |
+
help='Path to previous saved model to restart from')
|
211 |
+
parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
|
212 |
+
help='Path to store the loss logs')
|
213 |
+
parser.add_argument('--tboard', dest='tboard', action='store_true',
|
214 |
+
help='Enable visualization(s) on tensorboard | False by default')
|
215 |
+
parser.add_argument('--local_rank',type=int,default=0,metavar='N')
|
216 |
+
parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
|
217 |
+
help='the name of this experiment')
|
218 |
+
parser.set_defaults(tboard=False)
|
219 |
+
args = parser.parse_args()
|
220 |
+
|
221 |
+
train(args)
|
utils.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import os
|
7 |
+
from skimage.filters import threshold_sauvola
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
def second2hours(seconds):
|
11 |
+
h = seconds//3600
|
12 |
+
seconds %= 3600
|
13 |
+
m = seconds//60
|
14 |
+
seconds %= 60
|
15 |
+
|
16 |
+
hms = '{:d} H : {:d} Min'.format(int(h),int(m))
|
17 |
+
return hms
|
18 |
+
|
19 |
+
|
20 |
+
def dict2string(loss_dict):
|
21 |
+
loss_string = ''
|
22 |
+
for key, value in loss_dict.items():
|
23 |
+
loss_string += key+' {:.4f}, '.format(value)
|
24 |
+
return loss_string[:-2]
|
25 |
+
def mkdir(dir):
|
26 |
+
if not os.path.exists(dir):
|
27 |
+
os.makedirs(dir)
|
28 |
+
|
29 |
+
def convert_state_dict(state_dict):
|
30 |
+
"""Converts a state dict saved from a dataParallel module to normal
|
31 |
+
module state_dict inplace
|
32 |
+
:param state_dict is the loaded DataParallel model_state
|
33 |
+
|
34 |
+
"""
|
35 |
+
new_state_dict = OrderedDict()
|
36 |
+
for k, v in state_dict.items():
|
37 |
+
name = k[7:] # remove `module.`
|
38 |
+
new_state_dict[name] = v
|
39 |
+
return new_state_dict
|
40 |
+
|
41 |
+
|
42 |
+
def get_lr(optimizer):
|
43 |
+
for param_group in optimizer.param_groups:
|
44 |
+
return float(param_group['lr'])
|
45 |
+
|
46 |
+
|
47 |
+
def torch2cvimg(tensor,min=0,max=1):
|
48 |
+
'''
|
49 |
+
input:
|
50 |
+
tensor -> torch.tensor BxCxHxW C can be 1,3
|
51 |
+
return
|
52 |
+
im -> ndarray uint8 HxWxC
|
53 |
+
'''
|
54 |
+
im_list = []
|
55 |
+
for i in range(tensor.shape[0]):
|
56 |
+
im = tensor.detach().cpu().data.numpy()[i]
|
57 |
+
im = im.transpose(1,2,0)
|
58 |
+
im = np.clip(im,min,max)
|
59 |
+
im = ((im-min)/(max-min)*255).astype(np.uint8)
|
60 |
+
im_list.append(im)
|
61 |
+
return im_list
|
62 |
+
def cvimg2torch(img,min=0,max=1):
|
63 |
+
'''
|
64 |
+
input:
|
65 |
+
im -> ndarray uint8 HxWxC
|
66 |
+
return
|
67 |
+
tensor -> torch.tensor BxCxHxW
|
68 |
+
'''
|
69 |
+
img = img.astype(float) / 255.0
|
70 |
+
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
71 |
+
img = np.expand_dims(img, 0)
|
72 |
+
img = torch.from_numpy(img).float()
|
73 |
+
return img
|
74 |
+
|
75 |
+
|
76 |
+
def setup_seed(seed):
|
77 |
+
# np.random.seed(seed)
|
78 |
+
# random.seed(seed)
|
79 |
+
# torch.manual_seed(seed) #cpu
|
80 |
+
# torch.cuda.manual_seed_all(seed) #并行gpu
|
81 |
+
torch.backends.cudnn.deterministic = True #cpu/gpu结果一致
|
82 |
+
# torch.backends.cudnn.benchmark = False #训练集变化不大时使训练加速
|
83 |
+
|
84 |
+
def SauvolaModBinarization(image,n1=51,n2=51,k1=0.3,k2=0.3,default=True):
|
85 |
+
'''
|
86 |
+
Binarization using Sauvola's algorithm
|
87 |
+
@name : SauvolaModBinarization
|
88 |
+
parameters
|
89 |
+
@param image (numpy array of shape (3/1) of type np.uint8): color or gray scale image
|
90 |
+
optional parameters
|
91 |
+
@param n1 (int) : window size for running sauvola during the first pass
|
92 |
+
@param n2 (int): window size for running sauvola during the second pass
|
93 |
+
@param k1 (float): k value corresponding to sauvola during the first pass
|
94 |
+
@param k2 (float): k value corresponding to sauvola during the second pass
|
95 |
+
@param default (bool) : bollean variable to set the above parameter as default.
|
96 |
+
@param default is set to True : thus default values of the above optional parameters (n1,n2,k1,k2) are set to
|
97 |
+
n1 = 5 % of min(image height, image width)
|
98 |
+
n2 = 10 % of min(image height, image width)
|
99 |
+
k1 = 0.5
|
100 |
+
k2 = 0.5
|
101 |
+
Returns
|
102 |
+
@return A binary image of same size as @param image
|
103 |
+
|
104 |
+
@cite https://drive.google.com/file/d/1D3CyI5vtodPJeZaD2UV5wdcaIMtkBbdZ/view?usp=sharing
|
105 |
+
'''
|
106 |
+
|
107 |
+
if(default):
|
108 |
+
n1 = int(0.05*min(image.shape[0],image.shape[1]))
|
109 |
+
if (n1%2==0):
|
110 |
+
n1 = n1+1
|
111 |
+
n2 = int(0.1*min(image.shape[0],image.shape[1]))
|
112 |
+
if (n2%2==0):
|
113 |
+
n2 = n2+1
|
114 |
+
k1 = 0.5
|
115 |
+
k2 = 0.5
|
116 |
+
if(image.ndim==3):
|
117 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
118 |
+
else:
|
119 |
+
gray = np.copy(image)
|
120 |
+
T1 = threshold_sauvola(gray, window_size=n1,k=k1)
|
121 |
+
max_val = np.amax(gray)
|
122 |
+
min_val = np.amin(gray)
|
123 |
+
C = np.copy(T1)
|
124 |
+
C = C.astype(np.float32)
|
125 |
+
C[gray > T1] = (gray[gray > T1] - T1[gray > T1])/(max_val - T1[gray > T1])
|
126 |
+
C[gray <= T1] = 0
|
127 |
+
C = C * 255.0
|
128 |
+
new_in = np.copy(C.astype(np.uint8))
|
129 |
+
T2 = threshold_sauvola(new_in, window_size=n2,k=k2)
|
130 |
+
binary = np.copy(gray)
|
131 |
+
binary[new_in <= T2] = 0
|
132 |
+
binary[new_in > T2] = 255
|
133 |
+
return binary,T2
|
134 |
+
|
135 |
+
|
136 |
+
def getBasecoord(h,w):
|
137 |
+
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
|
138 |
+
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
|
139 |
+
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
|
140 |
+
return base_coord
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
import numpy as np
|
148 |
+
from scipy import ndimage as ndi
|
149 |
+
|
150 |
+
# lookup tables for bwmorph_thin
|
151 |
+
|
152 |
+
G123_LUT = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
|
153 |
+
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
154 |
+
0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
|
155 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,
|
156 |
+
1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
|
157 |
+
0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
158 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
159 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
160 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
161 |
+
0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
162 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
|
163 |
+
0, 0, 0], dtype=np.bool_)
|
164 |
+
|
165 |
+
G123P_LUT = np.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
|
166 |
+
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
167 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
|
168 |
+
1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
169 |
+
0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
170 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0,
|
171 |
+
0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
|
172 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
173 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0,
|
174 |
+
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
175 |
+
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
176 |
+
0, 0, 0], dtype=np.bool_)
|
177 |
+
|
178 |
+
def bwmorph(image, n_iter=None):
|
179 |
+
"""
|
180 |
+
Perform morphological thinning of a binary image
|
181 |
+
|
182 |
+
Parameters
|
183 |
+
----------
|
184 |
+
image : binary (M, N) ndarray
|
185 |
+
The image to be thinned.
|
186 |
+
|
187 |
+
n_iter : int, number of iterations, optional
|
188 |
+
Regardless of the value of this parameter, the thinned image
|
189 |
+
is returned immediately if an iteration produces no change.
|
190 |
+
If this parameter is specified it thus sets an upper bound on
|
191 |
+
the number of iterations performed.
|
192 |
+
|
193 |
+
Returns
|
194 |
+
-------
|
195 |
+
out : ndarray of bools
|
196 |
+
Thinned image.
|
197 |
+
|
198 |
+
See also
|
199 |
+
--------
|
200 |
+
skeletonize
|
201 |
+
|
202 |
+
Notes
|
203 |
+
-----
|
204 |
+
This algorithm [1]_ works by making multiple passes over the image,
|
205 |
+
removing pixels matching a set of criteria designed to thin
|
206 |
+
connected regions while preserving eight-connected components and
|
207 |
+
2 x 2 squares [2]_. In each of the two sub-iterations the algorithm
|
208 |
+
correlates the intermediate skeleton image with a neighborhood mask,
|
209 |
+
then looks up each neighborhood in a lookup table indicating whether
|
210 |
+
the central pixel should be deleted in that sub-iteration.
|
211 |
+
|
212 |
+
References
|
213 |
+
----------
|
214 |
+
.. [1] Z. Guo and R. W. Hall, "Parallel thinning with
|
215 |
+
two-subiteration algorithms," Comm. ACM, vol. 32, no. 3,
|
216 |
+
pp. 359-373, 1989.
|
217 |
+
.. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning
|
218 |
+
Methodologies-A Comprehensive Survey," IEEE Transactions on
|
219 |
+
Pattern Analysis and Machine Intelligence, Vol 14, No. 9,
|
220 |
+
September 1992, p. 879
|
221 |
+
|
222 |
+
Examples
|
223 |
+
--------
|
224 |
+
>>> square = np.zeros((7, 7), dtype=np.uint8)
|
225 |
+
>>> square[1:-1, 2:-2] = 1
|
226 |
+
>>> square[0,1] = 1
|
227 |
+
>>> square
|
228 |
+
array([[0, 1, 0, 0, 0, 0, 0],
|
229 |
+
[0, 0, 1, 1, 1, 0, 0],
|
230 |
+
[0, 0, 1, 1, 1, 0, 0],
|
231 |
+
[0, 0, 1, 1, 1, 0, 0],
|
232 |
+
[0, 0, 1, 1, 1, 0, 0],
|
233 |
+
[0, 0, 1, 1, 1, 0, 0],
|
234 |
+
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
|
235 |
+
>>> skel = bwmorph_thin(square)
|
236 |
+
>>> skel.astype(np.uint8)
|
237 |
+
array([[0, 1, 0, 0, 0, 0, 0],
|
238 |
+
[0, 0, 1, 0, 0, 0, 0],
|
239 |
+
[0, 0, 0, 1, 0, 0, 0],
|
240 |
+
[0, 0, 0, 1, 0, 0, 0],
|
241 |
+
[0, 0, 0, 1, 0, 0, 0],
|
242 |
+
[0, 0, 0, 0, 0, 0, 0],
|
243 |
+
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
|
244 |
+
"""
|
245 |
+
# check parameters
|
246 |
+
if n_iter is None:
|
247 |
+
n = -1
|
248 |
+
elif n_iter <= 0:
|
249 |
+
raise ValueError('n_iter must be > 0')
|
250 |
+
else:
|
251 |
+
n = n_iter
|
252 |
+
|
253 |
+
# check that we have a 2d binary image, and convert it
|
254 |
+
# to uint8
|
255 |
+
skel = np.array(image).astype(np.uint8)
|
256 |
+
|
257 |
+
if skel.ndim != 2:
|
258 |
+
raise ValueError('2D array required')
|
259 |
+
if not np.all(np.in1d(image.flat,(0,1))):
|
260 |
+
raise ValueError('Image contains values other than 0 and 1')
|
261 |
+
|
262 |
+
# neighborhood mask
|
263 |
+
mask = np.array([[ 8, 4, 2],
|
264 |
+
[16, 0, 1],
|
265 |
+
[32, 64,128]],dtype=np.uint8)
|
266 |
+
|
267 |
+
# iterate either 1) indefinitely or 2) up to iteration limit
|
268 |
+
while n != 0:
|
269 |
+
before = np.sum(skel) # count points before thinning
|
270 |
+
|
271 |
+
# for each subiteration
|
272 |
+
for lut in [G123_LUT, G123P_LUT]:
|
273 |
+
# correlate image with neighborhood mask
|
274 |
+
N = ndi.correlate(skel, mask, mode='constant')
|
275 |
+
# take deletion decision from this subiteration's LUT
|
276 |
+
D = np.take(lut, N)
|
277 |
+
# perform deletion
|
278 |
+
skel[D] = 0
|
279 |
+
|
280 |
+
after = np.sum(skel) # coint points after thinning
|
281 |
+
|
282 |
+
if before == after:
|
283 |
+
# iteration had no effect: finish
|
284 |
+
break
|
285 |
+
|
286 |
+
# count down to iteration limit (or endlessly negative)
|
287 |
+
n -= 1
|
288 |
+
|
289 |
+
return skel.astype(np.bool_)
|
290 |
+
|
291 |
+
"""
|
292 |
+
# here's how to make the LUTs
|
293 |
+
def nabe(n):
|
294 |
+
return np.array([n>>i&1 for i in range(0,9)]).astype(np.bool_)
|
295 |
+
def hood(n):
|
296 |
+
return np.take(nabe(n), np.array([[3, 2, 1],
|
297 |
+
[4, 8, 0],
|
298 |
+
[5, 6, 7]]))
|
299 |
+
def G1(n):
|
300 |
+
s = 0
|
301 |
+
bits = nabe(n)
|
302 |
+
for i in (0,2,4,6):
|
303 |
+
if not(bits[i]) and (bits[i+1] or bits[(i+2) % 8]):
|
304 |
+
s += 1
|
305 |
+
return s==1
|
306 |
+
|
307 |
+
g1_lut = np.array([G1(n) for n in range(256)])
|
308 |
+
def G2(n):
|
309 |
+
n1, n2 = 0, 0
|
310 |
+
bits = nabe(n)
|
311 |
+
for k in (1,3,5,7):
|
312 |
+
if bits[k] or bits[k-1]:
|
313 |
+
n1 += 1
|
314 |
+
if bits[k] or bits[(k+1) % 8]:
|
315 |
+
n2 += 1
|
316 |
+
return min(n1,n2) in [2,3]
|
317 |
+
g2_lut = np.array([G2(n) for n in range(256)])
|
318 |
+
g12_lut = g1_lut & g2_lut
|
319 |
+
def G3(n):
|
320 |
+
bits = nabe(n)
|
321 |
+
return not((bits[1] or bits[2] or not(bits[7])) and bits[0])
|
322 |
+
def G3p(n):
|
323 |
+
bits = nabe(n)
|
324 |
+
return not((bits[5] or bits[6] or not(bits[3])) and bits[4])
|
325 |
+
g3_lut = np.array([G3(n) for n in range(256)])
|
326 |
+
g3p_lut = np.array([G3p(n) for n in range(256)])
|
327 |
+
g123_lut = g12_lut & g3_lut
|
328 |
+
g123p_lut = g12_lut & g3p_lut
|
329 |
+
"""
|
330 |
+
|
331 |
+
"""
|
332 |
+
author : Peb Ruswono Aryan
|
333 |
+
|
334 |
+
metric for evaluating binarization algorithms
|
335 |
+
implemented :
|
336 |
+
|
337 |
+
* F-Measure
|
338 |
+
* pseudo F-Measure (as in H-DIBCO 2010 & 2012)
|
339 |
+
* Peak Signal to Noise Ratio (PSNR)
|
340 |
+
* Negative Rate Measure (NRM)
|
341 |
+
* Misclassification Penaltiy Measure (MPM)
|
342 |
+
* Distance Reciprocal Distortion (DRD)
|
343 |
+
|
344 |
+
usage:
|
345 |
+
python metric.py test-image.png ground-truth-image.png
|
346 |
+
"""
|
347 |
+
|
348 |
+
|
349 |
+
def drd_fn(im, im_gt):
|
350 |
+
height, width = im.shape
|
351 |
+
neg = np.zeros(im.shape)
|
352 |
+
neg[im_gt!=im] = 1
|
353 |
+
y, x = np.unravel_index(np.flatnonzero(neg), im.shape)
|
354 |
+
|
355 |
+
n = 2
|
356 |
+
m = n*2+1
|
357 |
+
W = np.zeros((m,m), dtype=np.uint8)
|
358 |
+
W[n,n] = 1.
|
359 |
+
W = cv2.distanceTransform(1-W, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
|
360 |
+
W[n,n] = 1.
|
361 |
+
W = 1./W
|
362 |
+
W[n,n] = 0.
|
363 |
+
W /= W.sum()
|
364 |
+
|
365 |
+
nubn = 0.
|
366 |
+
block_size = 8
|
367 |
+
for y1 in range(0, height, block_size):
|
368 |
+
for x1 in range(0, width, block_size):
|
369 |
+
y2 = min(y1+block_size-1,height-1)
|
370 |
+
x2 = min(x1+block_size-1,width-1)
|
371 |
+
block_dim = (x2-x1+1)*(y1-y1+1)
|
372 |
+
block = 1-im_gt[y1:y2, x1:x2]
|
373 |
+
block_sum = np.sum(block)
|
374 |
+
if block_sum>0 and block_sum<block_dim:
|
375 |
+
nubn += 1
|
376 |
+
|
377 |
+
drd_sum= 0.
|
378 |
+
tmp = np.zeros(W.shape)
|
379 |
+
for i in range(min(1,len(y))):
|
380 |
+
tmp[:,:] = 0
|
381 |
+
|
382 |
+
x1 = max(0, x[i]-n)
|
383 |
+
y1 = max(0, y[i]-n)
|
384 |
+
x2 = min(width-1, x[i]+n)
|
385 |
+
y2 = min(height-1, y[i]+n)
|
386 |
+
|
387 |
+
yy1 = y1-y[i]+n
|
388 |
+
yy2 = y2-y[i]+n
|
389 |
+
xx1 = x1-x[i]+n
|
390 |
+
xx2 = x2-x[i]+n
|
391 |
+
|
392 |
+
tmp[yy1:yy2+1,xx1:xx2+1] = np.abs(im[y[i],x[i]]-im_gt[y1:y2+1,x1:x2+1])
|
393 |
+
tmp *= W
|
394 |
+
|
395 |
+
drd_sum += np.sum(tmp)
|
396 |
+
return drd_sum/nubn
|
397 |
+
|
398 |
+
def bin_metric(im,im_gt):
|
399 |
+
height, width = im.shape
|
400 |
+
npixel = height*width
|
401 |
+
|
402 |
+
im[im>0] = 1
|
403 |
+
gt_mask = im_gt==0
|
404 |
+
im_gt[im_gt>0] = 1
|
405 |
+
|
406 |
+
sk = bwmorph(1-im_gt)
|
407 |
+
im_sk = np.ones(im_gt.shape)
|
408 |
+
im_sk[sk] = 0
|
409 |
+
|
410 |
+
kernel = np.ones((3,3), dtype=np.uint8)
|
411 |
+
im_dil = cv2.erode(im_gt, kernel)
|
412 |
+
im_gtb = im_gt-im_dil
|
413 |
+
im_gtbd = cv2.distanceTransform(1-im_gtb, cv2.DIST_L2, 3)
|
414 |
+
|
415 |
+
nd = im_gtbd.sum()
|
416 |
+
|
417 |
+
ptp = np.zeros(im_gt.shape)
|
418 |
+
ptp[(im==0) & (im_sk==0)] = 1
|
419 |
+
numptp = ptp.sum()
|
420 |
+
|
421 |
+
tp = np.zeros(im_gt.shape)
|
422 |
+
tp[(im==0) & (im_gt==0)] = 1
|
423 |
+
numtp = tp.sum()
|
424 |
+
|
425 |
+
tn = np.zeros(im_gt.shape)
|
426 |
+
tn[(im==1) & (im_gt==1)] = 1
|
427 |
+
numtn = tn.sum()
|
428 |
+
|
429 |
+
fp = np.zeros(im_gt.shape)
|
430 |
+
fp[(im==0) & (im_gt==1)] = 1
|
431 |
+
numfp = fp.sum()
|
432 |
+
|
433 |
+
fn = np.zeros(im_gt.shape)
|
434 |
+
fn[(im==1) & (im_gt==0)] = 1
|
435 |
+
numfn = fn.sum()
|
436 |
+
|
437 |
+
precision = numtp / (numtp + numfp)
|
438 |
+
recall = numtp / (numtp + numfn)
|
439 |
+
precall = numptp / np.sum(1-im_sk)
|
440 |
+
fmeasure = (2*recall*precision)/(recall+precision)
|
441 |
+
pfmeasure = (2*precall*precision)/(precall+precision)
|
442 |
+
|
443 |
+
mse = (numfp+numfn)/npixel
|
444 |
+
psnr = 10.*np.log10(1./mse)
|
445 |
+
|
446 |
+
nrfn = numfn / (numfn + numtp)
|
447 |
+
nrfp = numfp / (numfp + numtn)
|
448 |
+
nrm = (nrfn + nrfp)/2
|
449 |
+
|
450 |
+
im_dn = im_gtbd.copy()
|
451 |
+
im_dn[fn==0] = 0
|
452 |
+
dn = np.sum(im_dn)
|
453 |
+
mpfn = dn / nd
|
454 |
+
|
455 |
+
im_dp = im_gtbd.copy()
|
456 |
+
im_dp[fp==0] = 0
|
457 |
+
dp = np.sum(im_dp)
|
458 |
+
mpfp = dp / nd
|
459 |
+
|
460 |
+
mpm = (mpfp + mpfn) / 2
|
461 |
+
drd = drd_fn(im, im_gt)
|
462 |
+
|
463 |
+
return fmeasure, pfmeasure,psnr,nrm, mpm,drd
|
464 |
+
# print("F-measure\t: {0}\npF-measure\t: {1}\nPSNR\t\t: {2}\nNRM\t\t: {3}\nMPM\t\t: {4}\nDRD\t\t: {5}".format(fmeasure, pfmeasure, psnr, nrm, mpm, drd))
|