Spaces:
Build error
Build error
Commit
·
0483f57
0
Parent(s):
Duplicate from PKUWilliamYang/StyleGANEX
Browse filesCo-authored-by: Shuai Yang <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- packages.txt +2 -0
- .gitattributes +34 -0
- README.md +10 -0
- app.py +112 -0
- configs/__init__.py +0 -0
- configs/data_configs.py +48 -0
- configs/dataset_config.yml +60 -0
- configs/paths_config.py +25 -0
- configs/transforms_config.py +242 -0
- datasets/__init__.py +0 -0
- datasets/augmentations.py +110 -0
- datasets/ffhq_degradation_dataset.py +235 -0
- datasets/gt_res_dataset.py +32 -0
- datasets/images_dataset.py +33 -0
- datasets/inference_dataset.py +22 -0
- latent_optimization.py +107 -0
- models/__init__.py +0 -0
- models/bisenet/LICENSE +21 -0
- models/bisenet/README.md +68 -0
- models/bisenet/model.py +283 -0
- models/bisenet/resnet.py +109 -0
- models/encoders/__init__.py +0 -0
- models/encoders/helpers.py +119 -0
- models/encoders/model_irse.py +84 -0
- models/encoders/psp_encoders.py +357 -0
- models/mtcnn/__init__.py +0 -0
- models/mtcnn/mtcnn.py +156 -0
- models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
- models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
- models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
- models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
- models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
- models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
- models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
- models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
- models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
- models/mtcnn/mtcnn_pytorch/src/weights/onet.npy +3 -0
- models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy +3 -0
- models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy +3 -0
- models/psp.py +148 -0
- models/stylegan2/__init__.py +0 -0
- models/stylegan2/lpips/__init__.py +161 -0
- models/stylegan2/lpips/base_model.py +58 -0
- models/stylegan2/lpips/dist_model.py +284 -0
- models/stylegan2/lpips/networks_basic.py +187 -0
- models/stylegan2/lpips/pretrained_networks.py +181 -0
- models/stylegan2/lpips/weights/v0.0/alex.pth +3 -0
- models/stylegan2/lpips/weights/v0.0/squeeze.pth +3 -0
- models/stylegan2/lpips/weights/v0.0/vgg.pth +3 -0
- models/stylegan2/lpips/weights/v0.1/alex.pth +3 -0
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
bzip2
|
2 |
+
cmake
|
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: StyleGANEX
|
3 |
+
sdk: gradio
|
4 |
+
emoji: 🐨
|
5 |
+
colorFrom: pink
|
6 |
+
colorTo: yellow
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
duplicated_from: PKUWilliamYang/StyleGANEX
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import pathlib
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from webUI.app_task import *
|
9 |
+
from webUI.styleganex_model import Model
|
10 |
+
|
11 |
+
def parse_args() -> argparse.Namespace:
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--device', type=str, default='cpu')
|
14 |
+
parser.add_argument('--theme', type=str)
|
15 |
+
parser.add_argument('--share', action='store_true')
|
16 |
+
parser.add_argument('--port', type=int)
|
17 |
+
parser.add_argument('--disable-queue',
|
18 |
+
dest='enable_queue',
|
19 |
+
action='store_false')
|
20 |
+
return parser.parse_args()
|
21 |
+
|
22 |
+
DESCRIPTION = '''
|
23 |
+
<div align=center>
|
24 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
25 |
+
Face Manipulation with <a href="https://github.com/williamyang1991/StyleGANEX">StyleGANEX</a>
|
26 |
+
</h1>
|
27 |
+
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
28 |
+
<a href="https://huggingface.co/spaces/PKUWilliamYang/StyleGANEX?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
|
29 |
+
<p/>
|
30 |
+
<img style="margin-top: 0em" src="https://raw.githubusercontent.com/williamyang1991/tmpfile/master/imgs/example.jpg" alt="example">
|
31 |
+
</div>
|
32 |
+
'''
|
33 |
+
ARTICLE = r"""
|
34 |
+
If StyleGANEX is helpful, please help to ⭐ the <a href='https://github.com/williamyang1991/StyleGANEX' target='_blank'>Github Repo</a>. Thanks!
|
35 |
+
[![GitHub Stars](https://img.shields.io/github/stars/williamyang1991/StyleGANEX?style=social)](https://github.com/williamyang1991/StyleGANEX)
|
36 |
+
---
|
37 |
+
📝 **Citation**
|
38 |
+
If our work is useful for your research, please consider citing:
|
39 |
+
```bibtex
|
40 |
+
@article{yang2023styleganex,
|
41 |
+
title = {StyleGANEX: StyleGAN-Based Manipulation Beyond Cropped Aligned Faces},
|
42 |
+
author = {Yang, Shuai and Jiang, Liming and Liu, Ziwei and and Loy, Chen Change},
|
43 |
+
journal = {arXiv preprint arXiv:2303.06146},
|
44 |
+
year={2023},
|
45 |
+
}
|
46 |
+
```
|
47 |
+
📋 **License**
|
48 |
+
This project is licensed under <a rel="license" href="https://github.com/williamyang1991/VToonify/blob/main/LICENSE.md">S-Lab License 1.0</a>.
|
49 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
50 |
+
|
51 |
+
📧 **Contact**
|
52 |
+
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
|
53 |
+
"""
|
54 |
+
|
55 |
+
FOOTER = '<div align=center><img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/styleganex" /></div>'
|
56 |
+
|
57 |
+
def main():
|
58 |
+
args = parse_args()
|
59 |
+
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
60 |
+
print('*** Now using %s.'%(args.device))
|
61 |
+
model = Model(device=args.device)
|
62 |
+
|
63 |
+
|
64 |
+
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/234_sketch.jpg',
|
65 |
+
'234_sketch.jpg')
|
66 |
+
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/output/ILip77SbmOE_inversion.pt',
|
67 |
+
'ILip77SbmOE_inversion.pt')
|
68 |
+
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE.png',
|
69 |
+
'ILip77SbmOE.png')
|
70 |
+
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE_mask.png',
|
71 |
+
'ILip77SbmOE_mask.png')
|
72 |
+
torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/pexels-daniel-xavier-1239291.jpg',
|
73 |
+
'pexels-daniel-xavier-1239291.jpg')
|
74 |
+
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/529_2.mp4',
|
75 |
+
'529_2.mp4')
|
76 |
+
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/684.mp4',
|
77 |
+
'684.mp4')
|
78 |
+
torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/pexels-anthony-shkraba-production-8136210.mp4',
|
79 |
+
'pexels-anthony-shkraba-production-8136210.mp4')
|
80 |
+
|
81 |
+
|
82 |
+
with gr.Blocks(css='style.css') as demo:
|
83 |
+
gr.Markdown(DESCRIPTION)
|
84 |
+
with gr.Tabs():
|
85 |
+
with gr.TabItem('Inversion for Editing'):
|
86 |
+
create_demo_inversion(model.process_inversion, allow_optimization=False)
|
87 |
+
with gr.TabItem('Image Face Toonify'):
|
88 |
+
create_demo_toonify(model.process_toonify)
|
89 |
+
with gr.TabItem('Video Face Toonify'):
|
90 |
+
create_demo_vtoonify(model.process_vtoonify, max_frame_num=12)
|
91 |
+
with gr.TabItem('Image Face Editing'):
|
92 |
+
create_demo_editing(model.process_editing)
|
93 |
+
with gr.TabItem('Video Face Editing'):
|
94 |
+
create_demo_vediting(model.process_vediting, max_frame_num=12)
|
95 |
+
with gr.TabItem('Sketch2Face'):
|
96 |
+
create_demo_s2f(model.process_s2f)
|
97 |
+
with gr.TabItem('Mask2Face'):
|
98 |
+
create_demo_m2f(model.process_m2f)
|
99 |
+
with gr.TabItem('SR'):
|
100 |
+
create_demo_sr(model.process_sr)
|
101 |
+
gr.Markdown(ARTICLE)
|
102 |
+
gr.Markdown(FOOTER)
|
103 |
+
|
104 |
+
demo.launch(
|
105 |
+
enable_queue=args.enable_queue,
|
106 |
+
server_port=args.port,
|
107 |
+
share=args.share,
|
108 |
+
)
|
109 |
+
|
110 |
+
if __name__ == '__main__':
|
111 |
+
main()
|
112 |
+
|
configs/__init__.py
ADDED
File without changes
|
configs/data_configs.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import transforms_config
|
2 |
+
from configs.paths_config import dataset_paths
|
3 |
+
|
4 |
+
|
5 |
+
DATASETS = {
|
6 |
+
'ffhq_encode': {
|
7 |
+
'transforms': transforms_config.EncodeTransforms,
|
8 |
+
'train_source_root': dataset_paths['ffhq'],
|
9 |
+
'train_target_root': dataset_paths['ffhq'],
|
10 |
+
'test_source_root': dataset_paths['ffhq_test'],
|
11 |
+
'test_target_root': dataset_paths['ffhq_test'],
|
12 |
+
},
|
13 |
+
'ffhq_sketch_to_face': {
|
14 |
+
'transforms': transforms_config.SketchToImageTransforms,
|
15 |
+
'train_source_root': dataset_paths['ffhq_train_sketch'],
|
16 |
+
'train_target_root': dataset_paths['ffhq'],
|
17 |
+
'test_source_root': dataset_paths['ffhq_test_sketch'],
|
18 |
+
'test_target_root': dataset_paths['ffhq_test'],
|
19 |
+
},
|
20 |
+
'ffhq_seg_to_face': {
|
21 |
+
'transforms': transforms_config.SegToImageTransforms,
|
22 |
+
'train_source_root': dataset_paths['ffhq_train_segmentation'],
|
23 |
+
'train_target_root': dataset_paths['ffhq'],
|
24 |
+
'test_source_root': dataset_paths['ffhq_test_segmentation'],
|
25 |
+
'test_target_root': dataset_paths['ffhq_test'],
|
26 |
+
},
|
27 |
+
'ffhq_super_resolution': {
|
28 |
+
'transforms': transforms_config.SuperResTransforms,
|
29 |
+
'train_source_root': dataset_paths['ffhq'],
|
30 |
+
'train_target_root': dataset_paths['ffhq1280'],
|
31 |
+
'test_source_root': dataset_paths['ffhq_test'],
|
32 |
+
'test_target_root': dataset_paths['ffhq1280_test'],
|
33 |
+
},
|
34 |
+
'toonify': {
|
35 |
+
'transforms': transforms_config.ToonifyTransforms,
|
36 |
+
'train_source_root': dataset_paths['toonify_in'],
|
37 |
+
'train_target_root': dataset_paths['toonify_out'],
|
38 |
+
'test_source_root': dataset_paths['toonify_test_in'],
|
39 |
+
'test_target_root': dataset_paths['toonify_test_out'],
|
40 |
+
},
|
41 |
+
'ffhq_edit': {
|
42 |
+
'transforms': transforms_config.EditingTransforms,
|
43 |
+
'train_source_root': dataset_paths['ffhq'],
|
44 |
+
'train_target_root': dataset_paths['ffhq'],
|
45 |
+
'test_source_root': dataset_paths['ffhq_test'],
|
46 |
+
'test_target_root': dataset_paths['ffhq_test'],
|
47 |
+
},
|
48 |
+
}
|
configs/dataset_config.yml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset and data loader settings
|
2 |
+
datasets:
|
3 |
+
train:
|
4 |
+
name: FFHQ
|
5 |
+
type: FFHQDegradationDataset
|
6 |
+
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
7 |
+
dataroot_gt: ../../../../share/shuaiyang/ffhq/realign1280x1280test/
|
8 |
+
io_backend:
|
9 |
+
# type: lmdb
|
10 |
+
type: disk
|
11 |
+
|
12 |
+
use_hflip: true
|
13 |
+
mean: [0.5, 0.5, 0.5]
|
14 |
+
std: [0.5, 0.5, 0.5]
|
15 |
+
out_size: 1280
|
16 |
+
scale: 4
|
17 |
+
|
18 |
+
blur_kernel_size: 41
|
19 |
+
kernel_list: ['iso', 'aniso']
|
20 |
+
kernel_prob: [0.5, 0.5]
|
21 |
+
blur_sigma: [0.1, 10]
|
22 |
+
downsample_range: [4, 40]
|
23 |
+
noise_range: [0, 20]
|
24 |
+
jpeg_range: [60, 100]
|
25 |
+
|
26 |
+
# color jitter and gray
|
27 |
+
#color_jitter_prob: 0.3
|
28 |
+
#color_jitter_shift: 20
|
29 |
+
#color_jitter_pt_prob: 0.3
|
30 |
+
#gray_prob: 0.01
|
31 |
+
|
32 |
+
# If you do not want colorization, please set
|
33 |
+
color_jitter_prob: ~
|
34 |
+
color_jitter_pt_prob: ~
|
35 |
+
gray_prob: 0.01
|
36 |
+
gt_gray: True
|
37 |
+
|
38 |
+
crop_components: true
|
39 |
+
component_path: ./pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
40 |
+
eye_enlarge_ratio: 1.4
|
41 |
+
|
42 |
+
# data loader
|
43 |
+
use_shuffle: true
|
44 |
+
num_worker_per_gpu: 6
|
45 |
+
batch_size_per_gpu: 4
|
46 |
+
dataset_enlarge_ratio: 1
|
47 |
+
prefetch_mode: ~
|
48 |
+
|
49 |
+
val:
|
50 |
+
# Please modify accordingly to use your own validation
|
51 |
+
# Or comment the val block if do not need validation during training
|
52 |
+
name: validation
|
53 |
+
type: PairedImageDataset
|
54 |
+
dataroot_lq: datasets/faces/validation/input
|
55 |
+
dataroot_gt: datasets/faces/validation/reference
|
56 |
+
io_backend:
|
57 |
+
type: disk
|
58 |
+
mean: [0.5, 0.5, 0.5]
|
59 |
+
std: [0.5, 0.5, 0.5]
|
60 |
+
scale: 1
|
configs/paths_config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_paths = {
|
2 |
+
'ffhq': 'data/train/ffhq/realign320x320/',
|
3 |
+
'ffhq_test': 'data/train/ffhq/realign320x320test/',
|
4 |
+
'ffhq1280': 'data/train/ffhq/realign1280x1280/',
|
5 |
+
'ffhq1280_test': 'data/train/ffhq/realign1280x1280test/',
|
6 |
+
'ffhq_train_sketch': 'data/train/ffhq/realign640x640sketch/',
|
7 |
+
'ffhq_test_sketch': 'data/train/ffhq/realign640x640sketchtest/',
|
8 |
+
'ffhq_train_segmentation': 'data/train/ffhq/realign320x320mask/',
|
9 |
+
'ffhq_test_segmentation': 'data/train/ffhq/realign320x320masktest/',
|
10 |
+
'toonify_in': 'data/train/pixar/trainA/',
|
11 |
+
'toonify_out': 'data/train/pixar/trainB/',
|
12 |
+
'toonify_test_in': 'data/train/pixar/testA/',
|
13 |
+
'toonify_test_out': 'data/train/testB/',
|
14 |
+
}
|
15 |
+
|
16 |
+
model_paths = {
|
17 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
18 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
19 |
+
'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
|
20 |
+
'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
|
21 |
+
'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
|
22 |
+
'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
|
23 |
+
'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
|
24 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar'
|
25 |
+
}
|
configs/transforms_config.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from datasets import augmentations
|
4 |
+
|
5 |
+
|
6 |
+
class TransformsConfig(object):
|
7 |
+
|
8 |
+
def __init__(self, opts):
|
9 |
+
self.opts = opts
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def get_transforms(self):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class EncodeTransforms(TransformsConfig):
|
17 |
+
|
18 |
+
def __init__(self, opts):
|
19 |
+
super(EncodeTransforms, self).__init__(opts)
|
20 |
+
|
21 |
+
def get_transforms(self):
|
22 |
+
transforms_dict = {
|
23 |
+
'transform_gt_train': transforms.Compose([
|
24 |
+
transforms.Resize((320, 320)),
|
25 |
+
transforms.RandomHorizontalFlip(0.5),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
28 |
+
'transform_source': None,
|
29 |
+
'transform_test': transforms.Compose([
|
30 |
+
transforms.Resize((320, 320)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
33 |
+
'transform_inference': transforms.Compose([
|
34 |
+
transforms.Resize((320, 320)),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
37 |
+
}
|
38 |
+
return transforms_dict
|
39 |
+
|
40 |
+
|
41 |
+
class FrontalizationTransforms(TransformsConfig):
|
42 |
+
|
43 |
+
def __init__(self, opts):
|
44 |
+
super(FrontalizationTransforms, self).__init__(opts)
|
45 |
+
|
46 |
+
def get_transforms(self):
|
47 |
+
transforms_dict = {
|
48 |
+
'transform_gt_train': transforms.Compose([
|
49 |
+
transforms.Resize((256, 256)),
|
50 |
+
transforms.RandomHorizontalFlip(0.5),
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
53 |
+
'transform_source': transforms.Compose([
|
54 |
+
transforms.Resize((256, 256)),
|
55 |
+
transforms.RandomHorizontalFlip(0.5),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
58 |
+
'transform_test': transforms.Compose([
|
59 |
+
transforms.Resize((256, 256)),
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
62 |
+
'transform_inference': transforms.Compose([
|
63 |
+
transforms.Resize((256, 256)),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
66 |
+
}
|
67 |
+
return transforms_dict
|
68 |
+
|
69 |
+
|
70 |
+
class SketchToImageTransforms(TransformsConfig):
|
71 |
+
|
72 |
+
def __init__(self, opts):
|
73 |
+
super(SketchToImageTransforms, self).__init__(opts)
|
74 |
+
|
75 |
+
def get_transforms(self):
|
76 |
+
transforms_dict = {
|
77 |
+
'transform_gt_train': transforms.Compose([
|
78 |
+
transforms.Resize((320, 320)),
|
79 |
+
transforms.ToTensor(),
|
80 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
81 |
+
'transform_source': transforms.Compose([
|
82 |
+
transforms.Resize((320, 320)),
|
83 |
+
transforms.ToTensor()]),
|
84 |
+
'transform_test': transforms.Compose([
|
85 |
+
transforms.Resize((320, 320)),
|
86 |
+
transforms.ToTensor(),
|
87 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
88 |
+
'transform_inference': transforms.Compose([
|
89 |
+
transforms.Resize((320, 320)),
|
90 |
+
transforms.ToTensor()]),
|
91 |
+
}
|
92 |
+
return transforms_dict
|
93 |
+
|
94 |
+
|
95 |
+
class SegToImageTransforms(TransformsConfig):
|
96 |
+
|
97 |
+
def __init__(self, opts):
|
98 |
+
super(SegToImageTransforms, self).__init__(opts)
|
99 |
+
|
100 |
+
def get_transforms(self):
|
101 |
+
transforms_dict = {
|
102 |
+
'transform_gt_train': transforms.Compose([
|
103 |
+
transforms.Resize((320, 320)),
|
104 |
+
transforms.ToTensor(),
|
105 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
106 |
+
'transform_source': transforms.Compose([
|
107 |
+
transforms.Resize((320, 320)),
|
108 |
+
augmentations.ToOneHot(self.opts.label_nc),
|
109 |
+
transforms.ToTensor()]),
|
110 |
+
'transform_test': transforms.Compose([
|
111 |
+
transforms.Resize((320, 320)),
|
112 |
+
transforms.ToTensor(),
|
113 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
114 |
+
'transform_inference': transforms.Compose([
|
115 |
+
transforms.Resize((320, 320)),
|
116 |
+
augmentations.ToOneHot(self.opts.label_nc),
|
117 |
+
transforms.ToTensor()])
|
118 |
+
}
|
119 |
+
return transforms_dict
|
120 |
+
|
121 |
+
|
122 |
+
class SuperResTransforms(TransformsConfig):
|
123 |
+
|
124 |
+
def __init__(self, opts):
|
125 |
+
super(SuperResTransforms, self).__init__(opts)
|
126 |
+
|
127 |
+
def get_transforms(self):
|
128 |
+
if self.opts.resize_factors is None:
|
129 |
+
self.opts.resize_factors = '1,2,4,8,16,32'
|
130 |
+
factors = [int(f) for f in self.opts.resize_factors.split(",")]
|
131 |
+
print("Performing down-sampling with factors: {}".format(factors))
|
132 |
+
transforms_dict = {
|
133 |
+
'transform_gt_train': transforms.Compose([
|
134 |
+
transforms.Resize((1280, 1280)),
|
135 |
+
transforms.ToTensor(),
|
136 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
137 |
+
'transform_source': transforms.Compose([
|
138 |
+
transforms.Resize((320, 320)),
|
139 |
+
augmentations.BilinearResize(factors=factors),
|
140 |
+
transforms.Resize((320, 320)),
|
141 |
+
transforms.ToTensor(),
|
142 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
143 |
+
'transform_test': transforms.Compose([
|
144 |
+
transforms.Resize((1280, 1280)),
|
145 |
+
transforms.ToTensor(),
|
146 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
147 |
+
'transform_inference': transforms.Compose([
|
148 |
+
transforms.Resize((320, 320)),
|
149 |
+
augmentations.BilinearResize(factors=factors),
|
150 |
+
transforms.Resize((320, 320)),
|
151 |
+
transforms.ToTensor(),
|
152 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
153 |
+
}
|
154 |
+
return transforms_dict
|
155 |
+
|
156 |
+
|
157 |
+
class SuperResTransforms_320(TransformsConfig):
|
158 |
+
|
159 |
+
def __init__(self, opts):
|
160 |
+
super(SuperResTransforms_320, self).__init__(opts)
|
161 |
+
|
162 |
+
def get_transforms(self):
|
163 |
+
if self.opts.resize_factors is None:
|
164 |
+
self.opts.resize_factors = '1,2,4,8,16,32'
|
165 |
+
factors = [int(f) for f in self.opts.resize_factors.split(",")]
|
166 |
+
print("Performing down-sampling with factors: {}".format(factors))
|
167 |
+
transforms_dict = {
|
168 |
+
'transform_gt_train': transforms.Compose([
|
169 |
+
transforms.Resize((320, 320)),
|
170 |
+
transforms.ToTensor(),
|
171 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
172 |
+
'transform_source': transforms.Compose([
|
173 |
+
transforms.Resize((320, 320)),
|
174 |
+
augmentations.BilinearResize(factors=factors),
|
175 |
+
transforms.Resize((320, 320)),
|
176 |
+
transforms.ToTensor(),
|
177 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
178 |
+
'transform_test': transforms.Compose([
|
179 |
+
transforms.Resize((320, 320)),
|
180 |
+
transforms.ToTensor(),
|
181 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
182 |
+
'transform_inference': transforms.Compose([
|
183 |
+
transforms.Resize((320, 320)),
|
184 |
+
augmentations.BilinearResize(factors=factors),
|
185 |
+
transforms.Resize((320, 320)),
|
186 |
+
transforms.ToTensor(),
|
187 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
188 |
+
}
|
189 |
+
return transforms_dict
|
190 |
+
|
191 |
+
|
192 |
+
class ToonifyTransforms(TransformsConfig):
|
193 |
+
|
194 |
+
def __init__(self, opts):
|
195 |
+
super(ToonifyTransforms, self).__init__(opts)
|
196 |
+
|
197 |
+
def get_transforms(self):
|
198 |
+
transforms_dict = {
|
199 |
+
'transform_gt_train': transforms.Compose([
|
200 |
+
transforms.Resize((1024, 1024)),
|
201 |
+
transforms.ToTensor(),
|
202 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
203 |
+
'transform_source': transforms.Compose([
|
204 |
+
transforms.Resize((256, 256)),
|
205 |
+
transforms.ToTensor(),
|
206 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
207 |
+
'transform_test': transforms.Compose([
|
208 |
+
transforms.Resize((1024, 1024)),
|
209 |
+
transforms.ToTensor(),
|
210 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
211 |
+
'transform_inference': transforms.Compose([
|
212 |
+
transforms.Resize((256, 256)),
|
213 |
+
transforms.ToTensor(),
|
214 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
215 |
+
}
|
216 |
+
return transforms_dict
|
217 |
+
|
218 |
+
class EditingTransforms(TransformsConfig):
|
219 |
+
|
220 |
+
def __init__(self, opts):
|
221 |
+
super(EditingTransforms, self).__init__(opts)
|
222 |
+
|
223 |
+
def get_transforms(self):
|
224 |
+
transforms_dict = {
|
225 |
+
'transform_gt_train': transforms.Compose([
|
226 |
+
transforms.Resize((1280, 1280)),
|
227 |
+
transforms.ToTensor(),
|
228 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
229 |
+
'transform_source': transforms.Compose([
|
230 |
+
transforms.Resize((320, 320)),
|
231 |
+
transforms.ToTensor(),
|
232 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
233 |
+
'transform_test': transforms.Compose([
|
234 |
+
transforms.Resize((1280, 1280)),
|
235 |
+
transforms.ToTensor(),
|
236 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
237 |
+
'transform_inference': transforms.Compose([
|
238 |
+
transforms.Resize((320, 320)),
|
239 |
+
transforms.ToTensor(),
|
240 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
241 |
+
}
|
242 |
+
return transforms_dict
|
datasets/__init__.py
ADDED
File without changes
|
datasets/augmentations.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
|
8 |
+
class ToOneHot(object):
|
9 |
+
""" Convert the input PIL image to a one-hot torch tensor """
|
10 |
+
def __init__(self, n_classes=None):
|
11 |
+
self.n_classes = n_classes
|
12 |
+
|
13 |
+
def onehot_initialization(self, a):
|
14 |
+
if self.n_classes is None:
|
15 |
+
self.n_classes = len(np.unique(a))
|
16 |
+
out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
|
17 |
+
out[self.__all_idx(a, axis=2)] = 1
|
18 |
+
return out
|
19 |
+
|
20 |
+
def __all_idx(self, idx, axis):
|
21 |
+
grid = np.ogrid[tuple(map(slice, idx.shape))]
|
22 |
+
grid.insert(axis, idx)
|
23 |
+
return tuple(grid)
|
24 |
+
|
25 |
+
def __call__(self, img):
|
26 |
+
img = np.array(img)
|
27 |
+
one_hot = self.onehot_initialization(img)
|
28 |
+
return one_hot
|
29 |
+
|
30 |
+
|
31 |
+
class BilinearResize(object):
|
32 |
+
def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
|
33 |
+
self.factors = factors
|
34 |
+
|
35 |
+
def __call__(self, image):
|
36 |
+
factor = np.random.choice(self.factors, size=1)[0]
|
37 |
+
D = BicubicDownSample(factor=factor, cuda=False)
|
38 |
+
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
39 |
+
img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
|
40 |
+
img_low_res = transforms.ToPILImage()(img_tensor_lr)
|
41 |
+
return img_low_res
|
42 |
+
|
43 |
+
|
44 |
+
class BicubicDownSample(nn.Module):
|
45 |
+
def bicubic_kernel(self, x, a=-0.50):
|
46 |
+
"""
|
47 |
+
This equation is exactly copied from the website below:
|
48 |
+
https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
|
49 |
+
"""
|
50 |
+
abs_x = torch.abs(x)
|
51 |
+
if abs_x <= 1.:
|
52 |
+
return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
|
53 |
+
elif 1. < abs_x < 2.:
|
54 |
+
return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
|
55 |
+
else:
|
56 |
+
return 0.0
|
57 |
+
|
58 |
+
def __init__(self, factor=4, cuda=True, padding='reflect'):
|
59 |
+
super().__init__()
|
60 |
+
self.factor = factor
|
61 |
+
size = factor * 4
|
62 |
+
k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
|
63 |
+
for i in range(size)], dtype=torch.float32)
|
64 |
+
k = k / torch.sum(k)
|
65 |
+
k1 = torch.reshape(k, shape=(1, 1, size, 1))
|
66 |
+
self.k1 = torch.cat([k1, k1, k1], dim=0)
|
67 |
+
k2 = torch.reshape(k, shape=(1, 1, 1, size))
|
68 |
+
self.k2 = torch.cat([k2, k2, k2], dim=0)
|
69 |
+
self.cuda = '.cuda' if cuda else ''
|
70 |
+
self.padding = padding
|
71 |
+
for param in self.parameters():
|
72 |
+
param.requires_grad = False
|
73 |
+
|
74 |
+
def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
|
75 |
+
filter_height = self.factor * 4
|
76 |
+
filter_width = self.factor * 4
|
77 |
+
stride = self.factor
|
78 |
+
|
79 |
+
pad_along_height = max(filter_height - stride, 0)
|
80 |
+
pad_along_width = max(filter_width - stride, 0)
|
81 |
+
filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
|
82 |
+
filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
|
83 |
+
|
84 |
+
# compute actual padding values for each side
|
85 |
+
pad_top = pad_along_height // 2
|
86 |
+
pad_bottom = pad_along_height - pad_top
|
87 |
+
pad_left = pad_along_width // 2
|
88 |
+
pad_right = pad_along_width - pad_left
|
89 |
+
|
90 |
+
# apply mirror padding
|
91 |
+
if nhwc:
|
92 |
+
x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
|
93 |
+
|
94 |
+
# downscaling performed by 1-d convolution
|
95 |
+
x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
|
96 |
+
x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
|
97 |
+
if clip_round:
|
98 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
99 |
+
|
100 |
+
x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
|
101 |
+
x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
|
102 |
+
if clip_round:
|
103 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
104 |
+
|
105 |
+
if nhwc:
|
106 |
+
x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
|
107 |
+
if byte_output:
|
108 |
+
return x.type('torch.ByteTensor'.format(self.cuda))
|
109 |
+
else:
|
110 |
+
return x
|
datasets/ffhq_degradation_dataset.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
from basicsr.data import degradations as degradations
|
8 |
+
from basicsr.data.data_util import paths_from_folder
|
9 |
+
from basicsr.data.transforms import augment
|
10 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
11 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
12 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
13 |
+
normalize)
|
14 |
+
|
15 |
+
|
16 |
+
@DATASET_REGISTRY.register()
|
17 |
+
class FFHQDegradationDataset(data.Dataset):
|
18 |
+
"""FFHQ dataset for GFPGAN.
|
19 |
+
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
20 |
+
Args:
|
21 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
22 |
+
dataroot_gt (str): Data root path for gt.
|
23 |
+
io_backend (dict): IO backend type and other kwarg.
|
24 |
+
mean (list | tuple): Image mean.
|
25 |
+
std (list | tuple): Image std.
|
26 |
+
use_hflip (bool): Whether to horizontally flip.
|
27 |
+
Please see more options in the codes.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, opt):
|
31 |
+
super(FFHQDegradationDataset, self).__init__()
|
32 |
+
self.opt = opt
|
33 |
+
# file client (io backend)
|
34 |
+
self.file_client = None
|
35 |
+
self.io_backend_opt = opt['io_backend']
|
36 |
+
|
37 |
+
self.gt_folder = opt['dataroot_gt']
|
38 |
+
self.mean = opt['mean']
|
39 |
+
self.std = opt['std']
|
40 |
+
self.out_size = opt['out_size']
|
41 |
+
|
42 |
+
self.crop_components = opt.get('crop_components', False) # facial components
|
43 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
44 |
+
|
45 |
+
if self.crop_components:
|
46 |
+
# load component list from a pre-process pth files
|
47 |
+
self.components_list = torch.load(opt.get('component_path'))
|
48 |
+
|
49 |
+
# file client (lmdb io backend)
|
50 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
51 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
52 |
+
if not self.gt_folder.endswith('.lmdb'):
|
53 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
54 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
55 |
+
self.paths = [line.split('.')[0] for line in fin]
|
56 |
+
else:
|
57 |
+
# disk backend: scan file list from a folder
|
58 |
+
self.paths = paths_from_folder(self.gt_folder)
|
59 |
+
|
60 |
+
# degradation configurations
|
61 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
62 |
+
self.kernel_list = opt['kernel_list']
|
63 |
+
self.kernel_prob = opt['kernel_prob']
|
64 |
+
self.blur_sigma = opt['blur_sigma']
|
65 |
+
self.downsample_range = opt['downsample_range']
|
66 |
+
self.noise_range = opt['noise_range']
|
67 |
+
self.jpeg_range = opt['jpeg_range']
|
68 |
+
|
69 |
+
# color jitter
|
70 |
+
self.color_jitter_prob = opt.get('color_jitter_prob')
|
71 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
72 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
73 |
+
# to gray
|
74 |
+
self.gray_prob = opt.get('gray_prob')
|
75 |
+
|
76 |
+
logger = get_root_logger()
|
77 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
78 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
79 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
80 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
81 |
+
|
82 |
+
if self.color_jitter_prob is not None:
|
83 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
84 |
+
if self.gray_prob is not None:
|
85 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
86 |
+
self.color_jitter_shift /= 255.
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def color_jitter(img, shift):
|
90 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
91 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
92 |
+
img = img + jitter_val
|
93 |
+
img = np.clip(img, 0, 1)
|
94 |
+
return img
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
98 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
99 |
+
fn_idx = torch.randperm(4)
|
100 |
+
for fn_id in fn_idx:
|
101 |
+
if fn_id == 0 and brightness is not None:
|
102 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
103 |
+
img = adjust_brightness(img, brightness_factor)
|
104 |
+
|
105 |
+
if fn_id == 1 and contrast is not None:
|
106 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
107 |
+
img = adjust_contrast(img, contrast_factor)
|
108 |
+
|
109 |
+
if fn_id == 2 and saturation is not None:
|
110 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
111 |
+
img = adjust_saturation(img, saturation_factor)
|
112 |
+
|
113 |
+
if fn_id == 3 and hue is not None:
|
114 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
115 |
+
img = adjust_hue(img, hue_factor)
|
116 |
+
return img
|
117 |
+
|
118 |
+
def get_component_coordinates(self, index, status):
|
119 |
+
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
120 |
+
components_bbox = self.components_list[f'{index:08d}']
|
121 |
+
if status[0]: # hflip
|
122 |
+
# exchange right and left eye
|
123 |
+
tmp = components_bbox['left_eye']
|
124 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
125 |
+
components_bbox['right_eye'] = tmp
|
126 |
+
# modify the width coordinate
|
127 |
+
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
128 |
+
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
129 |
+
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
130 |
+
|
131 |
+
# get coordinates
|
132 |
+
locations = []
|
133 |
+
for part in ['left_eye', 'right_eye', 'mouth']:
|
134 |
+
mean = components_bbox[part][0:2]
|
135 |
+
mean[0] = mean[0] * 2 + 128 ########
|
136 |
+
mean[1] = mean[1] * 2 + 128 ########
|
137 |
+
half_len = components_bbox[part][2] * 2 ########
|
138 |
+
if 'eye' in part:
|
139 |
+
half_len *= self.eye_enlarge_ratio
|
140 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
141 |
+
loc = torch.from_numpy(loc).float()
|
142 |
+
locations.append(loc)
|
143 |
+
return locations
|
144 |
+
|
145 |
+
def __getitem__(self, index):
|
146 |
+
if self.file_client is None:
|
147 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
148 |
+
|
149 |
+
# load gt image
|
150 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
151 |
+
gt_path = self.paths[index]
|
152 |
+
img_bytes = self.file_client.get(gt_path)
|
153 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
154 |
+
|
155 |
+
# random horizontal flip
|
156 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
157 |
+
h, w, _ = img_gt.shape
|
158 |
+
|
159 |
+
# get facial component coordinates
|
160 |
+
if self.crop_components:
|
161 |
+
locations = self.get_component_coordinates(index, status)
|
162 |
+
loc_left_eye, loc_right_eye, loc_mouth = locations
|
163 |
+
|
164 |
+
# ------------------------ generate lq image ------------------------ #
|
165 |
+
# blur
|
166 |
+
kernel = degradations.random_mixed_kernels(
|
167 |
+
self.kernel_list,
|
168 |
+
self.kernel_prob,
|
169 |
+
self.blur_kernel_size,
|
170 |
+
self.blur_sigma,
|
171 |
+
self.blur_sigma, [-math.pi, math.pi],
|
172 |
+
noise_range=None)
|
173 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
174 |
+
# downsample
|
175 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
176 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
177 |
+
# noise
|
178 |
+
if self.noise_range is not None:
|
179 |
+
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
180 |
+
# jpeg compression
|
181 |
+
if self.jpeg_range is not None:
|
182 |
+
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
183 |
+
|
184 |
+
# resize to original size
|
185 |
+
img_lq = cv2.resize(img_lq, (int(w // self.opt['scale']), int(h // self.opt['scale'])), interpolation=cv2.INTER_LINEAR)
|
186 |
+
|
187 |
+
# random color jitter (only for lq)
|
188 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
189 |
+
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
190 |
+
# random to gray (only for lq)
|
191 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
192 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
193 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
194 |
+
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
195 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
196 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
197 |
+
|
198 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
199 |
+
#img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
200 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
201 |
+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
|
202 |
+
|
203 |
+
# random color jitter (pytorch version) (only for lq)
|
204 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
205 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
206 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
207 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
208 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
209 |
+
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
210 |
+
|
211 |
+
# round and clip
|
212 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
213 |
+
|
214 |
+
# normalize
|
215 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
216 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
217 |
+
|
218 |
+
'''
|
219 |
+
if self.crop_components:
|
220 |
+
return_dict = {
|
221 |
+
'lq': img_lq,
|
222 |
+
'gt': img_gt,
|
223 |
+
'gt_path': gt_path,
|
224 |
+
'loc_left_eye': loc_left_eye,
|
225 |
+
'loc_right_eye': loc_right_eye,
|
226 |
+
'loc_mouth': loc_mouth
|
227 |
+
}
|
228 |
+
return return_dict
|
229 |
+
else:
|
230 |
+
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
231 |
+
'''
|
232 |
+
return img_lq, img_gt
|
233 |
+
|
234 |
+
def __len__(self):
|
235 |
+
return len(self.paths)
|
datasets/gt_res_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# encoding: utf-8
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class GTResDataset(Dataset):
|
9 |
+
|
10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
11 |
+
self.pairs = []
|
12 |
+
for f in os.listdir(root_path):
|
13 |
+
image_path = os.path.join(root_path, f)
|
14 |
+
gt_path = os.path.join(gt_dir, f)
|
15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
17 |
+
self.transform = transform
|
18 |
+
self.transform_train = transform_train
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.pairs)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
from_path, to_path, _ = self.pairs[index]
|
25 |
+
from_im = Image.open(from_path).convert('RGB')
|
26 |
+
to_im = Image.open(to_path).convert('RGB')
|
27 |
+
|
28 |
+
if self.transform:
|
29 |
+
to_im = self.transform(to_im)
|
30 |
+
from_im = self.transform(from_im)
|
31 |
+
|
32 |
+
return from_im, to_im
|
datasets/images_dataset.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
11 |
+
self.source_transform = source_transform
|
12 |
+
self.target_transform = target_transform
|
13 |
+
self.opts = opts
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.source_paths)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
from_path = self.source_paths[index]
|
20 |
+
from_im = Image.open(from_path)
|
21 |
+
from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
|
22 |
+
|
23 |
+
to_path = self.target_paths[index]
|
24 |
+
to_im = Image.open(to_path).convert('RGB')
|
25 |
+
if self.target_transform:
|
26 |
+
to_im = self.target_transform(to_im)
|
27 |
+
|
28 |
+
if self.source_transform:
|
29 |
+
from_im = self.source_transform(from_im)
|
30 |
+
else:
|
31 |
+
from_im = to_im
|
32 |
+
|
33 |
+
return from_im, to_im
|
datasets/inference_dataset.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, root, opts, transform=None):
|
9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
10 |
+
self.transform = transform
|
11 |
+
self.opts = opts
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return len(self.paths)
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
from_path = self.paths[index]
|
18 |
+
from_im = Image.open(from_path)
|
19 |
+
from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
|
20 |
+
if self.transform:
|
21 |
+
from_im = self.transform(from_im)
|
22 |
+
return from_im
|
latent_optimization.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import models.stylegan2.lpips as lpips
|
2 |
+
from torch import autograd, optim
|
3 |
+
from torchvision import transforms, utils
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
from scripts.align_all_parallel import align_face
|
7 |
+
from utils.inference_utils import noise_regularize, noise_normalize_, get_lr, latent_noise, visualize
|
8 |
+
|
9 |
+
def latent_optimization(frame, pspex, landmarkpredictor, step=500, device='cuda'):
|
10 |
+
percept = lpips.PerceptualLoss(
|
11 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
12 |
+
)
|
13 |
+
|
14 |
+
transform = transforms.Compose([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
|
17 |
+
])
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
|
21 |
+
noise_sample = torch.randn(1000, 512, device=device)
|
22 |
+
latent_out = pspex.decoder.style(noise_sample)
|
23 |
+
latent_mean = latent_out.mean(0)
|
24 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / 1000) ** 0.5
|
25 |
+
|
26 |
+
y = transform(frame).unsqueeze(dim=0).to(device)
|
27 |
+
I_ = align_face(frame, landmarkpredictor)
|
28 |
+
I_ = transform(I_).unsqueeze(dim=0).to(device)
|
29 |
+
wplus = pspex.encoder(I_) + pspex.latent_avg.unsqueeze(0)
|
30 |
+
_, f = pspex.encoder(y, return_feat=True)
|
31 |
+
latent_in = wplus.detach().clone()
|
32 |
+
feat = [f[0].detach().clone(), f[1].detach().clone()]
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
# wplus and f to optimize
|
37 |
+
latent_in.requires_grad = True
|
38 |
+
feat[0].requires_grad = True
|
39 |
+
feat[1].requires_grad = True
|
40 |
+
|
41 |
+
noises_single = pspex.decoder.make_noise()
|
42 |
+
basic_height, basic_width = int(y.shape[2]*32/256), int(y.shape[3]*32/256)
|
43 |
+
noises = []
|
44 |
+
for noise in noises_single:
|
45 |
+
noises.append(noise.new_empty(y.shape[0], 1, max(basic_height, int(y.shape[2]*noise.shape[2]/256)),
|
46 |
+
max(basic_width, int(y.shape[3]*noise.shape[2]/256))).normal_())
|
47 |
+
for noise in noises:
|
48 |
+
noise.requires_grad = True
|
49 |
+
|
50 |
+
init_lr=0.05
|
51 |
+
optimizer = optim.Adam(feat + noises, lr=init_lr)
|
52 |
+
optimizer2 = optim.Adam([latent_in], lr=init_lr)
|
53 |
+
noise_weight = 0.05 * 0.2
|
54 |
+
|
55 |
+
pbar = tqdm(range(step))
|
56 |
+
latent_path = []
|
57 |
+
|
58 |
+
for i in pbar:
|
59 |
+
t = i / step
|
60 |
+
lr = get_lr(t, init_lr)
|
61 |
+
optimizer.param_groups[0]["lr"] = lr
|
62 |
+
optimizer2.param_groups[0]["lr"] = get_lr(t, init_lr)
|
63 |
+
|
64 |
+
noise_strength = latent_std * noise_weight * max(0, 1 - t / 0.75) ** 2
|
65 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
66 |
+
|
67 |
+
y_hat, _ = pspex.decoder([latent_n], input_is_latent=True, randomize_noise=False,
|
68 |
+
first_layer_feature=feat, noise=noises)
|
69 |
+
|
70 |
+
|
71 |
+
batch, channel, height, width = y_hat.shape
|
72 |
+
|
73 |
+
if height > y.shape[2]:
|
74 |
+
factor = height // y.shape[2]
|
75 |
+
|
76 |
+
y_hat = y_hat.reshape(
|
77 |
+
batch, channel, height // factor, factor, width // factor, factor
|
78 |
+
)
|
79 |
+
y_hat = y_hat.mean([3, 5])
|
80 |
+
|
81 |
+
p_loss = percept(y_hat, y).sum()
|
82 |
+
n_loss = noise_regularize(noises) * 1e3
|
83 |
+
|
84 |
+
loss = p_loss + n_loss
|
85 |
+
|
86 |
+
optimizer.zero_grad()
|
87 |
+
optimizer2.zero_grad()
|
88 |
+
loss.backward()
|
89 |
+
optimizer.step()
|
90 |
+
optimizer2.step()
|
91 |
+
|
92 |
+
noise_normalize_(noises)
|
93 |
+
|
94 |
+
''' for visualization
|
95 |
+
if (i + 1) % 100 == 0 or i == 0:
|
96 |
+
viz = torch.cat((y_hat,y,y_hat-y), dim=3)
|
97 |
+
visualize(torch.clamp(viz[0].cpu(),-1,1), 60)
|
98 |
+
'''
|
99 |
+
|
100 |
+
pbar.set_description(
|
101 |
+
(
|
102 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
103 |
+
f" lr: {lr:.4f}"
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
return latent_n, feat, noises, wplus, f
|
models/__init__.py
ADDED
File without changes
|
models/bisenet/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 zll
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
models/bisenet/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# face-parsing.PyTorch
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://github.com/zllrunning/face-parsing.PyTorch">
|
5 |
+
<img class="page-image" src="https://github.com/zllrunning/face-parsing.PyTorch/blob/master/6.jpg" >
|
6 |
+
</a>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
### Contents
|
10 |
+
- [Training](#training)
|
11 |
+
- [Demo](#Demo)
|
12 |
+
- [References](#references)
|
13 |
+
|
14 |
+
## Training
|
15 |
+
|
16 |
+
1. Prepare training data:
|
17 |
+
-- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ)
|
18 |
+
|
19 |
+
-- change file path in the `prepropess_data.py` and run
|
20 |
+
```Shell
|
21 |
+
python prepropess_data.py
|
22 |
+
```
|
23 |
+
|
24 |
+
2. Train the model using CelebAMask-HQ dataset:
|
25 |
+
Just run the train script:
|
26 |
+
```
|
27 |
+
$ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
|
28 |
+
```
|
29 |
+
|
30 |
+
If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`.
|
31 |
+
|
32 |
+
|
33 |
+
## Demo
|
34 |
+
1. Evaluate the trained model using:
|
35 |
+
```Shell
|
36 |
+
# evaluate using GPU
|
37 |
+
python test.py
|
38 |
+
```
|
39 |
+
|
40 |
+
## Face makeup using parsing maps
|
41 |
+
[**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch)
|
42 |
+
<table>
|
43 |
+
|
44 |
+
<tr>
|
45 |
+
<th> </th>
|
46 |
+
<th>Hair</th>
|
47 |
+
<th>Lip</th>
|
48 |
+
</tr>
|
49 |
+
|
50 |
+
<!-- Line 1: Original Input -->
|
51 |
+
<tr>
|
52 |
+
<td><em>Original Input</em></td>
|
53 |
+
<td><img src="makeup/116_ori.png" height="256" width="256" alt="Original Input"></td>
|
54 |
+
<td><img src="makeup/116_lip_ori.png" height="256" width="256" alt="Original Input"></td>
|
55 |
+
</tr>
|
56 |
+
|
57 |
+
<!-- Line 3: Color -->
|
58 |
+
<tr>
|
59 |
+
<td>Color</td>
|
60 |
+
<td><img src="makeup/116_1.png" height="256" width="256" alt="Color"></td>
|
61 |
+
<td><img src="makeup/116_3.png" height="256" width="256" alt="Color"></td>
|
62 |
+
</tr>
|
63 |
+
|
64 |
+
</table>
|
65 |
+
|
66 |
+
|
67 |
+
## References
|
68 |
+
- [BiSeNet](https://github.com/CoinCheung/BiSeNet)
|
models/bisenet/model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from models.bisenet.resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
return feat_out, feat_out16, feat_out32
|
255 |
+
|
256 |
+
def init_weight(self):
|
257 |
+
for ly in self.children():
|
258 |
+
if isinstance(ly, nn.Conv2d):
|
259 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
260 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
261 |
+
|
262 |
+
def get_params(self):
|
263 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
264 |
+
for name, child in self.named_children():
|
265 |
+
child_wd_params, child_nowd_params = child.get_params()
|
266 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
267 |
+
lr_mul_wd_params += child_wd_params
|
268 |
+
lr_mul_nowd_params += child_nowd_params
|
269 |
+
else:
|
270 |
+
wd_params += child_wd_params
|
271 |
+
nowd_params += child_nowd_params
|
272 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
net = BiSeNet(19)
|
277 |
+
net.cuda()
|
278 |
+
net.eval()
|
279 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
280 |
+
out, out16, out32 = net(in_ten)
|
281 |
+
print(out.shape)
|
282 |
+
|
283 |
+
net.get_params()
|
models/bisenet/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
models/encoders/__init__.py
ADDED
File without changes
|
models/encoders/helpers.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
4 |
+
|
5 |
+
"""
|
6 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
class Flatten(Module):
|
11 |
+
def forward(self, input):
|
12 |
+
return input.view(input.size(0), -1)
|
13 |
+
|
14 |
+
|
15 |
+
def l2_norm(input, axis=1):
|
16 |
+
norm = torch.norm(input, 2, axis, True)
|
17 |
+
output = torch.div(input, norm)
|
18 |
+
return output
|
19 |
+
|
20 |
+
|
21 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
22 |
+
""" A named tuple describing a ResNet block. """
|
23 |
+
|
24 |
+
|
25 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
26 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
27 |
+
|
28 |
+
|
29 |
+
def get_blocks(num_layers):
|
30 |
+
if num_layers == 50:
|
31 |
+
blocks = [
|
32 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
33 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
34 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
35 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
36 |
+
]
|
37 |
+
elif num_layers == 100:
|
38 |
+
blocks = [
|
39 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
40 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
41 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
42 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
43 |
+
]
|
44 |
+
elif num_layers == 152:
|
45 |
+
blocks = [
|
46 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
47 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
48 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
49 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
50 |
+
]
|
51 |
+
else:
|
52 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
53 |
+
return blocks
|
54 |
+
|
55 |
+
|
56 |
+
class SEModule(Module):
|
57 |
+
def __init__(self, channels, reduction):
|
58 |
+
super(SEModule, self).__init__()
|
59 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
60 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
61 |
+
self.relu = ReLU(inplace=True)
|
62 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
63 |
+
self.sigmoid = Sigmoid()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
module_input = x
|
67 |
+
x = self.avg_pool(x)
|
68 |
+
x = self.fc1(x)
|
69 |
+
x = self.relu(x)
|
70 |
+
x = self.fc2(x)
|
71 |
+
x = self.sigmoid(x)
|
72 |
+
return module_input * x
|
73 |
+
|
74 |
+
|
75 |
+
class bottleneck_IR(Module):
|
76 |
+
def __init__(self, in_channel, depth, stride):
|
77 |
+
super(bottleneck_IR, self).__init__()
|
78 |
+
if in_channel == depth:
|
79 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
80 |
+
else:
|
81 |
+
self.shortcut_layer = Sequential(
|
82 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
83 |
+
BatchNorm2d(depth)
|
84 |
+
)
|
85 |
+
self.res_layer = Sequential(
|
86 |
+
BatchNorm2d(in_channel),
|
87 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
88 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
shortcut = self.shortcut_layer(x)
|
93 |
+
res = self.res_layer(x)
|
94 |
+
return res + shortcut
|
95 |
+
|
96 |
+
|
97 |
+
class bottleneck_IR_SE(Module):
|
98 |
+
def __init__(self, in_channel, depth, stride):
|
99 |
+
super(bottleneck_IR_SE, self).__init__()
|
100 |
+
if in_channel == depth:
|
101 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
102 |
+
else:
|
103 |
+
self.shortcut_layer = Sequential(
|
104 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
105 |
+
BatchNorm2d(depth)
|
106 |
+
)
|
107 |
+
self.res_layer = Sequential(
|
108 |
+
BatchNorm2d(in_channel),
|
109 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
110 |
+
PReLU(depth),
|
111 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
112 |
+
BatchNorm2d(depth),
|
113 |
+
SEModule(depth, 16)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
shortcut = self.shortcut_layer(x)
|
118 |
+
res = self.res_layer(x)
|
119 |
+
return res + shortcut
|
models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
6 |
+
|
7 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
|
8 |
+
from models.stylegan2.model import EqualLinear
|
9 |
+
|
10 |
+
|
11 |
+
class GradualStyleBlock(Module):
|
12 |
+
def __init__(self, in_c, out_c, spatial, max_pooling=False):
|
13 |
+
super(GradualStyleBlock, self).__init__()
|
14 |
+
self.out_c = out_c
|
15 |
+
self.spatial = spatial
|
16 |
+
self.max_pooling = max_pooling
|
17 |
+
num_pools = int(np.log2(spatial))
|
18 |
+
modules = []
|
19 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
20 |
+
nn.LeakyReLU()]
|
21 |
+
for i in range(num_pools - 1):
|
22 |
+
modules += [
|
23 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
24 |
+
nn.LeakyReLU()
|
25 |
+
]
|
26 |
+
self.convs = nn.Sequential(*modules)
|
27 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.convs(x)
|
31 |
+
# To make E accept more general H*W images, we add global average pooling to
|
32 |
+
# resize all features to 1*1*512 before mapping to latent codes
|
33 |
+
if self.max_pooling:
|
34 |
+
x = F.adaptive_max_pool2d(x, 1) ##### modified
|
35 |
+
else:
|
36 |
+
x = F.adaptive_avg_pool2d(x, 1) ##### modified
|
37 |
+
x = x.view(-1, self.out_c)
|
38 |
+
x = self.linear(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
class AdaptiveInstanceNorm(nn.Module):
|
42 |
+
def __init__(self, fin, style_dim=512):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.norm = nn.InstanceNorm2d(fin, affine=False)
|
46 |
+
self.style = nn.Linear(style_dim, fin * 2)
|
47 |
+
|
48 |
+
self.style.bias.data[:fin] = 1
|
49 |
+
self.style.bias.data[fin:] = 0
|
50 |
+
|
51 |
+
def forward(self, input, style):
|
52 |
+
style = self.style(style).unsqueeze(2).unsqueeze(3)
|
53 |
+
gamma, beta = style.chunk(2, 1)
|
54 |
+
out = self.norm(input)
|
55 |
+
out = gamma * out + beta
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class FusionLayer(Module): ##### modified
|
60 |
+
def __init__(self, inchannel, outchannel, use_skip_torgb=True, use_att=0):
|
61 |
+
super(FusionLayer, self).__init__()
|
62 |
+
|
63 |
+
self.transform = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1),
|
64 |
+
nn.LeakyReLU())
|
65 |
+
self.fusion_out = nn.Conv2d(outchannel*2, outchannel, kernel_size=3, stride=1, padding=1)
|
66 |
+
self.fusion_out.weight.data *= 0.01
|
67 |
+
self.fusion_out.weight[:,0:outchannel,1,1].data += torch.eye(outchannel)
|
68 |
+
|
69 |
+
self.use_skip_torgb = use_skip_torgb
|
70 |
+
if use_skip_torgb:
|
71 |
+
self.fusion_skip = nn.Conv2d(3+outchannel, 3, kernel_size=3, stride=1, padding=1)
|
72 |
+
self.fusion_skip.weight.data *= 0.01
|
73 |
+
self.fusion_skip.weight[:,0:3,1,1].data += torch.eye(3)
|
74 |
+
|
75 |
+
self.use_att = use_att
|
76 |
+
if use_att:
|
77 |
+
modules = []
|
78 |
+
modules.append(nn.Linear(512, outchannel))
|
79 |
+
for _ in range(use_att):
|
80 |
+
modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
81 |
+
modules.append(nn.Linear(outchannel, outchannel))
|
82 |
+
modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
83 |
+
self.linear = Sequential(*modules)
|
84 |
+
self.norm = AdaptiveInstanceNorm(outchannel*2, outchannel)
|
85 |
+
self.conv = nn.Conv2d(outchannel*2, 1, 3, 1, 1, bias=True)
|
86 |
+
|
87 |
+
def forward(self, feat, out, skip, editing_w=None):
|
88 |
+
x = self.transform(feat)
|
89 |
+
# similar to VToonify, use editing vector as condition
|
90 |
+
# fuse encoder feature and decoder feature with a predicted attention mask m_E
|
91 |
+
# if self.use_att = False, just fuse them with a simple conv layer
|
92 |
+
if self.use_att and editing_w is not None:
|
93 |
+
label = self.linear(editing_w)
|
94 |
+
m_E = (F.relu(self.conv(self.norm(torch.cat([out, abs(out-x)], dim=1), label)))).tanh()
|
95 |
+
x = x * m_E
|
96 |
+
out = self.fusion_out(torch.cat((out, x), dim=1))
|
97 |
+
if self.use_skip_torgb:
|
98 |
+
skip = self.fusion_skip(torch.cat((skip, x), dim=1))
|
99 |
+
return out, skip
|
100 |
+
|
101 |
+
|
102 |
+
class ResnetBlock(nn.Module):
|
103 |
+
def __init__(self, dim):
|
104 |
+
super(ResnetBlock, self).__init__()
|
105 |
+
|
106 |
+
self.conv_block = nn.Sequential(Conv2d(dim, dim, 3, 1, 1),
|
107 |
+
nn.LeakyReLU(),
|
108 |
+
Conv2d(dim, dim, 3, 1, 1))
|
109 |
+
self.relu = nn.LeakyReLU()
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
out = x + self.conv_block(x)
|
113 |
+
return self.relu(out)
|
114 |
+
|
115 |
+
# trainable light-weight translation network T
|
116 |
+
# for sketch/mask-to-face translation,
|
117 |
+
# we add a trainable T to map y to an intermediate domain where E can more easily extract features.
|
118 |
+
class ResnetGenerator(nn.Module):
|
119 |
+
def __init__(self, in_channel=19, res_num=2):
|
120 |
+
super(ResnetGenerator, self).__init__()
|
121 |
+
|
122 |
+
modules = []
|
123 |
+
modules.append(Conv2d(in_channel, 16, 3, 2, 1))
|
124 |
+
modules.append(nn.LeakyReLU())
|
125 |
+
modules.append(Conv2d(16, 16, 3, 2, 1))
|
126 |
+
modules.append(nn.LeakyReLU())
|
127 |
+
for _ in range(res_num):
|
128 |
+
modules.append(ResnetBlock(16))
|
129 |
+
for _ in range(2):
|
130 |
+
modules.append(nn.ConvTranspose2d(16, 16, 3, 2, 1, output_padding=1))
|
131 |
+
modules.append(nn.LeakyReLU())
|
132 |
+
modules.append(Conv2d(16, 64, 3, 1, 1, bias=False))
|
133 |
+
modules.append(BatchNorm2d(64))
|
134 |
+
modules.append(PReLU(64))
|
135 |
+
self.model = Sequential(*modules)
|
136 |
+
|
137 |
+
def forward(self, input):
|
138 |
+
return self.model(input)
|
139 |
+
|
140 |
+
class GradualStyleEncoder(Module):
|
141 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
142 |
+
super(GradualStyleEncoder, self).__init__()
|
143 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
144 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
145 |
+
blocks = get_blocks(num_layers)
|
146 |
+
if mode == 'ir':
|
147 |
+
unit_module = bottleneck_IR
|
148 |
+
elif mode == 'ir_se':
|
149 |
+
unit_module = bottleneck_IR_SE
|
150 |
+
|
151 |
+
# for sketch/mask-to-face translation, add a new network T
|
152 |
+
if opts.input_nc != 3:
|
153 |
+
self.input_label_layer = ResnetGenerator(opts.input_nc, opts.res_num)
|
154 |
+
|
155 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
156 |
+
BatchNorm2d(64),
|
157 |
+
PReLU(64))
|
158 |
+
modules = []
|
159 |
+
for block in blocks:
|
160 |
+
for bottleneck in block:
|
161 |
+
modules.append(unit_module(bottleneck.in_channel,
|
162 |
+
bottleneck.depth,
|
163 |
+
bottleneck.stride))
|
164 |
+
self.body = Sequential(*modules)
|
165 |
+
|
166 |
+
self.styles = nn.ModuleList()
|
167 |
+
self.style_count = opts.n_styles
|
168 |
+
self.coarse_ind = 3
|
169 |
+
self.middle_ind = 7
|
170 |
+
for i in range(self.style_count):
|
171 |
+
if i < self.coarse_ind:
|
172 |
+
style = GradualStyleBlock(512, 512, 16, 'max_pooling' in opts and opts.max_pooling)
|
173 |
+
elif i < self.middle_ind:
|
174 |
+
style = GradualStyleBlock(512, 512, 32, 'max_pooling' in opts and opts.max_pooling)
|
175 |
+
else:
|
176 |
+
style = GradualStyleBlock(512, 512, 64, 'max_pooling' in opts and opts.max_pooling)
|
177 |
+
self.styles.append(style)
|
178 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
179 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
180 |
+
|
181 |
+
# we concatenate pSp features in the middle layers and
|
182 |
+
# add a convolution layer to map the concatenated features to the first-layer input feature f of G.
|
183 |
+
self.featlayer = nn.Conv2d(768, 512, kernel_size=1, stride=1, padding=0) ##### modified
|
184 |
+
self.skiplayer = nn.Conv2d(768, 3, kernel_size=1, stride=1, padding=0) ##### modified
|
185 |
+
|
186 |
+
# skip connection
|
187 |
+
if 'use_skip' in opts and opts.use_skip: ##### modified
|
188 |
+
self.fusion = nn.ModuleList()
|
189 |
+
channels = [[256,512], [256,512], [256,512], [256,512], [128,512], [64,256], [64,128]]
|
190 |
+
# opts.skip_max_layer: how many layers are skipped to the decoder
|
191 |
+
for inc, outc in channels[:max(1, min(7, opts.skip_max_layer))]: # from 4 to 256
|
192 |
+
self.fusion.append(FusionLayer(inc, outc, opts.use_skip_torgb, opts.use_att))
|
193 |
+
|
194 |
+
def _upsample_add(self, x, y):
|
195 |
+
'''Upsample and add two feature maps.
|
196 |
+
Args:
|
197 |
+
x: (Variable) top feature map to be upsampled.
|
198 |
+
y: (Variable) lateral feature map.
|
199 |
+
Returns:
|
200 |
+
(Variable) added feature map.
|
201 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
202 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
203 |
+
maybe not equal to the lateral feature map size.
|
204 |
+
e.g.
|
205 |
+
original input size: [N,_,15,15] ->
|
206 |
+
conv2d feature map size: [N,_,8,8] ->
|
207 |
+
upsampled feature map size: [N,_,16,16]
|
208 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
209 |
+
'''
|
210 |
+
_, _, H, W = y.size()
|
211 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
212 |
+
|
213 |
+
# return_feat: return f
|
214 |
+
# return_full: return f and the skipped encoder features
|
215 |
+
# return [out, feats]
|
216 |
+
# out is the style latent code w+
|
217 |
+
# feats[0] is f for the 1st conv layer, feats[1] is f for the 1st torgb layer
|
218 |
+
# feats[2-8] is the skipped encoder features
|
219 |
+
def forward(self, x, return_feat=False, return_full=False): ##### modified
|
220 |
+
if x.shape[1] != 3:
|
221 |
+
x = self.input_label_layer(x)
|
222 |
+
else:
|
223 |
+
x = self.input_layer(x)
|
224 |
+
c256 = x ##### modified
|
225 |
+
|
226 |
+
latents = []
|
227 |
+
modulelist = list(self.body._modules.values())
|
228 |
+
for i, l in enumerate(modulelist):
|
229 |
+
x = l(x)
|
230 |
+
if i == 2: ##### modified
|
231 |
+
c128 = x
|
232 |
+
elif i == 6:
|
233 |
+
c1 = x
|
234 |
+
elif i == 10: ##### modified
|
235 |
+
c21 = x ##### modified
|
236 |
+
elif i == 15: ##### modified
|
237 |
+
c22 = x ##### modified
|
238 |
+
elif i == 20:
|
239 |
+
c2 = x
|
240 |
+
elif i == 23:
|
241 |
+
c3 = x
|
242 |
+
|
243 |
+
for j in range(self.coarse_ind):
|
244 |
+
latents.append(self.styles[j](c3))
|
245 |
+
|
246 |
+
p2 = self._upsample_add(c3, self.latlayer1(c2))
|
247 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
248 |
+
latents.append(self.styles[j](p2))
|
249 |
+
|
250 |
+
p1 = self._upsample_add(p2, self.latlayer2(c1))
|
251 |
+
for j in range(self.middle_ind, self.style_count):
|
252 |
+
latents.append(self.styles[j](p1))
|
253 |
+
|
254 |
+
out = torch.stack(latents, dim=1)
|
255 |
+
|
256 |
+
if not return_feat:
|
257 |
+
return out
|
258 |
+
|
259 |
+
feats = [self.featlayer(torch.cat((c21, c22, c2), dim=1)), self.skiplayer(torch.cat((c21, c22, c2), dim=1))]
|
260 |
+
|
261 |
+
if return_full: ##### modified
|
262 |
+
feats += [c2, c2, c22, c21, c1, c128, c256]
|
263 |
+
|
264 |
+
return out, feats
|
265 |
+
|
266 |
+
|
267 |
+
# only compute the first-layer feature f
|
268 |
+
# E_F in the paper
|
269 |
+
def get_feat(self, x): ##### modified
|
270 |
+
# for sketch/mask-to-face translation
|
271 |
+
# use a trainable light-weight translation network T
|
272 |
+
if x.shape[1] != 3:
|
273 |
+
x = self.input_label_layer(x)
|
274 |
+
else:
|
275 |
+
x = self.input_layer(x)
|
276 |
+
|
277 |
+
latents = []
|
278 |
+
modulelist = list(self.body._modules.values())
|
279 |
+
for i, l in enumerate(modulelist):
|
280 |
+
x = l(x)
|
281 |
+
if i == 10: ##### modified
|
282 |
+
c21 = x ##### modified
|
283 |
+
elif i == 15: ##### modified
|
284 |
+
c22 = x ##### modified
|
285 |
+
elif i == 20:
|
286 |
+
c2 = x
|
287 |
+
break
|
288 |
+
return self.featlayer(torch.cat((c21, c22, c2), dim=1))
|
289 |
+
|
290 |
+
class BackboneEncoderUsingLastLayerIntoW(Module):
|
291 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
292 |
+
super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
|
293 |
+
print('Using BackboneEncoderUsingLastLayerIntoW')
|
294 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
295 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
296 |
+
blocks = get_blocks(num_layers)
|
297 |
+
if mode == 'ir':
|
298 |
+
unit_module = bottleneck_IR
|
299 |
+
elif mode == 'ir_se':
|
300 |
+
unit_module = bottleneck_IR_SE
|
301 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
302 |
+
BatchNorm2d(64),
|
303 |
+
PReLU(64))
|
304 |
+
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
305 |
+
self.linear = EqualLinear(512, 512, lr_mul=1)
|
306 |
+
modules = []
|
307 |
+
for block in blocks:
|
308 |
+
for bottleneck in block:
|
309 |
+
modules.append(unit_module(bottleneck.in_channel,
|
310 |
+
bottleneck.depth,
|
311 |
+
bottleneck.stride))
|
312 |
+
self.body = Sequential(*modules)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
x = self.input_layer(x)
|
316 |
+
x = self.body(x)
|
317 |
+
x = self.output_pool(x)
|
318 |
+
x = x.view(-1, 512)
|
319 |
+
x = self.linear(x)
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
class BackboneEncoderUsingLastLayerIntoWPlus(Module):
|
324 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
325 |
+
super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
|
326 |
+
print('Using BackboneEncoderUsingLastLayerIntoWPlus')
|
327 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
328 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
329 |
+
blocks = get_blocks(num_layers)
|
330 |
+
if mode == 'ir':
|
331 |
+
unit_module = bottleneck_IR
|
332 |
+
elif mode == 'ir_se':
|
333 |
+
unit_module = bottleneck_IR_SE
|
334 |
+
self.n_styles = opts.n_styles
|
335 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
336 |
+
BatchNorm2d(64),
|
337 |
+
PReLU(64))
|
338 |
+
self.output_layer_2 = Sequential(BatchNorm2d(512),
|
339 |
+
torch.nn.AdaptiveAvgPool2d((7, 7)),
|
340 |
+
Flatten(),
|
341 |
+
Linear(512 * 7 * 7, 512))
|
342 |
+
self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
|
343 |
+
modules = []
|
344 |
+
for block in blocks:
|
345 |
+
for bottleneck in block:
|
346 |
+
modules.append(unit_module(bottleneck.in_channel,
|
347 |
+
bottleneck.depth,
|
348 |
+
bottleneck.stride))
|
349 |
+
self.body = Sequential(*modules)
|
350 |
+
|
351 |
+
def forward(self, x):
|
352 |
+
x = self.input_layer(x)
|
353 |
+
x = self.body(x)
|
354 |
+
x = self.output_layer_2(x)
|
355 |
+
x = self.linear(x)
|
356 |
+
x = x.view(-1, self.n_styles, 512)
|
357 |
+
return x
|
models/mtcnn/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
|
5 |
+
from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
|
7 |
+
from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
|
8 |
+
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
class MTCNN():
|
13 |
+
def __init__(self):
|
14 |
+
print(device)
|
15 |
+
self.pnet = PNet().to(device)
|
16 |
+
self.rnet = RNet().to(device)
|
17 |
+
self.onet = ONet().to(device)
|
18 |
+
self.pnet.eval()
|
19 |
+
self.rnet.eval()
|
20 |
+
self.onet.eval()
|
21 |
+
self.refrence = get_reference_facial_points(default_square=True)
|
22 |
+
|
23 |
+
def align(self, img):
|
24 |
+
_, landmarks = self.detect_faces(img)
|
25 |
+
if len(landmarks) == 0:
|
26 |
+
return None, None
|
27 |
+
facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
|
28 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
29 |
+
return Image.fromarray(warped_face), tfm
|
30 |
+
|
31 |
+
def align_multi(self, img, limit=None, min_face_size=30.0):
|
32 |
+
boxes, landmarks = self.detect_faces(img, min_face_size)
|
33 |
+
if limit:
|
34 |
+
boxes = boxes[:limit]
|
35 |
+
landmarks = landmarks[:limit]
|
36 |
+
faces = []
|
37 |
+
tfms = []
|
38 |
+
for landmark in landmarks:
|
39 |
+
facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
|
40 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
41 |
+
faces.append(Image.fromarray(warped_face))
|
42 |
+
tfms.append(tfm)
|
43 |
+
return boxes, faces, tfms
|
44 |
+
|
45 |
+
def detect_faces(self, image, min_face_size=20.0,
|
46 |
+
thresholds=[0.15, 0.25, 0.35],
|
47 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
48 |
+
"""
|
49 |
+
Arguments:
|
50 |
+
image: an instance of PIL.Image.
|
51 |
+
min_face_size: a float number.
|
52 |
+
thresholds: a list of length 3.
|
53 |
+
nms_thresholds: a list of length 3.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
57 |
+
bounding boxes and facial landmarks.
|
58 |
+
"""
|
59 |
+
|
60 |
+
# BUILD AN IMAGE PYRAMID
|
61 |
+
width, height = image.size
|
62 |
+
min_length = min(height, width)
|
63 |
+
|
64 |
+
min_detection_size = 12
|
65 |
+
factor = 0.707 # sqrt(0.5)
|
66 |
+
|
67 |
+
# scales for scaling the image
|
68 |
+
scales = []
|
69 |
+
|
70 |
+
# scales the image so that
|
71 |
+
# minimum size that we can detect equals to
|
72 |
+
# minimum face size that we want to detect
|
73 |
+
m = min_detection_size / min_face_size
|
74 |
+
min_length *= m
|
75 |
+
|
76 |
+
factor_count = 0
|
77 |
+
while min_length > min_detection_size:
|
78 |
+
scales.append(m * factor ** factor_count)
|
79 |
+
min_length *= factor
|
80 |
+
factor_count += 1
|
81 |
+
|
82 |
+
# STAGE 1
|
83 |
+
|
84 |
+
# it will be returned
|
85 |
+
bounding_boxes = []
|
86 |
+
|
87 |
+
with torch.no_grad():
|
88 |
+
# run P-Net on different scales
|
89 |
+
for s in scales:
|
90 |
+
boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
|
91 |
+
bounding_boxes.append(boxes)
|
92 |
+
|
93 |
+
# collect boxes (and offsets, and scores) from different scales
|
94 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
95 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
96 |
+
|
97 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
98 |
+
bounding_boxes = bounding_boxes[keep]
|
99 |
+
|
100 |
+
# use offsets predicted by pnet to transform bounding boxes
|
101 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
102 |
+
# shape [n_boxes, 5]
|
103 |
+
|
104 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
105 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
106 |
+
|
107 |
+
# STAGE 2
|
108 |
+
|
109 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
110 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
111 |
+
|
112 |
+
output = self.rnet(img_boxes)
|
113 |
+
offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
|
114 |
+
probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
|
115 |
+
|
116 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
117 |
+
bounding_boxes = bounding_boxes[keep]
|
118 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
119 |
+
offsets = offsets[keep]
|
120 |
+
|
121 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
122 |
+
bounding_boxes = bounding_boxes[keep]
|
123 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
124 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
125 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
126 |
+
|
127 |
+
# STAGE 3
|
128 |
+
|
129 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
130 |
+
if len(img_boxes) == 0:
|
131 |
+
return [], []
|
132 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
133 |
+
output = self.onet(img_boxes)
|
134 |
+
landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
|
135 |
+
offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
|
136 |
+
probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
|
137 |
+
|
138 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
139 |
+
bounding_boxes = bounding_boxes[keep]
|
140 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
141 |
+
offsets = offsets[keep]
|
142 |
+
landmarks = landmarks[keep]
|
143 |
+
|
144 |
+
# compute landmark points
|
145 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
146 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
147 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
148 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
149 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
150 |
+
|
151 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
152 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
153 |
+
bounding_boxes = bounding_boxes[keep]
|
154 |
+
landmarks = landmarks[keep]
|
155 |
+
|
156 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn_pytorch/src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .visualization_utils import show_bboxes
|
2 |
+
from .detector import detect_faces
|
models/mtcnn/mtcnn_pytorch/src/align_trans.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon Apr 24 15:43:29 2017
|
4 |
+
@author: zhaoy
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
# from scipy.linalg import lstsq
|
10 |
+
# from scipy.ndimage import geometric_transform # , map_coordinates
|
11 |
+
|
12 |
+
from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
|
13 |
+
|
14 |
+
# reference facial points, a list of coordinates (x,y)
|
15 |
+
REFERENCE_FACIAL_POINTS = [
|
16 |
+
[30.29459953, 51.69630051],
|
17 |
+
[65.53179932, 51.50139999],
|
18 |
+
[48.02519989, 71.73660278],
|
19 |
+
[33.54930115, 92.3655014],
|
20 |
+
[62.72990036, 92.20410156]
|
21 |
+
]
|
22 |
+
|
23 |
+
DEFAULT_CROP_SIZE = (96, 112)
|
24 |
+
|
25 |
+
|
26 |
+
class FaceWarpException(Exception):
|
27 |
+
def __str__(self):
|
28 |
+
return 'In File {}:{}'.format(
|
29 |
+
__file__, super.__str__(self))
|
30 |
+
|
31 |
+
|
32 |
+
def get_reference_facial_points(output_size=None,
|
33 |
+
inner_padding_factor=0.0,
|
34 |
+
outer_padding=(0, 0),
|
35 |
+
default_square=False):
|
36 |
+
"""
|
37 |
+
Function:
|
38 |
+
----------
|
39 |
+
get reference 5 key points according to crop settings:
|
40 |
+
0. Set default crop_size:
|
41 |
+
if default_square:
|
42 |
+
crop_size = (112, 112)
|
43 |
+
else:
|
44 |
+
crop_size = (96, 112)
|
45 |
+
1. Pad the crop_size by inner_padding_factor in each side;
|
46 |
+
2. Resize crop_size into (output_size - outer_padding*2),
|
47 |
+
pad into output_size with outer_padding;
|
48 |
+
3. Output reference_5point;
|
49 |
+
Parameters:
|
50 |
+
----------
|
51 |
+
@output_size: (w, h) or None
|
52 |
+
size of aligned face image
|
53 |
+
@inner_padding_factor: (w_factor, h_factor)
|
54 |
+
padding factor for inner (w, h)
|
55 |
+
@outer_padding: (w_pad, h_pad)
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
@default_square: True or False
|
58 |
+
if True:
|
59 |
+
default crop_size = (112, 112)
|
60 |
+
else:
|
61 |
+
default crop_size = (96, 112);
|
62 |
+
!!! make sure, if output_size is not None:
|
63 |
+
(output_size - outer_padding)
|
64 |
+
= some_scale * (default crop_size * (1.0 + inner_padding_factor))
|
65 |
+
Returns:
|
66 |
+
----------
|
67 |
+
@reference_5point: 5x2 np.array
|
68 |
+
each row is a pair of transformed coordinates (x, y)
|
69 |
+
"""
|
70 |
+
# print('\n===> get_reference_facial_points():')
|
71 |
+
|
72 |
+
# print('---> Params:')
|
73 |
+
# print(' output_size: ', output_size)
|
74 |
+
# print(' inner_padding_factor: ', inner_padding_factor)
|
75 |
+
# print(' outer_padding:', outer_padding)
|
76 |
+
# print(' default_square: ', default_square)
|
77 |
+
|
78 |
+
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
79 |
+
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
80 |
+
|
81 |
+
# 0) make the inner region a square
|
82 |
+
if default_square:
|
83 |
+
size_diff = max(tmp_crop_size) - tmp_crop_size
|
84 |
+
tmp_5pts += size_diff / 2
|
85 |
+
tmp_crop_size += size_diff
|
86 |
+
|
87 |
+
# print('---> default:')
|
88 |
+
# print(' crop_size = ', tmp_crop_size)
|
89 |
+
# print(' reference_5pts = ', tmp_5pts)
|
90 |
+
|
91 |
+
if (output_size and
|
92 |
+
output_size[0] == tmp_crop_size[0] and
|
93 |
+
output_size[1] == tmp_crop_size[1]):
|
94 |
+
# print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
95 |
+
return tmp_5pts
|
96 |
+
|
97 |
+
if (inner_padding_factor == 0 and
|
98 |
+
outer_padding == (0, 0)):
|
99 |
+
if output_size is None:
|
100 |
+
# print('No paddings to do: return default reference points')
|
101 |
+
return tmp_5pts
|
102 |
+
else:
|
103 |
+
raise FaceWarpException(
|
104 |
+
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
105 |
+
|
106 |
+
# check output size
|
107 |
+
if not (0 <= inner_padding_factor <= 1.0):
|
108 |
+
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
109 |
+
|
110 |
+
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
111 |
+
and output_size is None):
|
112 |
+
output_size = tmp_crop_size * \
|
113 |
+
(1 + inner_padding_factor * 2).astype(np.int32)
|
114 |
+
output_size += np.array(outer_padding)
|
115 |
+
# print(' deduced from paddings, output_size = ', output_size)
|
116 |
+
|
117 |
+
if not (outer_padding[0] < output_size[0]
|
118 |
+
and outer_padding[1] < output_size[1]):
|
119 |
+
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
120 |
+
'and outer_padding[1] < output_size[1])')
|
121 |
+
|
122 |
+
# 1) pad the inner region according inner_padding_factor
|
123 |
+
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
124 |
+
if inner_padding_factor > 0:
|
125 |
+
size_diff = tmp_crop_size * inner_padding_factor * 2
|
126 |
+
tmp_5pts += size_diff / 2
|
127 |
+
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
128 |
+
|
129 |
+
# print(' crop_size = ', tmp_crop_size)
|
130 |
+
# print(' reference_5pts = ', tmp_5pts)
|
131 |
+
|
132 |
+
# 2) resize the padded inner region
|
133 |
+
# print('---> STEP2: resize the padded inner region')
|
134 |
+
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
135 |
+
# print(' crop_size = ', tmp_crop_size)
|
136 |
+
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
137 |
+
|
138 |
+
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
139 |
+
raise FaceWarpException('Must have (output_size - outer_padding)'
|
140 |
+
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
141 |
+
|
142 |
+
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
143 |
+
# print(' resize scale_factor = ', scale_factor)
|
144 |
+
tmp_5pts = tmp_5pts * scale_factor
|
145 |
+
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
146 |
+
# tmp_5pts = tmp_5pts + size_diff / 2
|
147 |
+
tmp_crop_size = size_bf_outer_pad
|
148 |
+
# print(' crop_size = ', tmp_crop_size)
|
149 |
+
# print(' reference_5pts = ', tmp_5pts)
|
150 |
+
|
151 |
+
# 3) add outer_padding to make output_size
|
152 |
+
reference_5point = tmp_5pts + np.array(outer_padding)
|
153 |
+
tmp_crop_size = output_size
|
154 |
+
# print('---> STEP3: add outer_padding to make output_size')
|
155 |
+
# print(' crop_size = ', tmp_crop_size)
|
156 |
+
# print(' reference_5pts = ', tmp_5pts)
|
157 |
+
|
158 |
+
# print('===> end get_reference_facial_points\n')
|
159 |
+
|
160 |
+
return reference_5point
|
161 |
+
|
162 |
+
|
163 |
+
def get_affine_transform_matrix(src_pts, dst_pts):
|
164 |
+
"""
|
165 |
+
Function:
|
166 |
+
----------
|
167 |
+
get affine transform matrix 'tfm' from src_pts to dst_pts
|
168 |
+
Parameters:
|
169 |
+
----------
|
170 |
+
@src_pts: Kx2 np.array
|
171 |
+
source points matrix, each row is a pair of coordinates (x, y)
|
172 |
+
@dst_pts: Kx2 np.array
|
173 |
+
destination points matrix, each row is a pair of coordinates (x, y)
|
174 |
+
Returns:
|
175 |
+
----------
|
176 |
+
@tfm: 2x3 np.array
|
177 |
+
transform matrix from src_pts to dst_pts
|
178 |
+
"""
|
179 |
+
|
180 |
+
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
181 |
+
n_pts = src_pts.shape[0]
|
182 |
+
ones = np.ones((n_pts, 1), src_pts.dtype)
|
183 |
+
src_pts_ = np.hstack([src_pts, ones])
|
184 |
+
dst_pts_ = np.hstack([dst_pts, ones])
|
185 |
+
|
186 |
+
# #print(('src_pts_:\n' + str(src_pts_))
|
187 |
+
# #print(('dst_pts_:\n' + str(dst_pts_))
|
188 |
+
|
189 |
+
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
190 |
+
|
191 |
+
# #print(('np.linalg.lstsq return A: \n' + str(A))
|
192 |
+
# #print(('np.linalg.lstsq return res: \n' + str(res))
|
193 |
+
# #print(('np.linalg.lstsq return rank: \n' + str(rank))
|
194 |
+
# #print(('np.linalg.lstsq return s: \n' + str(s))
|
195 |
+
|
196 |
+
if rank == 3:
|
197 |
+
tfm = np.float32([
|
198 |
+
[A[0, 0], A[1, 0], A[2, 0]],
|
199 |
+
[A[0, 1], A[1, 1], A[2, 1]]
|
200 |
+
])
|
201 |
+
elif rank == 2:
|
202 |
+
tfm = np.float32([
|
203 |
+
[A[0, 0], A[1, 0], 0],
|
204 |
+
[A[0, 1], A[1, 1], 0]
|
205 |
+
])
|
206 |
+
|
207 |
+
return tfm
|
208 |
+
|
209 |
+
|
210 |
+
def warp_and_crop_face(src_img,
|
211 |
+
facial_pts,
|
212 |
+
reference_pts=None,
|
213 |
+
crop_size=(96, 112),
|
214 |
+
align_type='smilarity'):
|
215 |
+
"""
|
216 |
+
Function:
|
217 |
+
----------
|
218 |
+
apply affine transform 'trans' to uv
|
219 |
+
Parameters:
|
220 |
+
----------
|
221 |
+
@src_img: 3x3 np.array
|
222 |
+
input image
|
223 |
+
@facial_pts: could be
|
224 |
+
1)a list of K coordinates (x,y)
|
225 |
+
or
|
226 |
+
2) Kx2 or 2xK np.array
|
227 |
+
each row or col is a pair of coordinates (x, y)
|
228 |
+
@reference_pts: could be
|
229 |
+
1) a list of K coordinates (x,y)
|
230 |
+
or
|
231 |
+
2) Kx2 or 2xK np.array
|
232 |
+
each row or col is a pair of coordinates (x, y)
|
233 |
+
or
|
234 |
+
3) None
|
235 |
+
if None, use default reference facial points
|
236 |
+
@crop_size: (w, h)
|
237 |
+
output face image size
|
238 |
+
@align_type: transform type, could be one of
|
239 |
+
1) 'similarity': use similarity transform
|
240 |
+
2) 'cv2_affine': use the first 3 points to do affine transform,
|
241 |
+
by calling cv2.getAffineTransform()
|
242 |
+
3) 'affine': use all points to do affine transform
|
243 |
+
Returns:
|
244 |
+
----------
|
245 |
+
@face_img: output face image with size (w, h) = @crop_size
|
246 |
+
"""
|
247 |
+
|
248 |
+
if reference_pts is None:
|
249 |
+
if crop_size[0] == 96 and crop_size[1] == 112:
|
250 |
+
reference_pts = REFERENCE_FACIAL_POINTS
|
251 |
+
else:
|
252 |
+
default_square = False
|
253 |
+
inner_padding_factor = 0
|
254 |
+
outer_padding = (0, 0)
|
255 |
+
output_size = crop_size
|
256 |
+
|
257 |
+
reference_pts = get_reference_facial_points(output_size,
|
258 |
+
inner_padding_factor,
|
259 |
+
outer_padding,
|
260 |
+
default_square)
|
261 |
+
|
262 |
+
ref_pts = np.float32(reference_pts)
|
263 |
+
ref_pts_shp = ref_pts.shape
|
264 |
+
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
265 |
+
raise FaceWarpException(
|
266 |
+
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
267 |
+
|
268 |
+
if ref_pts_shp[0] == 2:
|
269 |
+
ref_pts = ref_pts.T
|
270 |
+
|
271 |
+
src_pts = np.float32(facial_pts)
|
272 |
+
src_pts_shp = src_pts.shape
|
273 |
+
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
274 |
+
raise FaceWarpException(
|
275 |
+
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
276 |
+
|
277 |
+
if src_pts_shp[0] == 2:
|
278 |
+
src_pts = src_pts.T
|
279 |
+
|
280 |
+
# #print('--->src_pts:\n', src_pts
|
281 |
+
# #print('--->ref_pts\n', ref_pts
|
282 |
+
|
283 |
+
if src_pts.shape != ref_pts.shape:
|
284 |
+
raise FaceWarpException(
|
285 |
+
'facial_pts and reference_pts must have the same shape')
|
286 |
+
|
287 |
+
if align_type is 'cv2_affine':
|
288 |
+
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
289 |
+
# #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
|
290 |
+
elif align_type is 'affine':
|
291 |
+
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
292 |
+
# #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
|
293 |
+
else:
|
294 |
+
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
|
295 |
+
# #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
|
296 |
+
|
297 |
+
# #print('--->Transform matrix: '
|
298 |
+
# #print(('type(tfm):' + str(type(tfm)))
|
299 |
+
# #print(('tfm.dtype:' + str(tfm.dtype))
|
300 |
+
# #print( tfm
|
301 |
+
|
302 |
+
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
|
303 |
+
|
304 |
+
return face_img, tfm
|
models/mtcnn/mtcnn_pytorch/src/box_utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def nms(boxes, overlap_threshold=0.5, mode='union'):
|
6 |
+
"""Non-maximum suppression.
|
7 |
+
|
8 |
+
Arguments:
|
9 |
+
boxes: a float numpy array of shape [n, 5],
|
10 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
11 |
+
overlap_threshold: a float number.
|
12 |
+
mode: 'union' or 'min'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
list with indices of the selected boxes
|
16 |
+
"""
|
17 |
+
|
18 |
+
# if there are no boxes, return the empty list
|
19 |
+
if len(boxes) == 0:
|
20 |
+
return []
|
21 |
+
|
22 |
+
# list of picked indices
|
23 |
+
pick = []
|
24 |
+
|
25 |
+
# grab the coordinates of the bounding boxes
|
26 |
+
x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
|
27 |
+
|
28 |
+
area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
|
29 |
+
ids = np.argsort(score) # in increasing order
|
30 |
+
|
31 |
+
while len(ids) > 0:
|
32 |
+
|
33 |
+
# grab index of the largest value
|
34 |
+
last = len(ids) - 1
|
35 |
+
i = ids[last]
|
36 |
+
pick.append(i)
|
37 |
+
|
38 |
+
# compute intersections
|
39 |
+
# of the box with the largest score
|
40 |
+
# with the rest of boxes
|
41 |
+
|
42 |
+
# left top corner of intersection boxes
|
43 |
+
ix1 = np.maximum(x1[i], x1[ids[:last]])
|
44 |
+
iy1 = np.maximum(y1[i], y1[ids[:last]])
|
45 |
+
|
46 |
+
# right bottom corner of intersection boxes
|
47 |
+
ix2 = np.minimum(x2[i], x2[ids[:last]])
|
48 |
+
iy2 = np.minimum(y2[i], y2[ids[:last]])
|
49 |
+
|
50 |
+
# width and height of intersection boxes
|
51 |
+
w = np.maximum(0.0, ix2 - ix1 + 1.0)
|
52 |
+
h = np.maximum(0.0, iy2 - iy1 + 1.0)
|
53 |
+
|
54 |
+
# intersections' areas
|
55 |
+
inter = w * h
|
56 |
+
if mode == 'min':
|
57 |
+
overlap = inter / np.minimum(area[i], area[ids[:last]])
|
58 |
+
elif mode == 'union':
|
59 |
+
# intersection over union (IoU)
|
60 |
+
overlap = inter / (area[i] + area[ids[:last]] - inter)
|
61 |
+
|
62 |
+
# delete all boxes where overlap is too big
|
63 |
+
ids = np.delete(
|
64 |
+
ids,
|
65 |
+
np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
|
66 |
+
)
|
67 |
+
|
68 |
+
return pick
|
69 |
+
|
70 |
+
|
71 |
+
def convert_to_square(bboxes):
|
72 |
+
"""Convert bounding boxes to a square form.
|
73 |
+
|
74 |
+
Arguments:
|
75 |
+
bboxes: a float numpy array of shape [n, 5].
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
a float numpy array of shape [n, 5],
|
79 |
+
squared bounding boxes.
|
80 |
+
"""
|
81 |
+
|
82 |
+
square_bboxes = np.zeros_like(bboxes)
|
83 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
84 |
+
h = y2 - y1 + 1.0
|
85 |
+
w = x2 - x1 + 1.0
|
86 |
+
max_side = np.maximum(h, w)
|
87 |
+
square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
|
88 |
+
square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
|
89 |
+
square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
|
90 |
+
square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
|
91 |
+
return square_bboxes
|
92 |
+
|
93 |
+
|
94 |
+
def calibrate_box(bboxes, offsets):
|
95 |
+
"""Transform bounding boxes to be more like true bounding boxes.
|
96 |
+
'offsets' is one of the outputs of the nets.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
bboxes: a float numpy array of shape [n, 5].
|
100 |
+
offsets: a float numpy array of shape [n, 4].
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
a float numpy array of shape [n, 5].
|
104 |
+
"""
|
105 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
106 |
+
w = x2 - x1 + 1.0
|
107 |
+
h = y2 - y1 + 1.0
|
108 |
+
w = np.expand_dims(w, 1)
|
109 |
+
h = np.expand_dims(h, 1)
|
110 |
+
|
111 |
+
# this is what happening here:
|
112 |
+
# tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
|
113 |
+
# x1_true = x1 + tx1*w
|
114 |
+
# y1_true = y1 + ty1*h
|
115 |
+
# x2_true = x2 + tx2*w
|
116 |
+
# y2_true = y2 + ty2*h
|
117 |
+
# below is just more compact form of this
|
118 |
+
|
119 |
+
# are offsets always such that
|
120 |
+
# x1 < x2 and y1 < y2 ?
|
121 |
+
|
122 |
+
translation = np.hstack([w, h, w, h]) * offsets
|
123 |
+
bboxes[:, 0:4] = bboxes[:, 0:4] + translation
|
124 |
+
return bboxes
|
125 |
+
|
126 |
+
|
127 |
+
def get_image_boxes(bounding_boxes, img, size=24):
|
128 |
+
"""Cut out boxes from the image.
|
129 |
+
|
130 |
+
Arguments:
|
131 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
132 |
+
img: an instance of PIL.Image.
|
133 |
+
size: an integer, size of cutouts.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
a float numpy array of shape [n, 3, size, size].
|
137 |
+
"""
|
138 |
+
|
139 |
+
num_boxes = len(bounding_boxes)
|
140 |
+
width, height = img.size
|
141 |
+
|
142 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
|
143 |
+
img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
|
144 |
+
|
145 |
+
for i in range(num_boxes):
|
146 |
+
img_box = np.zeros((h[i], w[i], 3), 'uint8')
|
147 |
+
|
148 |
+
img_array = np.asarray(img, 'uint8')
|
149 |
+
img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
|
150 |
+
img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
|
151 |
+
|
152 |
+
# resize
|
153 |
+
img_box = Image.fromarray(img_box)
|
154 |
+
img_box = img_box.resize((size, size), Image.BILINEAR)
|
155 |
+
img_box = np.asarray(img_box, 'float32')
|
156 |
+
|
157 |
+
img_boxes[i, :, :, :] = _preprocess(img_box)
|
158 |
+
|
159 |
+
return img_boxes
|
160 |
+
|
161 |
+
|
162 |
+
def correct_bboxes(bboxes, width, height):
|
163 |
+
"""Crop boxes that are too big and get coordinates
|
164 |
+
with respect to cutouts.
|
165 |
+
|
166 |
+
Arguments:
|
167 |
+
bboxes: a float numpy array of shape [n, 5],
|
168 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
169 |
+
width: a float number.
|
170 |
+
height: a float number.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
dy, dx, edy, edx: a int numpy arrays of shape [n],
|
174 |
+
coordinates of the boxes with respect to the cutouts.
|
175 |
+
y, x, ey, ex: a int numpy arrays of shape [n],
|
176 |
+
corrected ymin, xmin, ymax, xmax.
|
177 |
+
h, w: a int numpy arrays of shape [n],
|
178 |
+
just heights and widths of boxes.
|
179 |
+
|
180 |
+
in the following order:
|
181 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h].
|
182 |
+
"""
|
183 |
+
|
184 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
185 |
+
w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
|
186 |
+
num_boxes = bboxes.shape[0]
|
187 |
+
|
188 |
+
# 'e' stands for end
|
189 |
+
# (x, y) -> (ex, ey)
|
190 |
+
x, y, ex, ey = x1, y1, x2, y2
|
191 |
+
|
192 |
+
# we need to cut out a box from the image.
|
193 |
+
# (x, y, ex, ey) are corrected coordinates of the box
|
194 |
+
# in the image.
|
195 |
+
# (dx, dy, edx, edy) are coordinates of the box in the cutout
|
196 |
+
# from the image.
|
197 |
+
dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
|
198 |
+
edx, edy = w.copy() - 1.0, h.copy() - 1.0
|
199 |
+
|
200 |
+
# if box's bottom right corner is too far right
|
201 |
+
ind = np.where(ex > width - 1.0)[0]
|
202 |
+
edx[ind] = w[ind] + width - 2.0 - ex[ind]
|
203 |
+
ex[ind] = width - 1.0
|
204 |
+
|
205 |
+
# if box's bottom right corner is too low
|
206 |
+
ind = np.where(ey > height - 1.0)[0]
|
207 |
+
edy[ind] = h[ind] + height - 2.0 - ey[ind]
|
208 |
+
ey[ind] = height - 1.0
|
209 |
+
|
210 |
+
# if box's top left corner is too far left
|
211 |
+
ind = np.where(x < 0.0)[0]
|
212 |
+
dx[ind] = 0.0 - x[ind]
|
213 |
+
x[ind] = 0.0
|
214 |
+
|
215 |
+
# if box's top left corner is too high
|
216 |
+
ind = np.where(y < 0.0)[0]
|
217 |
+
dy[ind] = 0.0 - y[ind]
|
218 |
+
y[ind] = 0.0
|
219 |
+
|
220 |
+
return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
|
221 |
+
return_list = [i.astype('int32') for i in return_list]
|
222 |
+
|
223 |
+
return return_list
|
224 |
+
|
225 |
+
|
226 |
+
def _preprocess(img):
|
227 |
+
"""Preprocessing step before feeding the network.
|
228 |
+
|
229 |
+
Arguments:
|
230 |
+
img: a float numpy array of shape [h, w, c].
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
a float numpy array of shape [1, c, h, w].
|
234 |
+
"""
|
235 |
+
img = img.transpose((2, 0, 1))
|
236 |
+
img = np.expand_dims(img, 0)
|
237 |
+
img = (img - 127.5) * 0.0078125
|
238 |
+
return img
|
models/mtcnn/mtcnn_pytorch/src/detector.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from .get_nets import PNet, RNet, ONet
|
5 |
+
from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from .first_stage import run_first_stage
|
7 |
+
|
8 |
+
|
9 |
+
def detect_faces(image, min_face_size=20.0,
|
10 |
+
thresholds=[0.6, 0.7, 0.8],
|
11 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
12 |
+
"""
|
13 |
+
Arguments:
|
14 |
+
image: an instance of PIL.Image.
|
15 |
+
min_face_size: a float number.
|
16 |
+
thresholds: a list of length 3.
|
17 |
+
nms_thresholds: a list of length 3.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
21 |
+
bounding boxes and facial landmarks.
|
22 |
+
"""
|
23 |
+
|
24 |
+
# LOAD MODELS
|
25 |
+
pnet = PNet()
|
26 |
+
rnet = RNet()
|
27 |
+
onet = ONet()
|
28 |
+
onet.eval()
|
29 |
+
|
30 |
+
# BUILD AN IMAGE PYRAMID
|
31 |
+
width, height = image.size
|
32 |
+
min_length = min(height, width)
|
33 |
+
|
34 |
+
min_detection_size = 12
|
35 |
+
factor = 0.707 # sqrt(0.5)
|
36 |
+
|
37 |
+
# scales for scaling the image
|
38 |
+
scales = []
|
39 |
+
|
40 |
+
# scales the image so that
|
41 |
+
# minimum size that we can detect equals to
|
42 |
+
# minimum face size that we want to detect
|
43 |
+
m = min_detection_size / min_face_size
|
44 |
+
min_length *= m
|
45 |
+
|
46 |
+
factor_count = 0
|
47 |
+
while min_length > min_detection_size:
|
48 |
+
scales.append(m * factor ** factor_count)
|
49 |
+
min_length *= factor
|
50 |
+
factor_count += 1
|
51 |
+
|
52 |
+
# STAGE 1
|
53 |
+
|
54 |
+
# it will be returned
|
55 |
+
bounding_boxes = []
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
# run P-Net on different scales
|
59 |
+
for s in scales:
|
60 |
+
boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
|
61 |
+
bounding_boxes.append(boxes)
|
62 |
+
|
63 |
+
# collect boxes (and offsets, and scores) from different scales
|
64 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
65 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
66 |
+
|
67 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
68 |
+
bounding_boxes = bounding_boxes[keep]
|
69 |
+
|
70 |
+
# use offsets predicted by pnet to transform bounding boxes
|
71 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
72 |
+
# shape [n_boxes, 5]
|
73 |
+
|
74 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
75 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
76 |
+
|
77 |
+
# STAGE 2
|
78 |
+
|
79 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
80 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
81 |
+
|
82 |
+
output = rnet(img_boxes)
|
83 |
+
offsets = output[0].data.numpy() # shape [n_boxes, 4]
|
84 |
+
probs = output[1].data.numpy() # shape [n_boxes, 2]
|
85 |
+
|
86 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
87 |
+
bounding_boxes = bounding_boxes[keep]
|
88 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
89 |
+
offsets = offsets[keep]
|
90 |
+
|
91 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
92 |
+
bounding_boxes = bounding_boxes[keep]
|
93 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
94 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
95 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
96 |
+
|
97 |
+
# STAGE 3
|
98 |
+
|
99 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
100 |
+
if len(img_boxes) == 0:
|
101 |
+
return [], []
|
102 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
103 |
+
output = onet(img_boxes)
|
104 |
+
landmarks = output[0].data.numpy() # shape [n_boxes, 10]
|
105 |
+
offsets = output[1].data.numpy() # shape [n_boxes, 4]
|
106 |
+
probs = output[2].data.numpy() # shape [n_boxes, 2]
|
107 |
+
|
108 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
109 |
+
bounding_boxes = bounding_boxes[keep]
|
110 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
111 |
+
offsets = offsets[keep]
|
112 |
+
landmarks = landmarks[keep]
|
113 |
+
|
114 |
+
# compute landmark points
|
115 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
116 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
117 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
118 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
119 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
120 |
+
|
121 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
122 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
123 |
+
bounding_boxes = bounding_boxes[keep]
|
124 |
+
landmarks = landmarks[keep]
|
125 |
+
|
126 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/src/first_stage.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from .box_utils import nms, _preprocess
|
7 |
+
|
8 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
def run_first_stage(image, net, scale, threshold):
|
13 |
+
"""Run P-Net, generate bounding boxes, and do NMS.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
image: an instance of PIL.Image.
|
17 |
+
net: an instance of pytorch's nn.Module, P-Net.
|
18 |
+
scale: a float number,
|
19 |
+
scale width and height of the image by this number.
|
20 |
+
threshold: a float number,
|
21 |
+
threshold on the probability of a face when generating
|
22 |
+
bounding boxes from predictions of the net.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
a float numpy array of shape [n_boxes, 9],
|
26 |
+
bounding boxes with scores and offsets (4 + 1 + 4).
|
27 |
+
"""
|
28 |
+
|
29 |
+
# scale the image and convert it to a float array
|
30 |
+
width, height = image.size
|
31 |
+
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
|
32 |
+
img = image.resize((sw, sh), Image.BILINEAR)
|
33 |
+
img = np.asarray(img, 'float32')
|
34 |
+
|
35 |
+
img = torch.FloatTensor(_preprocess(img)).to(device)
|
36 |
+
with torch.no_grad():
|
37 |
+
output = net(img)
|
38 |
+
probs = output[1].cpu().data.numpy()[0, 1, :, :]
|
39 |
+
offsets = output[0].cpu().data.numpy()
|
40 |
+
# probs: probability of a face at each sliding window
|
41 |
+
# offsets: transformations to true bounding boxes
|
42 |
+
|
43 |
+
boxes = _generate_bboxes(probs, offsets, scale, threshold)
|
44 |
+
if len(boxes) == 0:
|
45 |
+
return None
|
46 |
+
|
47 |
+
keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
|
48 |
+
return boxes[keep]
|
49 |
+
|
50 |
+
|
51 |
+
def _generate_bboxes(probs, offsets, scale, threshold):
|
52 |
+
"""Generate bounding boxes at places
|
53 |
+
where there is probably a face.
|
54 |
+
|
55 |
+
Arguments:
|
56 |
+
probs: a float numpy array of shape [n, m].
|
57 |
+
offsets: a float numpy array of shape [1, 4, n, m].
|
58 |
+
scale: a float number,
|
59 |
+
width and height of the image were scaled by this number.
|
60 |
+
threshold: a float number.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
a float numpy array of shape [n_boxes, 9]
|
64 |
+
"""
|
65 |
+
|
66 |
+
# applying P-Net is equivalent, in some sense, to
|
67 |
+
# moving 12x12 window with stride 2
|
68 |
+
stride = 2
|
69 |
+
cell_size = 12
|
70 |
+
|
71 |
+
# indices of boxes where there is probably a face
|
72 |
+
inds = np.where(probs > threshold)
|
73 |
+
|
74 |
+
if inds[0].size == 0:
|
75 |
+
return np.array([])
|
76 |
+
|
77 |
+
# transformations of bounding boxes
|
78 |
+
tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
|
79 |
+
# they are defined as:
|
80 |
+
# w = x2 - x1 + 1
|
81 |
+
# h = y2 - y1 + 1
|
82 |
+
# x1_true = x1 + tx1*w
|
83 |
+
# x2_true = x2 + tx2*w
|
84 |
+
# y1_true = y1 + ty1*h
|
85 |
+
# y2_true = y2 + ty2*h
|
86 |
+
|
87 |
+
offsets = np.array([tx1, ty1, tx2, ty2])
|
88 |
+
score = probs[inds[0], inds[1]]
|
89 |
+
|
90 |
+
# P-Net is applied to scaled images
|
91 |
+
# so we need to rescale bounding boxes back
|
92 |
+
bounding_boxes = np.vstack([
|
93 |
+
np.round((stride * inds[1] + 1.0) / scale),
|
94 |
+
np.round((stride * inds[0] + 1.0) / scale),
|
95 |
+
np.round((stride * inds[1] + 1.0 + cell_size) / scale),
|
96 |
+
np.round((stride * inds[0] + 1.0 + cell_size) / scale),
|
97 |
+
score, offsets
|
98 |
+
])
|
99 |
+
# why one is added?
|
100 |
+
|
101 |
+
return bounding_boxes.T
|
models/mtcnn/mtcnn_pytorch/src/get_nets.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import OrderedDict
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from configs.paths_config import model_paths
|
8 |
+
PNET_PATH = model_paths["mtcnn_pnet"]
|
9 |
+
ONET_PATH = model_paths["mtcnn_onet"]
|
10 |
+
RNET_PATH = model_paths["mtcnn_rnet"]
|
11 |
+
|
12 |
+
|
13 |
+
class Flatten(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(Flatten, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
"""
|
20 |
+
Arguments:
|
21 |
+
x: a float tensor with shape [batch_size, c, h, w].
|
22 |
+
Returns:
|
23 |
+
a float tensor with shape [batch_size, c*h*w].
|
24 |
+
"""
|
25 |
+
|
26 |
+
# without this pretrained model isn't working
|
27 |
+
x = x.transpose(3, 2).contiguous()
|
28 |
+
|
29 |
+
return x.view(x.size(0), -1)
|
30 |
+
|
31 |
+
|
32 |
+
class PNet(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
# suppose we have input with size HxW, then
|
38 |
+
# after first layer: H - 2,
|
39 |
+
# after pool: ceil((H - 2)/2),
|
40 |
+
# after second conv: ceil((H - 2)/2) - 2,
|
41 |
+
# after last conv: ceil((H - 2)/2) - 4,
|
42 |
+
# and the same for W
|
43 |
+
|
44 |
+
self.features = nn.Sequential(OrderedDict([
|
45 |
+
('conv1', nn.Conv2d(3, 10, 3, 1)),
|
46 |
+
('prelu1', nn.PReLU(10)),
|
47 |
+
('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
48 |
+
|
49 |
+
('conv2', nn.Conv2d(10, 16, 3, 1)),
|
50 |
+
('prelu2', nn.PReLU(16)),
|
51 |
+
|
52 |
+
('conv3', nn.Conv2d(16, 32, 3, 1)),
|
53 |
+
('prelu3', nn.PReLU(32))
|
54 |
+
]))
|
55 |
+
|
56 |
+
self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
|
57 |
+
self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
|
58 |
+
|
59 |
+
weights = np.load(PNET_PATH, allow_pickle=True)[()]
|
60 |
+
for n, p in self.named_parameters():
|
61 |
+
p.data = torch.FloatTensor(weights[n])
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
"""
|
65 |
+
Arguments:
|
66 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
67 |
+
Returns:
|
68 |
+
b: a float tensor with shape [batch_size, 4, h', w'].
|
69 |
+
a: a float tensor with shape [batch_size, 2, h', w'].
|
70 |
+
"""
|
71 |
+
x = self.features(x)
|
72 |
+
a = self.conv4_1(x)
|
73 |
+
b = self.conv4_2(x)
|
74 |
+
a = F.softmax(a, dim=-1)
|
75 |
+
return b, a
|
76 |
+
|
77 |
+
|
78 |
+
class RNet(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.features = nn.Sequential(OrderedDict([
|
84 |
+
('conv1', nn.Conv2d(3, 28, 3, 1)),
|
85 |
+
('prelu1', nn.PReLU(28)),
|
86 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
87 |
+
|
88 |
+
('conv2', nn.Conv2d(28, 48, 3, 1)),
|
89 |
+
('prelu2', nn.PReLU(48)),
|
90 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
91 |
+
|
92 |
+
('conv3', nn.Conv2d(48, 64, 2, 1)),
|
93 |
+
('prelu3', nn.PReLU(64)),
|
94 |
+
|
95 |
+
('flatten', Flatten()),
|
96 |
+
('conv4', nn.Linear(576, 128)),
|
97 |
+
('prelu4', nn.PReLU(128))
|
98 |
+
]))
|
99 |
+
|
100 |
+
self.conv5_1 = nn.Linear(128, 2)
|
101 |
+
self.conv5_2 = nn.Linear(128, 4)
|
102 |
+
|
103 |
+
weights = np.load(RNET_PATH, allow_pickle=True)[()]
|
104 |
+
for n, p in self.named_parameters():
|
105 |
+
p.data = torch.FloatTensor(weights[n])
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
"""
|
109 |
+
Arguments:
|
110 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
111 |
+
Returns:
|
112 |
+
b: a float tensor with shape [batch_size, 4].
|
113 |
+
a: a float tensor with shape [batch_size, 2].
|
114 |
+
"""
|
115 |
+
x = self.features(x)
|
116 |
+
a = self.conv5_1(x)
|
117 |
+
b = self.conv5_2(x)
|
118 |
+
a = F.softmax(a, dim=-1)
|
119 |
+
return b, a
|
120 |
+
|
121 |
+
|
122 |
+
class ONet(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.features = nn.Sequential(OrderedDict([
|
128 |
+
('conv1', nn.Conv2d(3, 32, 3, 1)),
|
129 |
+
('prelu1', nn.PReLU(32)),
|
130 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
131 |
+
|
132 |
+
('conv2', nn.Conv2d(32, 64, 3, 1)),
|
133 |
+
('prelu2', nn.PReLU(64)),
|
134 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
135 |
+
|
136 |
+
('conv3', nn.Conv2d(64, 64, 3, 1)),
|
137 |
+
('prelu3', nn.PReLU(64)),
|
138 |
+
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
139 |
+
|
140 |
+
('conv4', nn.Conv2d(64, 128, 2, 1)),
|
141 |
+
('prelu4', nn.PReLU(128)),
|
142 |
+
|
143 |
+
('flatten', Flatten()),
|
144 |
+
('conv5', nn.Linear(1152, 256)),
|
145 |
+
('drop5', nn.Dropout(0.25)),
|
146 |
+
('prelu5', nn.PReLU(256)),
|
147 |
+
]))
|
148 |
+
|
149 |
+
self.conv6_1 = nn.Linear(256, 2)
|
150 |
+
self.conv6_2 = nn.Linear(256, 4)
|
151 |
+
self.conv6_3 = nn.Linear(256, 10)
|
152 |
+
|
153 |
+
weights = np.load(ONET_PATH, allow_pickle=True)[()]
|
154 |
+
for n, p in self.named_parameters():
|
155 |
+
p.data = torch.FloatTensor(weights[n])
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
"""
|
159 |
+
Arguments:
|
160 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
161 |
+
Returns:
|
162 |
+
c: a float tensor with shape [batch_size, 10].
|
163 |
+
b: a float tensor with shape [batch_size, 4].
|
164 |
+
a: a float tensor with shape [batch_size, 2].
|
165 |
+
"""
|
166 |
+
x = self.features(x)
|
167 |
+
a = self.conv6_1(x)
|
168 |
+
b = self.conv6_2(x)
|
169 |
+
c = self.conv6_3(x)
|
170 |
+
a = F.softmax(a, dim=-1)
|
171 |
+
return c, b, a
|
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Jul 11 06:54:28 2017
|
4 |
+
|
5 |
+
@author: zhaoyafei
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from numpy.linalg import inv, norm, lstsq
|
10 |
+
from numpy.linalg import matrix_rank as rank
|
11 |
+
|
12 |
+
|
13 |
+
class MatlabCp2tormException(Exception):
|
14 |
+
def __str__(self):
|
15 |
+
return 'In File {}:{}'.format(
|
16 |
+
__file__, super.__str__(self))
|
17 |
+
|
18 |
+
|
19 |
+
def tformfwd(trans, uv):
|
20 |
+
"""
|
21 |
+
Function:
|
22 |
+
----------
|
23 |
+
apply affine transform 'trans' to uv
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
----------
|
27 |
+
@trans: 3x3 np.array
|
28 |
+
transform matrix
|
29 |
+
@uv: Kx2 np.array
|
30 |
+
each row is a pair of coordinates (x, y)
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
----------
|
34 |
+
@xy: Kx2 np.array
|
35 |
+
each row is a pair of transformed coordinates (x, y)
|
36 |
+
"""
|
37 |
+
uv = np.hstack((
|
38 |
+
uv, np.ones((uv.shape[0], 1))
|
39 |
+
))
|
40 |
+
xy = np.dot(uv, trans)
|
41 |
+
xy = xy[:, 0:-1]
|
42 |
+
return xy
|
43 |
+
|
44 |
+
|
45 |
+
def tforminv(trans, uv):
|
46 |
+
"""
|
47 |
+
Function:
|
48 |
+
----------
|
49 |
+
apply the inverse of affine transform 'trans' to uv
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
----------
|
53 |
+
@trans: 3x3 np.array
|
54 |
+
transform matrix
|
55 |
+
@uv: Kx2 np.array
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
----------
|
60 |
+
@xy: Kx2 np.array
|
61 |
+
each row is a pair of inverse-transformed coordinates (x, y)
|
62 |
+
"""
|
63 |
+
Tinv = inv(trans)
|
64 |
+
xy = tformfwd(Tinv, uv)
|
65 |
+
return xy
|
66 |
+
|
67 |
+
|
68 |
+
def findNonreflectiveSimilarity(uv, xy, options=None):
|
69 |
+
options = {'K': 2}
|
70 |
+
|
71 |
+
K = options['K']
|
72 |
+
M = xy.shape[0]
|
73 |
+
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
74 |
+
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
75 |
+
# print('--->x, y:\n', x, y
|
76 |
+
|
77 |
+
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
78 |
+
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
79 |
+
X = np.vstack((tmp1, tmp2))
|
80 |
+
# print('--->X.shape: ', X.shape
|
81 |
+
# print('X:\n', X
|
82 |
+
|
83 |
+
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
84 |
+
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
85 |
+
U = np.vstack((u, v))
|
86 |
+
# print('--->U.shape: ', U.shape
|
87 |
+
# print('U:\n', U
|
88 |
+
|
89 |
+
# We know that X * r = U
|
90 |
+
if rank(X) >= 2 * K:
|
91 |
+
r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
|
92 |
+
r = np.squeeze(r)
|
93 |
+
else:
|
94 |
+
raise Exception('cp2tform:twoUniquePointsReq')
|
95 |
+
|
96 |
+
# print('--->r:\n', r
|
97 |
+
|
98 |
+
sc = r[0]
|
99 |
+
ss = r[1]
|
100 |
+
tx = r[2]
|
101 |
+
ty = r[3]
|
102 |
+
|
103 |
+
Tinv = np.array([
|
104 |
+
[sc, -ss, 0],
|
105 |
+
[ss, sc, 0],
|
106 |
+
[tx, ty, 1]
|
107 |
+
])
|
108 |
+
|
109 |
+
# print('--->Tinv:\n', Tinv
|
110 |
+
|
111 |
+
T = inv(Tinv)
|
112 |
+
# print('--->T:\n', T
|
113 |
+
|
114 |
+
T[:, 2] = np.array([0, 0, 1])
|
115 |
+
|
116 |
+
return T, Tinv
|
117 |
+
|
118 |
+
|
119 |
+
def findSimilarity(uv, xy, options=None):
|
120 |
+
options = {'K': 2}
|
121 |
+
|
122 |
+
# uv = np.array(uv)
|
123 |
+
# xy = np.array(xy)
|
124 |
+
|
125 |
+
# Solve for trans1
|
126 |
+
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
127 |
+
|
128 |
+
# Solve for trans2
|
129 |
+
|
130 |
+
# manually reflect the xy data across the Y-axis
|
131 |
+
xyR = xy
|
132 |
+
xyR[:, 0] = -1 * xyR[:, 0]
|
133 |
+
|
134 |
+
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
135 |
+
|
136 |
+
# manually reflect the tform to undo the reflection done on xyR
|
137 |
+
TreflectY = np.array([
|
138 |
+
[-1, 0, 0],
|
139 |
+
[0, 1, 0],
|
140 |
+
[0, 0, 1]
|
141 |
+
])
|
142 |
+
|
143 |
+
trans2 = np.dot(trans2r, TreflectY)
|
144 |
+
|
145 |
+
# Figure out if trans1 or trans2 is better
|
146 |
+
xy1 = tformfwd(trans1, uv)
|
147 |
+
norm1 = norm(xy1 - xy)
|
148 |
+
|
149 |
+
xy2 = tformfwd(trans2, uv)
|
150 |
+
norm2 = norm(xy2 - xy)
|
151 |
+
|
152 |
+
if norm1 <= norm2:
|
153 |
+
return trans1, trans1_inv
|
154 |
+
else:
|
155 |
+
trans2_inv = inv(trans2)
|
156 |
+
return trans2, trans2_inv
|
157 |
+
|
158 |
+
|
159 |
+
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
160 |
+
"""
|
161 |
+
Function:
|
162 |
+
----------
|
163 |
+
Find Similarity Transform Matrix 'trans':
|
164 |
+
u = src_pts[:, 0]
|
165 |
+
v = src_pts[:, 1]
|
166 |
+
x = dst_pts[:, 0]
|
167 |
+
y = dst_pts[:, 1]
|
168 |
+
[x, y, 1] = [u, v, 1] * trans
|
169 |
+
|
170 |
+
Parameters:
|
171 |
+
----------
|
172 |
+
@src_pts: Kx2 np.array
|
173 |
+
source points, each row is a pair of coordinates (x, y)
|
174 |
+
@dst_pts: Kx2 np.array
|
175 |
+
destination points, each row is a pair of transformed
|
176 |
+
coordinates (x, y)
|
177 |
+
@reflective: True or False
|
178 |
+
if True:
|
179 |
+
use reflective similarity transform
|
180 |
+
else:
|
181 |
+
use non-reflective similarity transform
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
----------
|
185 |
+
@trans: 3x3 np.array
|
186 |
+
transform matrix from uv to xy
|
187 |
+
trans_inv: 3x3 np.array
|
188 |
+
inverse of trans, transform matrix from xy to uv
|
189 |
+
"""
|
190 |
+
|
191 |
+
if reflective:
|
192 |
+
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
193 |
+
else:
|
194 |
+
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
195 |
+
|
196 |
+
return trans, trans_inv
|
197 |
+
|
198 |
+
|
199 |
+
def cvt_tform_mat_for_cv2(trans):
|
200 |
+
"""
|
201 |
+
Function:
|
202 |
+
----------
|
203 |
+
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
204 |
+
directly used by cv2.warpAffine():
|
205 |
+
u = src_pts[:, 0]
|
206 |
+
v = src_pts[:, 1]
|
207 |
+
x = dst_pts[:, 0]
|
208 |
+
y = dst_pts[:, 1]
|
209 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
----------
|
213 |
+
@trans: 3x3 np.array
|
214 |
+
transform matrix from uv to xy
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
----------
|
218 |
+
@cv2_trans: 2x3 np.array
|
219 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
220 |
+
for cv2.warpAffine()
|
221 |
+
"""
|
222 |
+
cv2_trans = trans[:, 0:2].T
|
223 |
+
|
224 |
+
return cv2_trans
|
225 |
+
|
226 |
+
|
227 |
+
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
228 |
+
"""
|
229 |
+
Function:
|
230 |
+
----------
|
231 |
+
Find Similarity Transform Matrix 'cv2_trans' which could be
|
232 |
+
directly used by cv2.warpAffine():
|
233 |
+
u = src_pts[:, 0]
|
234 |
+
v = src_pts[:, 1]
|
235 |
+
x = dst_pts[:, 0]
|
236 |
+
y = dst_pts[:, 1]
|
237 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
----------
|
241 |
+
@src_pts: Kx2 np.array
|
242 |
+
source points, each row is a pair of coordinates (x, y)
|
243 |
+
@dst_pts: Kx2 np.array
|
244 |
+
destination points, each row is a pair of transformed
|
245 |
+
coordinates (x, y)
|
246 |
+
reflective: True or False
|
247 |
+
if True:
|
248 |
+
use reflective similarity transform
|
249 |
+
else:
|
250 |
+
use non-reflective similarity transform
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
----------
|
254 |
+
@cv2_trans: 2x3 np.array
|
255 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
256 |
+
for cv2.warpAffine()
|
257 |
+
"""
|
258 |
+
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
259 |
+
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
260 |
+
|
261 |
+
return cv2_trans
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
"""
|
266 |
+
u = [0, 6, -2]
|
267 |
+
v = [0, 3, 5]
|
268 |
+
x = [-1, 0, 4]
|
269 |
+
y = [-1, -10, 4]
|
270 |
+
|
271 |
+
# In Matlab, run:
|
272 |
+
#
|
273 |
+
# uv = [u'; v'];
|
274 |
+
# xy = [x'; y'];
|
275 |
+
# tform_sim=cp2tform(uv,xy,'similarity');
|
276 |
+
#
|
277 |
+
# trans = tform_sim.tdata.T
|
278 |
+
# ans =
|
279 |
+
# -0.0764 -1.6190 0
|
280 |
+
# 1.6190 -0.0764 0
|
281 |
+
# -3.2156 0.0290 1.0000
|
282 |
+
# trans_inv = tform_sim.tdata.Tinv
|
283 |
+
# ans =
|
284 |
+
#
|
285 |
+
# -0.0291 0.6163 0
|
286 |
+
# -0.6163 -0.0291 0
|
287 |
+
# -0.0756 1.9826 1.0000
|
288 |
+
# xy_m=tformfwd(tform_sim, u,v)
|
289 |
+
#
|
290 |
+
# xy_m =
|
291 |
+
#
|
292 |
+
# -3.2156 0.0290
|
293 |
+
# 1.1833 -9.9143
|
294 |
+
# 5.0323 2.8853
|
295 |
+
# uv_m=tforminv(tform_sim, x,y)
|
296 |
+
#
|
297 |
+
# uv_m =
|
298 |
+
#
|
299 |
+
# 0.5698 1.3953
|
300 |
+
# 6.0872 2.2733
|
301 |
+
# -2.6570 4.3314
|
302 |
+
"""
|
303 |
+
u = [0, 6, -2]
|
304 |
+
v = [0, 3, 5]
|
305 |
+
x = [-1, 0, 4]
|
306 |
+
y = [-1, -10, 4]
|
307 |
+
|
308 |
+
uv = np.array((u, v)).T
|
309 |
+
xy = np.array((x, y)).T
|
310 |
+
|
311 |
+
print('\n--->uv:')
|
312 |
+
print(uv)
|
313 |
+
print('\n--->xy:')
|
314 |
+
print(xy)
|
315 |
+
|
316 |
+
trans, trans_inv = get_similarity_transform(uv, xy)
|
317 |
+
|
318 |
+
print('\n--->trans matrix:')
|
319 |
+
print(trans)
|
320 |
+
|
321 |
+
print('\n--->trans_inv matrix:')
|
322 |
+
print(trans_inv)
|
323 |
+
|
324 |
+
print('\n---> apply transform to uv')
|
325 |
+
print('\nxy_m = uv_augmented * trans')
|
326 |
+
uv_aug = np.hstack((
|
327 |
+
uv, np.ones((uv.shape[0], 1))
|
328 |
+
))
|
329 |
+
xy_m = np.dot(uv_aug, trans)
|
330 |
+
print(xy_m)
|
331 |
+
|
332 |
+
print('\nxy_m = tformfwd(trans, uv)')
|
333 |
+
xy_m = tformfwd(trans, uv)
|
334 |
+
print(xy_m)
|
335 |
+
|
336 |
+
print('\n---> apply inverse transform to xy')
|
337 |
+
print('\nuv_m = xy_augmented * trans_inv')
|
338 |
+
xy_aug = np.hstack((
|
339 |
+
xy, np.ones((xy.shape[0], 1))
|
340 |
+
))
|
341 |
+
uv_m = np.dot(xy_aug, trans_inv)
|
342 |
+
print(uv_m)
|
343 |
+
|
344 |
+
print('\nuv_m = tformfwd(trans_inv, xy)')
|
345 |
+
uv_m = tformfwd(trans_inv, xy)
|
346 |
+
print(uv_m)
|
347 |
+
|
348 |
+
uv_m = tforminv(trans, xy)
|
349 |
+
print('\nuv_m = tforminv(trans, xy)')
|
350 |
+
print(uv_m)
|
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import ImageDraw
|
2 |
+
|
3 |
+
|
4 |
+
def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
|
5 |
+
"""Draw bounding boxes and facial landmarks.
|
6 |
+
|
7 |
+
Arguments:
|
8 |
+
img: an instance of PIL.Image.
|
9 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
10 |
+
facial_landmarks: a float numpy array of shape [n, 10].
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
an instance of PIL.Image.
|
14 |
+
"""
|
15 |
+
|
16 |
+
img_copy = img.copy()
|
17 |
+
draw = ImageDraw.Draw(img_copy)
|
18 |
+
|
19 |
+
for b in bounding_boxes:
|
20 |
+
draw.rectangle([
|
21 |
+
(b[0], b[1]), (b[2], b[3])
|
22 |
+
], outline='white')
|
23 |
+
|
24 |
+
for p in facial_landmarks:
|
25 |
+
for i in range(5):
|
26 |
+
draw.ellipse([
|
27 |
+
(p[i] - 1.0, p[i + 5] - 1.0),
|
28 |
+
(p[i] + 1.0, p[i + 5] + 1.0)
|
29 |
+
], outline='blue')
|
30 |
+
|
31 |
+
return img_copy
|
models/mtcnn/mtcnn_pytorch/src/weights/onet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:313141c3646bebb73cb8350a2d5fee4c7f044fb96304b46ccc21aeea8b818f83
|
3 |
+
size 2345483
|
models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03e19e5c473932ab38f5a6308fe6210624006994a687e858d1dcda53c66f18cb
|
3 |
+
size 41271
|
models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5660aad67688edc9e8a3dd4e47ed120932835e06a8a711a423252a6f2c747083
|
3 |
+
size 604651
|
models/psp.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines the core research contribution
|
3 |
+
"""
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from models.encoders import psp_encoders
|
11 |
+
from models.stylegan2.model import Generator
|
12 |
+
from configs.paths_config import model_paths
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
def get_keys(d, name):
|
16 |
+
if 'state_dict' in d:
|
17 |
+
d = d['state_dict']
|
18 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
19 |
+
return d_filt
|
20 |
+
|
21 |
+
|
22 |
+
class pSp(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, opts, ckpt=None):
|
25 |
+
super(pSp, self).__init__()
|
26 |
+
self.set_opts(opts)
|
27 |
+
# compute number of style inputs based on the output resolution
|
28 |
+
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
|
29 |
+
# Define architecture
|
30 |
+
self.encoder = self.set_encoder()
|
31 |
+
self.decoder = Generator(self.opts.output_size, 512, 8)
|
32 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
33 |
+
# Load weights if needed
|
34 |
+
self.load_weights(ckpt)
|
35 |
+
|
36 |
+
def set_encoder(self):
|
37 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
38 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
39 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
|
40 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
|
41 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
|
42 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
|
43 |
+
else:
|
44 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
45 |
+
return encoder
|
46 |
+
|
47 |
+
def load_weights(self, ckpt=None):
|
48 |
+
if self.opts.checkpoint_path is not None:
|
49 |
+
print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
|
50 |
+
if ckpt is None:
|
51 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
52 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
|
53 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
|
54 |
+
self.__load_latent_avg(ckpt)
|
55 |
+
else:
|
56 |
+
print('Loading encoders weights from irse50!')
|
57 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
58 |
+
# if input to encoder is not an RGB image, do not load the input layer weights
|
59 |
+
if self.opts.label_nc != 0:
|
60 |
+
encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
|
61 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
62 |
+
print('Loading decoder weights from pretrained!')
|
63 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
64 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
65 |
+
if self.opts.learn_in_w:
|
66 |
+
self.__load_latent_avg(ckpt, repeat=1)
|
67 |
+
else:
|
68 |
+
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
|
69 |
+
# for video toonification, we load G0' model
|
70 |
+
if self.opts.toonify_weights is not None: ##### modified
|
71 |
+
ckpt = torch.load(self.opts.toonify_weights)
|
72 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
73 |
+
self.opts.toonify_weights = None
|
74 |
+
|
75 |
+
# x1: image for first-layer feature f.
|
76 |
+
# x2: image for style latent code w+. If not specified, x2=x1.
|
77 |
+
# inject_latent: for sketch/mask-to-face translation, another latent code to fuse with w+
|
78 |
+
# latent_mask: fuse w+ and inject_latent with the mask (1~7 use w+ and 8~18 use inject_latent)
|
79 |
+
# use_feature: use f. Otherwise, use the orginal StyleGAN first-layer constant 4*4 feature
|
80 |
+
# first_layer_feature_ind: always=0, means the 1st layer of G accept f
|
81 |
+
# use_skip: use skip connection.
|
82 |
+
# zero_noise: use zero noises.
|
83 |
+
# editing_w: the editing vector v for video face editing
|
84 |
+
def forward(self, x1, x2=None, resize=True, latent_mask=None, randomize_noise=True,
|
85 |
+
inject_latent=None, return_latents=False, alpha=None, use_feature=True,
|
86 |
+
first_layer_feature_ind=0, use_skip=False, zero_noise=False, editing_w=None): ##### modified
|
87 |
+
|
88 |
+
feats = None # f and the skipped encoder features
|
89 |
+
codes, feats = self.encoder(x1, return_feat=True, return_full=use_skip) ##### modified
|
90 |
+
if x2 is not None: ##### modified
|
91 |
+
codes = self.encoder(x2) ##### modified
|
92 |
+
# normalize with respect to the center of an average face
|
93 |
+
if self.opts.start_from_latent_avg:
|
94 |
+
if self.opts.learn_in_w:
|
95 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
|
96 |
+
else:
|
97 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
98 |
+
|
99 |
+
# E_W^{1:7}(T(x1)) concatenate E_W^{8:18}(w~)
|
100 |
+
if latent_mask is not None:
|
101 |
+
for i in latent_mask:
|
102 |
+
if inject_latent is not None:
|
103 |
+
if alpha is not None:
|
104 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
105 |
+
else:
|
106 |
+
codes[:, i] = inject_latent[:, i]
|
107 |
+
else:
|
108 |
+
codes[:, i] = 0
|
109 |
+
|
110 |
+
first_layer_feats, skip_layer_feats, fusion = None, None, None ##### modified
|
111 |
+
if use_feature: ##### modified
|
112 |
+
first_layer_feats = feats[0:2] # use f
|
113 |
+
if use_skip: ##### modified
|
114 |
+
skip_layer_feats = feats[2:] # use skipped encoder feature
|
115 |
+
fusion = self.encoder.fusion # use fusion layer to fuse encoder feature and decoder feature.
|
116 |
+
|
117 |
+
images, result_latent = self.decoder([codes],
|
118 |
+
input_is_latent=True,
|
119 |
+
randomize_noise=randomize_noise,
|
120 |
+
return_latents=return_latents,
|
121 |
+
first_layer_feature=first_layer_feats,
|
122 |
+
first_layer_feature_ind=first_layer_feature_ind,
|
123 |
+
skip_layer_feature=skip_layer_feats,
|
124 |
+
fusion_block=fusion,
|
125 |
+
zero_noise=zero_noise,
|
126 |
+
editing_w=editing_w) ##### modified
|
127 |
+
|
128 |
+
if resize:
|
129 |
+
if self.opts.output_size == 1024: ##### modified
|
130 |
+
images = F.adaptive_avg_pool2d(images, (images.shape[2]//4, images.shape[3]//4)) ##### modified
|
131 |
+
else:
|
132 |
+
images = self.face_pool(images)
|
133 |
+
|
134 |
+
if return_latents:
|
135 |
+
return images, result_latent
|
136 |
+
else:
|
137 |
+
return images
|
138 |
+
|
139 |
+
def set_opts(self, opts):
|
140 |
+
self.opts = opts
|
141 |
+
|
142 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
143 |
+
if 'latent_avg' in ckpt:
|
144 |
+
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
|
145 |
+
if repeat is not None:
|
146 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
147 |
+
else:
|
148 |
+
self.latent_avg = None
|
models/stylegan2/__init__.py
ADDED
File without changes
|
models/stylegan2/lpips/__init__.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
#from skimage.measure import compare_ssim
|
8 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
9 |
+
import torch
|
10 |
+
from torch.autograd import Variable
|
11 |
+
|
12 |
+
from models.stylegan2.lpips import dist_model
|
13 |
+
|
14 |
+
class PerceptualLoss(torch.nn.Module):
|
15 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
16 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
17 |
+
super(PerceptualLoss, self).__init__()
|
18 |
+
print('Setting up Perceptual loss...')
|
19 |
+
self.use_gpu = use_gpu
|
20 |
+
self.spatial = spatial
|
21 |
+
self.gpu_ids = gpu_ids
|
22 |
+
self.model = dist_model.DistModel()
|
23 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
24 |
+
print('...[%s] initialized'%self.model.name())
|
25 |
+
print('...Done')
|
26 |
+
|
27 |
+
def forward(self, pred, target, normalize=False):
|
28 |
+
"""
|
29 |
+
Pred and target are Variables.
|
30 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
31 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
32 |
+
|
33 |
+
Inputs pred and target are Nx3xHxW
|
34 |
+
Output pytorch Variable N long
|
35 |
+
"""
|
36 |
+
|
37 |
+
if normalize:
|
38 |
+
target = 2 * target - 1
|
39 |
+
pred = 2 * pred - 1
|
40 |
+
|
41 |
+
return self.model.forward(target, pred)
|
42 |
+
|
43 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
44 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
45 |
+
return in_feat/(norm_factor+eps)
|
46 |
+
|
47 |
+
def l2(p0, p1, range=255.):
|
48 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
49 |
+
|
50 |
+
def psnr(p0, p1, peak=255.):
|
51 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
52 |
+
|
53 |
+
def dssim(p0, p1, range=255.):
|
54 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
55 |
+
|
56 |
+
def rgb2lab(in_img,mean_cent=False):
|
57 |
+
from skimage import color
|
58 |
+
img_lab = color.rgb2lab(in_img)
|
59 |
+
if(mean_cent):
|
60 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
61 |
+
return img_lab
|
62 |
+
|
63 |
+
def tensor2np(tensor_obj):
|
64 |
+
# change dimension of a tensor object into a numpy array
|
65 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
66 |
+
|
67 |
+
def np2tensor(np_obj):
|
68 |
+
# change dimenion of np array into tensor array
|
69 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
70 |
+
|
71 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
72 |
+
# image tensor to lab tensor
|
73 |
+
from skimage import color
|
74 |
+
|
75 |
+
img = tensor2im(image_tensor)
|
76 |
+
img_lab = color.rgb2lab(img)
|
77 |
+
if(mc_only):
|
78 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
79 |
+
if(to_norm and not mc_only):
|
80 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
81 |
+
img_lab = img_lab/100.
|
82 |
+
|
83 |
+
return np2tensor(img_lab)
|
84 |
+
|
85 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
86 |
+
from skimage import color
|
87 |
+
import warnings
|
88 |
+
warnings.filterwarnings("ignore")
|
89 |
+
|
90 |
+
lab = tensor2np(lab_tensor)*100.
|
91 |
+
lab[:,:,0] = lab[:,:,0]+50
|
92 |
+
|
93 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
94 |
+
if(return_inbnd):
|
95 |
+
# convert back to lab, see if we match
|
96 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
97 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
98 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
99 |
+
return (im2tensor(rgb_back),mask)
|
100 |
+
else:
|
101 |
+
return im2tensor(rgb_back)
|
102 |
+
|
103 |
+
def rgb2lab(input):
|
104 |
+
from skimage import color
|
105 |
+
return color.rgb2lab(input / 255.)
|
106 |
+
|
107 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
108 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
109 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
110 |
+
return image_numpy.astype(imtype)
|
111 |
+
|
112 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
113 |
+
return torch.Tensor((image / factor - cent)
|
114 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
115 |
+
|
116 |
+
def tensor2vec(vector_tensor):
|
117 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
118 |
+
|
119 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
120 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
121 |
+
Compute VOC AP given precision and recall.
|
122 |
+
If use_07_metric is true, uses the
|
123 |
+
VOC 07 11 point method (default:False).
|
124 |
+
"""
|
125 |
+
if use_07_metric:
|
126 |
+
# 11 point metric
|
127 |
+
ap = 0.
|
128 |
+
for t in np.arange(0., 1.1, 0.1):
|
129 |
+
if np.sum(rec >= t) == 0:
|
130 |
+
p = 0
|
131 |
+
else:
|
132 |
+
p = np.max(prec[rec >= t])
|
133 |
+
ap = ap + p / 11.
|
134 |
+
else:
|
135 |
+
# correct AP calculation
|
136 |
+
# first append sentinel values at the end
|
137 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
138 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
139 |
+
|
140 |
+
# compute the precision envelope
|
141 |
+
for i in range(mpre.size - 1, 0, -1):
|
142 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
143 |
+
|
144 |
+
# to calculate area under PR curve, look for points
|
145 |
+
# where X axis (recall) changes value
|
146 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
147 |
+
|
148 |
+
# and sum (\Delta recall) * prec
|
149 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
150 |
+
return ap
|
151 |
+
|
152 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
153 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
154 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
155 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
156 |
+
return image_numpy.astype(imtype)
|
157 |
+
|
158 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
159 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
160 |
+
return torch.Tensor((image / factor - cent)
|
161 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
models/stylegan2/lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
models/stylegan2/lpips/dist_model.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from models.stylegan2.lpips.base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from IPython import embed
|
20 |
+
|
21 |
+
from models.stylegan2.lpips import networks_basic as networks
|
22 |
+
import models.stylegan2.lpips as util
|
23 |
+
|
24 |
+
class DistModel(BaseModel):
|
25 |
+
def name(self):
|
26 |
+
return self.model_name
|
27 |
+
|
28 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
29 |
+
use_gpu=True, printNet=False, spatial=False,
|
30 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
31 |
+
'''
|
32 |
+
INPUTS
|
33 |
+
model - ['net-lin'] for linearly calibrated network
|
34 |
+
['net'] for off-the-shelf network
|
35 |
+
['L2'] for L2 distance in Lab colorspace
|
36 |
+
['SSIM'] for ssim in RGB colorspace
|
37 |
+
net - ['squeeze','alex','vgg']
|
38 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
39 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
40 |
+
use_gpu - bool - whether or not to use a GPU
|
41 |
+
printNet - bool - whether or not to print network architecture out
|
42 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
43 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
44 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
45 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
46 |
+
is_train - bool - [True] for training mode
|
47 |
+
lr - float - initial learning rate
|
48 |
+
beta1 - float - initial momentum term for adam
|
49 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
+
gpu_ids - int array - [0] by default, gpus to use
|
51 |
+
'''
|
52 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
53 |
+
|
54 |
+
self.model = model
|
55 |
+
self.net = net
|
56 |
+
self.is_train = is_train
|
57 |
+
self.spatial = spatial
|
58 |
+
self.gpu_ids = gpu_ids
|
59 |
+
self.model_name = '%s [%s]'%(model,net)
|
60 |
+
|
61 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
62 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
63 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
64 |
+
kw = {}
|
65 |
+
if not use_gpu:
|
66 |
+
kw['map_location'] = 'cpu'
|
67 |
+
if(model_path is None):
|
68 |
+
import inspect
|
69 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
70 |
+
|
71 |
+
if(not is_train):
|
72 |
+
print('Loading model from: %s'%model_path)
|
73 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
74 |
+
|
75 |
+
elif(self.model=='net'): # pretrained network
|
76 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
77 |
+
elif(self.model in ['L2','l2']):
|
78 |
+
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
79 |
+
self.model_name = 'L2'
|
80 |
+
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
81 |
+
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
82 |
+
self.model_name = 'SSIM'
|
83 |
+
else:
|
84 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
+
|
86 |
+
self.parameters = list(self.net.parameters())
|
87 |
+
|
88 |
+
if self.is_train: # training mode
|
89 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
+
self.rankLoss = networks.BCERankingLoss()
|
91 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
+
self.lr = lr
|
93 |
+
self.old_lr = lr
|
94 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
95 |
+
else: # test mode
|
96 |
+
self.net.eval()
|
97 |
+
|
98 |
+
if(use_gpu):
|
99 |
+
self.net.to(gpu_ids[0])
|
100 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
101 |
+
if(self.is_train):
|
102 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
103 |
+
|
104 |
+
if(printNet):
|
105 |
+
print('---------- Networks initialized -------------')
|
106 |
+
networks.print_network(self.net)
|
107 |
+
print('-----------------------------------------------')
|
108 |
+
|
109 |
+
def forward(self, in0, in1, retPerLayer=False):
|
110 |
+
''' Function computes the distance between image patches in0 and in1
|
111 |
+
INPUTS
|
112 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
113 |
+
OUTPUT
|
114 |
+
computed distances between in0 and in1
|
115 |
+
'''
|
116 |
+
|
117 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
118 |
+
|
119 |
+
# ***** TRAINING FUNCTIONS *****
|
120 |
+
def optimize_parameters(self):
|
121 |
+
self.forward_train()
|
122 |
+
self.optimizer_net.zero_grad()
|
123 |
+
self.backward_train()
|
124 |
+
self.optimizer_net.step()
|
125 |
+
self.clamp_weights()
|
126 |
+
|
127 |
+
def clamp_weights(self):
|
128 |
+
for module in self.net.modules():
|
129 |
+
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
130 |
+
module.weight.data = torch.clamp(module.weight.data,min=0)
|
131 |
+
|
132 |
+
def set_input(self, data):
|
133 |
+
self.input_ref = data['ref']
|
134 |
+
self.input_p0 = data['p0']
|
135 |
+
self.input_p1 = data['p1']
|
136 |
+
self.input_judge = data['judge']
|
137 |
+
|
138 |
+
if(self.use_gpu):
|
139 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
140 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
141 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
142 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
143 |
+
|
144 |
+
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
145 |
+
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
146 |
+
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
147 |
+
|
148 |
+
def forward_train(self): # run forward pass
|
149 |
+
# print(self.net.module.scaling_layer.shift)
|
150 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
151 |
+
|
152 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
153 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
154 |
+
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
155 |
+
|
156 |
+
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
157 |
+
|
158 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
159 |
+
|
160 |
+
return self.loss_total
|
161 |
+
|
162 |
+
def backward_train(self):
|
163 |
+
torch.mean(self.loss_total).backward()
|
164 |
+
|
165 |
+
def compute_accuracy(self,d0,d1,judge):
|
166 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
167 |
+
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
168 |
+
judge_per = judge.cpu().numpy().flatten()
|
169 |
+
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
170 |
+
|
171 |
+
def get_current_errors(self):
|
172 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
173 |
+
('acc_r', self.acc_r)])
|
174 |
+
|
175 |
+
for key in retDict.keys():
|
176 |
+
retDict[key] = np.mean(retDict[key])
|
177 |
+
|
178 |
+
return retDict
|
179 |
+
|
180 |
+
def get_current_visuals(self):
|
181 |
+
zoom_factor = 256/self.var_ref.data.size()[2]
|
182 |
+
|
183 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
184 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
185 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
186 |
+
|
187 |
+
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
+
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
+
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
190 |
+
|
191 |
+
return OrderedDict([('ref', ref_img_vis),
|
192 |
+
('p0', p0_img_vis),
|
193 |
+
('p1', p1_img_vis)])
|
194 |
+
|
195 |
+
def save(self, path, label):
|
196 |
+
if(self.use_gpu):
|
197 |
+
self.save_network(self.net.module, path, '', label)
|
198 |
+
else:
|
199 |
+
self.save_network(self.net, path, '', label)
|
200 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
201 |
+
|
202 |
+
def update_learning_rate(self,nepoch_decay):
|
203 |
+
lrd = self.lr / nepoch_decay
|
204 |
+
lr = self.old_lr - lrd
|
205 |
+
|
206 |
+
for param_group in self.optimizer_net.param_groups:
|
207 |
+
param_group['lr'] = lr
|
208 |
+
|
209 |
+
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
210 |
+
self.old_lr = lr
|
211 |
+
|
212 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
213 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
214 |
+
distance function 'func' in dataset 'data_loader'
|
215 |
+
INPUTS
|
216 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
217 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
218 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
219 |
+
OUTPUTS
|
220 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
221 |
+
[1] - dictionary with following elements
|
222 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
223 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
224 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
225 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
226 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
227 |
+
CONSTS
|
228 |
+
N - number of test triplets in data_loader
|
229 |
+
'''
|
230 |
+
|
231 |
+
d0s = []
|
232 |
+
d1s = []
|
233 |
+
gts = []
|
234 |
+
|
235 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
236 |
+
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
237 |
+
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
238 |
+
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
239 |
+
|
240 |
+
d0s = np.array(d0s)
|
241 |
+
d1s = np.array(d1s)
|
242 |
+
gts = np.array(gts)
|
243 |
+
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
244 |
+
|
245 |
+
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
246 |
+
|
247 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
248 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
249 |
+
INPUTS
|
250 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
251 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
252 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
253 |
+
OUTPUTS
|
254 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
255 |
+
[1] - dictionary with following elements
|
256 |
+
ds - N array containing distances between two patches shown to human evaluator
|
257 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
258 |
+
CONSTS
|
259 |
+
N - number of test triplets in data_loader
|
260 |
+
'''
|
261 |
+
|
262 |
+
ds = []
|
263 |
+
gts = []
|
264 |
+
|
265 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
266 |
+
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
267 |
+
gts+=data['same'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
sames = np.array(gts)
|
270 |
+
ds = np.array(ds)
|
271 |
+
|
272 |
+
sorted_inds = np.argsort(ds)
|
273 |
+
ds_sorted = ds[sorted_inds]
|
274 |
+
sames_sorted = sames[sorted_inds]
|
275 |
+
|
276 |
+
TPs = np.cumsum(sames_sorted)
|
277 |
+
FPs = np.cumsum(1-sames_sorted)
|
278 |
+
FNs = np.sum(sames_sorted)-TPs
|
279 |
+
|
280 |
+
precs = TPs/(TPs+FPs)
|
281 |
+
recs = TPs/(TPs+FNs)
|
282 |
+
score = util.voc_ap(recs,precs)
|
283 |
+
|
284 |
+
return(score, dict(ds=ds,sames=sames))
|
models/stylegan2/lpips/networks_basic.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from models.stylegan2.lpips import pretrained_networks as pn
|
14 |
+
|
15 |
+
import models.stylegan2.lpips as util
|
16 |
+
|
17 |
+
def spatial_average(in_tens, keepdim=True):
|
18 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
+
|
20 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
21 |
+
in_H = in_tens.shape[2]
|
22 |
+
scale_factor = 1.*out_H/in_H
|
23 |
+
|
24 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
25 |
+
|
26 |
+
# Learned perceptual metric
|
27 |
+
class PNetLin(nn.Module):
|
28 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
29 |
+
super(PNetLin, self).__init__()
|
30 |
+
|
31 |
+
self.pnet_type = pnet_type
|
32 |
+
self.pnet_tune = pnet_tune
|
33 |
+
self.pnet_rand = pnet_rand
|
34 |
+
self.spatial = spatial
|
35 |
+
self.lpips = lpips
|
36 |
+
self.version = version
|
37 |
+
self.scaling_layer = ScalingLayer()
|
38 |
+
|
39 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
40 |
+
net_type = pn.vgg16
|
41 |
+
self.chns = [64,128,256,512,512]
|
42 |
+
elif(self.pnet_type=='alex'):
|
43 |
+
net_type = pn.alexnet
|
44 |
+
self.chns = [64,192,384,256,256]
|
45 |
+
elif(self.pnet_type=='squeeze'):
|
46 |
+
net_type = pn.squeezenet
|
47 |
+
self.chns = [64,128,256,384,384,512,512]
|
48 |
+
self.L = len(self.chns)
|
49 |
+
|
50 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
51 |
+
|
52 |
+
if(lpips):
|
53 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
54 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
55 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
56 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
57 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
58 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
59 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
60 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
61 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
62 |
+
self.lins+=[self.lin5,self.lin6]
|
63 |
+
|
64 |
+
def forward(self, in0, in1, retPerLayer=False):
|
65 |
+
# v0.0 - original release had a bug, where input was not scaled
|
66 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
67 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
68 |
+
feats0, feats1, diffs = {}, {}, {}
|
69 |
+
|
70 |
+
for kk in range(self.L):
|
71 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
72 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
73 |
+
|
74 |
+
if(self.lpips):
|
75 |
+
if(self.spatial):
|
76 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
77 |
+
else:
|
78 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
79 |
+
else:
|
80 |
+
if(self.spatial):
|
81 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
82 |
+
else:
|
83 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
84 |
+
|
85 |
+
val = res[0]
|
86 |
+
for l in range(1,self.L):
|
87 |
+
val += res[l]
|
88 |
+
|
89 |
+
if(retPerLayer):
|
90 |
+
return (val, res)
|
91 |
+
else:
|
92 |
+
return val
|
93 |
+
|
94 |
+
class ScalingLayer(nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super(ScalingLayer, self).__init__()
|
97 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
98 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
99 |
+
|
100 |
+
def forward(self, inp):
|
101 |
+
return (inp - self.shift) / self.scale
|
102 |
+
|
103 |
+
|
104 |
+
class NetLinLayer(nn.Module):
|
105 |
+
''' A single linear layer which does a 1x1 conv '''
|
106 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
107 |
+
super(NetLinLayer, self).__init__()
|
108 |
+
|
109 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
110 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
111 |
+
self.model = nn.Sequential(*layers)
|
112 |
+
|
113 |
+
|
114 |
+
class Dist2LogitLayer(nn.Module):
|
115 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
116 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
117 |
+
super(Dist2LogitLayer, self).__init__()
|
118 |
+
|
119 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
120 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
121 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
122 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
123 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
124 |
+
if(use_sigmoid):
|
125 |
+
layers += [nn.Sigmoid(),]
|
126 |
+
self.model = nn.Sequential(*layers)
|
127 |
+
|
128 |
+
def forward(self,d0,d1,eps=0.1):
|
129 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
130 |
+
|
131 |
+
class BCERankingLoss(nn.Module):
|
132 |
+
def __init__(self, chn_mid=32):
|
133 |
+
super(BCERankingLoss, self).__init__()
|
134 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
135 |
+
# self.parameters = list(self.net.parameters())
|
136 |
+
self.loss = torch.nn.BCELoss()
|
137 |
+
|
138 |
+
def forward(self, d0, d1, judge):
|
139 |
+
per = (judge+1.)/2.
|
140 |
+
self.logit = self.net.forward(d0,d1)
|
141 |
+
return self.loss(self.logit, per)
|
142 |
+
|
143 |
+
# L2, DSSIM metrics
|
144 |
+
class FakeNet(nn.Module):
|
145 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
146 |
+
super(FakeNet, self).__init__()
|
147 |
+
self.use_gpu = use_gpu
|
148 |
+
self.colorspace=colorspace
|
149 |
+
|
150 |
+
class L2(FakeNet):
|
151 |
+
|
152 |
+
def forward(self, in0, in1, retPerLayer=None):
|
153 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
154 |
+
|
155 |
+
if(self.colorspace=='RGB'):
|
156 |
+
(N,C,X,Y) = in0.size()
|
157 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
158 |
+
return value
|
159 |
+
elif(self.colorspace=='Lab'):
|
160 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
161 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
162 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
163 |
+
if(self.use_gpu):
|
164 |
+
ret_var = ret_var.cuda()
|
165 |
+
return ret_var
|
166 |
+
|
167 |
+
class DSSIM(FakeNet):
|
168 |
+
|
169 |
+
def forward(self, in0, in1, retPerLayer=None):
|
170 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
171 |
+
|
172 |
+
if(self.colorspace=='RGB'):
|
173 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
174 |
+
elif(self.colorspace=='Lab'):
|
175 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
176 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
177 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
178 |
+
if(self.use_gpu):
|
179 |
+
ret_var = ret_var.cuda()
|
180 |
+
return ret_var
|
181 |
+
|
182 |
+
def print_network(net):
|
183 |
+
num_params = 0
|
184 |
+
for param in net.parameters():
|
185 |
+
num_params += param.numel()
|
186 |
+
print('Network',net)
|
187 |
+
print('Total number of parameters: %d' % num_params)
|
models/stylegan2/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
models/stylegan2/lpips/weights/v0.0/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
+
size 5455
|
models/stylegan2/lpips/weights/v0.0/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
+
size 10057
|
models/stylegan2/lpips/weights/v0.0/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
+
size 6735
|
models/stylegan2/lpips/weights/v0.1/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|