Zaesar PKUWilliamYang commited on
Commit
0483f57
·
0 Parent(s):

Duplicate from PKUWilliamYang/StyleGANEX

Browse files

Co-authored-by: Shuai Yang <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. packages.txt +2 -0
  2. .gitattributes +34 -0
  3. README.md +10 -0
  4. app.py +112 -0
  5. configs/__init__.py +0 -0
  6. configs/data_configs.py +48 -0
  7. configs/dataset_config.yml +60 -0
  8. configs/paths_config.py +25 -0
  9. configs/transforms_config.py +242 -0
  10. datasets/__init__.py +0 -0
  11. datasets/augmentations.py +110 -0
  12. datasets/ffhq_degradation_dataset.py +235 -0
  13. datasets/gt_res_dataset.py +32 -0
  14. datasets/images_dataset.py +33 -0
  15. datasets/inference_dataset.py +22 -0
  16. latent_optimization.py +107 -0
  17. models/__init__.py +0 -0
  18. models/bisenet/LICENSE +21 -0
  19. models/bisenet/README.md +68 -0
  20. models/bisenet/model.py +283 -0
  21. models/bisenet/resnet.py +109 -0
  22. models/encoders/__init__.py +0 -0
  23. models/encoders/helpers.py +119 -0
  24. models/encoders/model_irse.py +84 -0
  25. models/encoders/psp_encoders.py +357 -0
  26. models/mtcnn/__init__.py +0 -0
  27. models/mtcnn/mtcnn.py +156 -0
  28. models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
  29. models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
  30. models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
  31. models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
  32. models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
  33. models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
  34. models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
  35. models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
  36. models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
  37. models/mtcnn/mtcnn_pytorch/src/weights/onet.npy +3 -0
  38. models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy +3 -0
  39. models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy +3 -0
  40. models/psp.py +148 -0
  41. models/stylegan2/__init__.py +0 -0
  42. models/stylegan2/lpips/__init__.py +161 -0
  43. models/stylegan2/lpips/base_model.py +58 -0
  44. models/stylegan2/lpips/dist_model.py +284 -0
  45. models/stylegan2/lpips/networks_basic.py +187 -0
  46. models/stylegan2/lpips/pretrained_networks.py +181 -0
  47. models/stylegan2/lpips/weights/v0.0/alex.pth +3 -0
  48. models/stylegan2/lpips/weights/v0.0/squeeze.pth +3 -0
  49. models/stylegan2/lpips/weights/v0.0/vgg.pth +3 -0
  50. models/stylegan2/lpips/weights/v0.1/alex.pth +3 -0
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ bzip2
2
+ cmake
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StyleGANEX
3
+ sdk: gradio
4
+ emoji: 🐨
5
+ colorFrom: pink
6
+ colorTo: yellow
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: PKUWilliamYang/StyleGANEX
10
+ ---
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import pathlib
5
+ import torch
6
+ import gradio as gr
7
+
8
+ from webUI.app_task import *
9
+ from webUI.styleganex_model import Model
10
+
11
+ def parse_args() -> argparse.Namespace:
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--device', type=str, default='cpu')
14
+ parser.add_argument('--theme', type=str)
15
+ parser.add_argument('--share', action='store_true')
16
+ parser.add_argument('--port', type=int)
17
+ parser.add_argument('--disable-queue',
18
+ dest='enable_queue',
19
+ action='store_false')
20
+ return parser.parse_args()
21
+
22
+ DESCRIPTION = '''
23
+ <div align=center>
24
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
25
+ Face Manipulation with <a href="https://github.com/williamyang1991/StyleGANEX">StyleGANEX</a>
26
+ </h1>
27
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
28
+ <a href="https://huggingface.co/spaces/PKUWilliamYang/StyleGANEX?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
29
+ <p/>
30
+ <img style="margin-top: 0em" src="https://raw.githubusercontent.com/williamyang1991/tmpfile/master/imgs/example.jpg" alt="example">
31
+ </div>
32
+ '''
33
+ ARTICLE = r"""
34
+ If StyleGANEX is helpful, please help to ⭐ the <a href='https://github.com/williamyang1991/StyleGANEX' target='_blank'>Github Repo</a>. Thanks!
35
+ [![GitHub Stars](https://img.shields.io/github/stars/williamyang1991/StyleGANEX?style=social)](https://github.com/williamyang1991/StyleGANEX)
36
+ ---
37
+ 📝 **Citation**
38
+ If our work is useful for your research, please consider citing:
39
+ ```bibtex
40
+ @article{yang2023styleganex,
41
+ title = {StyleGANEX: StyleGAN-Based Manipulation Beyond Cropped Aligned Faces},
42
+ author = {Yang, Shuai and Jiang, Liming and Liu, Ziwei and and Loy, Chen Change},
43
+ journal = {arXiv preprint arXiv:2303.06146},
44
+ year={2023},
45
+ }
46
+ ```
47
+ 📋 **License**
48
+ This project is licensed under <a rel="license" href="https://github.com/williamyang1991/VToonify/blob/main/LICENSE.md">S-Lab License 1.0</a>.
49
+ Redistribution and use for non-commercial purposes should follow this license.
50
+
51
+ 📧 **Contact**
52
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
53
+ """
54
+
55
+ FOOTER = '<div align=center><img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/styleganex" /></div>'
56
+
57
+ def main():
58
+ args = parse_args()
59
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
+ print('*** Now using %s.'%(args.device))
61
+ model = Model(device=args.device)
62
+
63
+
64
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/234_sketch.jpg',
65
+ '234_sketch.jpg')
66
+ torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/output/ILip77SbmOE_inversion.pt',
67
+ 'ILip77SbmOE_inversion.pt')
68
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE.png',
69
+ 'ILip77SbmOE.png')
70
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/ILip77SbmOE_mask.png',
71
+ 'ILip77SbmOE_mask.png')
72
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/williamyang1991/StyleGANEX/main/data/pexels-daniel-xavier-1239291.jpg',
73
+ 'pexels-daniel-xavier-1239291.jpg')
74
+ torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/529_2.mp4',
75
+ '529_2.mp4')
76
+ torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/684.mp4',
77
+ '684.mp4')
78
+ torch.hub.download_url_to_file('https://github.com/williamyang1991/StyleGANEX/raw/main/data/pexels-anthony-shkraba-production-8136210.mp4',
79
+ 'pexels-anthony-shkraba-production-8136210.mp4')
80
+
81
+
82
+ with gr.Blocks(css='style.css') as demo:
83
+ gr.Markdown(DESCRIPTION)
84
+ with gr.Tabs():
85
+ with gr.TabItem('Inversion for Editing'):
86
+ create_demo_inversion(model.process_inversion, allow_optimization=False)
87
+ with gr.TabItem('Image Face Toonify'):
88
+ create_demo_toonify(model.process_toonify)
89
+ with gr.TabItem('Video Face Toonify'):
90
+ create_demo_vtoonify(model.process_vtoonify, max_frame_num=12)
91
+ with gr.TabItem('Image Face Editing'):
92
+ create_demo_editing(model.process_editing)
93
+ with gr.TabItem('Video Face Editing'):
94
+ create_demo_vediting(model.process_vediting, max_frame_num=12)
95
+ with gr.TabItem('Sketch2Face'):
96
+ create_demo_s2f(model.process_s2f)
97
+ with gr.TabItem('Mask2Face'):
98
+ create_demo_m2f(model.process_m2f)
99
+ with gr.TabItem('SR'):
100
+ create_demo_sr(model.process_sr)
101
+ gr.Markdown(ARTICLE)
102
+ gr.Markdown(FOOTER)
103
+
104
+ demo.launch(
105
+ enable_queue=args.enable_queue,
106
+ server_port=args.port,
107
+ share=args.share,
108
+ )
109
+
110
+ if __name__ == '__main__':
111
+ main()
112
+
configs/__init__.py ADDED
File without changes
configs/data_configs.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['ffhq_test'],
11
+ 'test_target_root': dataset_paths['ffhq_test'],
12
+ },
13
+ 'ffhq_sketch_to_face': {
14
+ 'transforms': transforms_config.SketchToImageTransforms,
15
+ 'train_source_root': dataset_paths['ffhq_train_sketch'],
16
+ 'train_target_root': dataset_paths['ffhq'],
17
+ 'test_source_root': dataset_paths['ffhq_test_sketch'],
18
+ 'test_target_root': dataset_paths['ffhq_test'],
19
+ },
20
+ 'ffhq_seg_to_face': {
21
+ 'transforms': transforms_config.SegToImageTransforms,
22
+ 'train_source_root': dataset_paths['ffhq_train_segmentation'],
23
+ 'train_target_root': dataset_paths['ffhq'],
24
+ 'test_source_root': dataset_paths['ffhq_test_segmentation'],
25
+ 'test_target_root': dataset_paths['ffhq_test'],
26
+ },
27
+ 'ffhq_super_resolution': {
28
+ 'transforms': transforms_config.SuperResTransforms,
29
+ 'train_source_root': dataset_paths['ffhq'],
30
+ 'train_target_root': dataset_paths['ffhq1280'],
31
+ 'test_source_root': dataset_paths['ffhq_test'],
32
+ 'test_target_root': dataset_paths['ffhq1280_test'],
33
+ },
34
+ 'toonify': {
35
+ 'transforms': transforms_config.ToonifyTransforms,
36
+ 'train_source_root': dataset_paths['toonify_in'],
37
+ 'train_target_root': dataset_paths['toonify_out'],
38
+ 'test_source_root': dataset_paths['toonify_test_in'],
39
+ 'test_target_root': dataset_paths['toonify_test_out'],
40
+ },
41
+ 'ffhq_edit': {
42
+ 'transforms': transforms_config.EditingTransforms,
43
+ 'train_source_root': dataset_paths['ffhq'],
44
+ 'train_target_root': dataset_paths['ffhq'],
45
+ 'test_source_root': dataset_paths['ffhq_test'],
46
+ 'test_target_root': dataset_paths['ffhq_test'],
47
+ },
48
+ }
configs/dataset_config.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset and data loader settings
2
+ datasets:
3
+ train:
4
+ name: FFHQ
5
+ type: FFHQDegradationDataset
6
+ # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
7
+ dataroot_gt: ../../../../share/shuaiyang/ffhq/realign1280x1280test/
8
+ io_backend:
9
+ # type: lmdb
10
+ type: disk
11
+
12
+ use_hflip: true
13
+ mean: [0.5, 0.5, 0.5]
14
+ std: [0.5, 0.5, 0.5]
15
+ out_size: 1280
16
+ scale: 4
17
+
18
+ blur_kernel_size: 41
19
+ kernel_list: ['iso', 'aniso']
20
+ kernel_prob: [0.5, 0.5]
21
+ blur_sigma: [0.1, 10]
22
+ downsample_range: [4, 40]
23
+ noise_range: [0, 20]
24
+ jpeg_range: [60, 100]
25
+
26
+ # color jitter and gray
27
+ #color_jitter_prob: 0.3
28
+ #color_jitter_shift: 20
29
+ #color_jitter_pt_prob: 0.3
30
+ #gray_prob: 0.01
31
+
32
+ # If you do not want colorization, please set
33
+ color_jitter_prob: ~
34
+ color_jitter_pt_prob: ~
35
+ gray_prob: 0.01
36
+ gt_gray: True
37
+
38
+ crop_components: true
39
+ component_path: ./pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
40
+ eye_enlarge_ratio: 1.4
41
+
42
+ # data loader
43
+ use_shuffle: true
44
+ num_worker_per_gpu: 6
45
+ batch_size_per_gpu: 4
46
+ dataset_enlarge_ratio: 1
47
+ prefetch_mode: ~
48
+
49
+ val:
50
+ # Please modify accordingly to use your own validation
51
+ # Or comment the val block if do not need validation during training
52
+ name: validation
53
+ type: PairedImageDataset
54
+ dataroot_lq: datasets/faces/validation/input
55
+ dataroot_gt: datasets/faces/validation/reference
56
+ io_backend:
57
+ type: disk
58
+ mean: [0.5, 0.5, 0.5]
59
+ std: [0.5, 0.5, 0.5]
60
+ scale: 1
configs/paths_config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ 'ffhq': 'data/train/ffhq/realign320x320/',
3
+ 'ffhq_test': 'data/train/ffhq/realign320x320test/',
4
+ 'ffhq1280': 'data/train/ffhq/realign1280x1280/',
5
+ 'ffhq1280_test': 'data/train/ffhq/realign1280x1280test/',
6
+ 'ffhq_train_sketch': 'data/train/ffhq/realign640x640sketch/',
7
+ 'ffhq_test_sketch': 'data/train/ffhq/realign640x640sketchtest/',
8
+ 'ffhq_train_segmentation': 'data/train/ffhq/realign320x320mask/',
9
+ 'ffhq_test_segmentation': 'data/train/ffhq/realign320x320masktest/',
10
+ 'toonify_in': 'data/train/pixar/trainA/',
11
+ 'toonify_out': 'data/train/pixar/trainB/',
12
+ 'toonify_test_in': 'data/train/pixar/testA/',
13
+ 'toonify_test_out': 'data/train/testB/',
14
+ }
15
+
16
+ model_paths = {
17
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
18
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
19
+ 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
20
+ 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
21
+ 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
22
+ 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
23
+ 'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
24
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar'
25
+ }
configs/transforms_config.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+ from datasets import augmentations
4
+
5
+
6
+ class TransformsConfig(object):
7
+
8
+ def __init__(self, opts):
9
+ self.opts = opts
10
+
11
+ @abstractmethod
12
+ def get_transforms(self):
13
+ pass
14
+
15
+
16
+ class EncodeTransforms(TransformsConfig):
17
+
18
+ def __init__(self, opts):
19
+ super(EncodeTransforms, self).__init__(opts)
20
+
21
+ def get_transforms(self):
22
+ transforms_dict = {
23
+ 'transform_gt_train': transforms.Compose([
24
+ transforms.Resize((320, 320)),
25
+ transforms.RandomHorizontalFlip(0.5),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
28
+ 'transform_source': None,
29
+ 'transform_test': transforms.Compose([
30
+ transforms.Resize((320, 320)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
33
+ 'transform_inference': transforms.Compose([
34
+ transforms.Resize((320, 320)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
+ }
38
+ return transforms_dict
39
+
40
+
41
+ class FrontalizationTransforms(TransformsConfig):
42
+
43
+ def __init__(self, opts):
44
+ super(FrontalizationTransforms, self).__init__(opts)
45
+
46
+ def get_transforms(self):
47
+ transforms_dict = {
48
+ 'transform_gt_train': transforms.Compose([
49
+ transforms.Resize((256, 256)),
50
+ transforms.RandomHorizontalFlip(0.5),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
53
+ 'transform_source': transforms.Compose([
54
+ transforms.Resize((256, 256)),
55
+ transforms.RandomHorizontalFlip(0.5),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
58
+ 'transform_test': transforms.Compose([
59
+ transforms.Resize((256, 256)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
62
+ 'transform_inference': transforms.Compose([
63
+ transforms.Resize((256, 256)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
66
+ }
67
+ return transforms_dict
68
+
69
+
70
+ class SketchToImageTransforms(TransformsConfig):
71
+
72
+ def __init__(self, opts):
73
+ super(SketchToImageTransforms, self).__init__(opts)
74
+
75
+ def get_transforms(self):
76
+ transforms_dict = {
77
+ 'transform_gt_train': transforms.Compose([
78
+ transforms.Resize((320, 320)),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
81
+ 'transform_source': transforms.Compose([
82
+ transforms.Resize((320, 320)),
83
+ transforms.ToTensor()]),
84
+ 'transform_test': transforms.Compose([
85
+ transforms.Resize((320, 320)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
88
+ 'transform_inference': transforms.Compose([
89
+ transforms.Resize((320, 320)),
90
+ transforms.ToTensor()]),
91
+ }
92
+ return transforms_dict
93
+
94
+
95
+ class SegToImageTransforms(TransformsConfig):
96
+
97
+ def __init__(self, opts):
98
+ super(SegToImageTransforms, self).__init__(opts)
99
+
100
+ def get_transforms(self):
101
+ transforms_dict = {
102
+ 'transform_gt_train': transforms.Compose([
103
+ transforms.Resize((320, 320)),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
106
+ 'transform_source': transforms.Compose([
107
+ transforms.Resize((320, 320)),
108
+ augmentations.ToOneHot(self.opts.label_nc),
109
+ transforms.ToTensor()]),
110
+ 'transform_test': transforms.Compose([
111
+ transforms.Resize((320, 320)),
112
+ transforms.ToTensor(),
113
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
114
+ 'transform_inference': transforms.Compose([
115
+ transforms.Resize((320, 320)),
116
+ augmentations.ToOneHot(self.opts.label_nc),
117
+ transforms.ToTensor()])
118
+ }
119
+ return transforms_dict
120
+
121
+
122
+ class SuperResTransforms(TransformsConfig):
123
+
124
+ def __init__(self, opts):
125
+ super(SuperResTransforms, self).__init__(opts)
126
+
127
+ def get_transforms(self):
128
+ if self.opts.resize_factors is None:
129
+ self.opts.resize_factors = '1,2,4,8,16,32'
130
+ factors = [int(f) for f in self.opts.resize_factors.split(",")]
131
+ print("Performing down-sampling with factors: {}".format(factors))
132
+ transforms_dict = {
133
+ 'transform_gt_train': transforms.Compose([
134
+ transforms.Resize((1280, 1280)),
135
+ transforms.ToTensor(),
136
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
137
+ 'transform_source': transforms.Compose([
138
+ transforms.Resize((320, 320)),
139
+ augmentations.BilinearResize(factors=factors),
140
+ transforms.Resize((320, 320)),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
143
+ 'transform_test': transforms.Compose([
144
+ transforms.Resize((1280, 1280)),
145
+ transforms.ToTensor(),
146
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
147
+ 'transform_inference': transforms.Compose([
148
+ transforms.Resize((320, 320)),
149
+ augmentations.BilinearResize(factors=factors),
150
+ transforms.Resize((320, 320)),
151
+ transforms.ToTensor(),
152
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
153
+ }
154
+ return transforms_dict
155
+
156
+
157
+ class SuperResTransforms_320(TransformsConfig):
158
+
159
+ def __init__(self, opts):
160
+ super(SuperResTransforms_320, self).__init__(opts)
161
+
162
+ def get_transforms(self):
163
+ if self.opts.resize_factors is None:
164
+ self.opts.resize_factors = '1,2,4,8,16,32'
165
+ factors = [int(f) for f in self.opts.resize_factors.split(",")]
166
+ print("Performing down-sampling with factors: {}".format(factors))
167
+ transforms_dict = {
168
+ 'transform_gt_train': transforms.Compose([
169
+ transforms.Resize((320, 320)),
170
+ transforms.ToTensor(),
171
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
172
+ 'transform_source': transforms.Compose([
173
+ transforms.Resize((320, 320)),
174
+ augmentations.BilinearResize(factors=factors),
175
+ transforms.Resize((320, 320)),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
178
+ 'transform_test': transforms.Compose([
179
+ transforms.Resize((320, 320)),
180
+ transforms.ToTensor(),
181
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
182
+ 'transform_inference': transforms.Compose([
183
+ transforms.Resize((320, 320)),
184
+ augmentations.BilinearResize(factors=factors),
185
+ transforms.Resize((320, 320)),
186
+ transforms.ToTensor(),
187
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
188
+ }
189
+ return transforms_dict
190
+
191
+
192
+ class ToonifyTransforms(TransformsConfig):
193
+
194
+ def __init__(self, opts):
195
+ super(ToonifyTransforms, self).__init__(opts)
196
+
197
+ def get_transforms(self):
198
+ transforms_dict = {
199
+ 'transform_gt_train': transforms.Compose([
200
+ transforms.Resize((1024, 1024)),
201
+ transforms.ToTensor(),
202
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
203
+ 'transform_source': transforms.Compose([
204
+ transforms.Resize((256, 256)),
205
+ transforms.ToTensor(),
206
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
207
+ 'transform_test': transforms.Compose([
208
+ transforms.Resize((1024, 1024)),
209
+ transforms.ToTensor(),
210
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
211
+ 'transform_inference': transforms.Compose([
212
+ transforms.Resize((256, 256)),
213
+ transforms.ToTensor(),
214
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
215
+ }
216
+ return transforms_dict
217
+
218
+ class EditingTransforms(TransformsConfig):
219
+
220
+ def __init__(self, opts):
221
+ super(EditingTransforms, self).__init__(opts)
222
+
223
+ def get_transforms(self):
224
+ transforms_dict = {
225
+ 'transform_gt_train': transforms.Compose([
226
+ transforms.Resize((1280, 1280)),
227
+ transforms.ToTensor(),
228
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
229
+ 'transform_source': transforms.Compose([
230
+ transforms.Resize((320, 320)),
231
+ transforms.ToTensor(),
232
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
233
+ 'transform_test': transforms.Compose([
234
+ transforms.Resize((1280, 1280)),
235
+ transforms.ToTensor(),
236
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
237
+ 'transform_inference': transforms.Compose([
238
+ transforms.Resize((320, 320)),
239
+ transforms.ToTensor(),
240
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
241
+ }
242
+ return transforms_dict
datasets/__init__.py ADDED
File without changes
datasets/augmentations.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torchvision import transforms
6
+
7
+
8
+ class ToOneHot(object):
9
+ """ Convert the input PIL image to a one-hot torch tensor """
10
+ def __init__(self, n_classes=None):
11
+ self.n_classes = n_classes
12
+
13
+ def onehot_initialization(self, a):
14
+ if self.n_classes is None:
15
+ self.n_classes = len(np.unique(a))
16
+ out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
17
+ out[self.__all_idx(a, axis=2)] = 1
18
+ return out
19
+
20
+ def __all_idx(self, idx, axis):
21
+ grid = np.ogrid[tuple(map(slice, idx.shape))]
22
+ grid.insert(axis, idx)
23
+ return tuple(grid)
24
+
25
+ def __call__(self, img):
26
+ img = np.array(img)
27
+ one_hot = self.onehot_initialization(img)
28
+ return one_hot
29
+
30
+
31
+ class BilinearResize(object):
32
+ def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
33
+ self.factors = factors
34
+
35
+ def __call__(self, image):
36
+ factor = np.random.choice(self.factors, size=1)[0]
37
+ D = BicubicDownSample(factor=factor, cuda=False)
38
+ img_tensor = transforms.ToTensor()(image).unsqueeze(0)
39
+ img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
40
+ img_low_res = transforms.ToPILImage()(img_tensor_lr)
41
+ return img_low_res
42
+
43
+
44
+ class BicubicDownSample(nn.Module):
45
+ def bicubic_kernel(self, x, a=-0.50):
46
+ """
47
+ This equation is exactly copied from the website below:
48
+ https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
49
+ """
50
+ abs_x = torch.abs(x)
51
+ if abs_x <= 1.:
52
+ return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
53
+ elif 1. < abs_x < 2.:
54
+ return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
55
+ else:
56
+ return 0.0
57
+
58
+ def __init__(self, factor=4, cuda=True, padding='reflect'):
59
+ super().__init__()
60
+ self.factor = factor
61
+ size = factor * 4
62
+ k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
63
+ for i in range(size)], dtype=torch.float32)
64
+ k = k / torch.sum(k)
65
+ k1 = torch.reshape(k, shape=(1, 1, size, 1))
66
+ self.k1 = torch.cat([k1, k1, k1], dim=0)
67
+ k2 = torch.reshape(k, shape=(1, 1, 1, size))
68
+ self.k2 = torch.cat([k2, k2, k2], dim=0)
69
+ self.cuda = '.cuda' if cuda else ''
70
+ self.padding = padding
71
+ for param in self.parameters():
72
+ param.requires_grad = False
73
+
74
+ def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
75
+ filter_height = self.factor * 4
76
+ filter_width = self.factor * 4
77
+ stride = self.factor
78
+
79
+ pad_along_height = max(filter_height - stride, 0)
80
+ pad_along_width = max(filter_width - stride, 0)
81
+ filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
82
+ filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
83
+
84
+ # compute actual padding values for each side
85
+ pad_top = pad_along_height // 2
86
+ pad_bottom = pad_along_height - pad_top
87
+ pad_left = pad_along_width // 2
88
+ pad_right = pad_along_width - pad_left
89
+
90
+ # apply mirror padding
91
+ if nhwc:
92
+ x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
93
+
94
+ # downscaling performed by 1-d convolution
95
+ x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
96
+ x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
97
+ if clip_round:
98
+ x = torch.clamp(torch.round(x), 0.0, 255.)
99
+
100
+ x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
101
+ x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
102
+ if clip_round:
103
+ x = torch.clamp(torch.round(x), 0.0, 255.)
104
+
105
+ if nhwc:
106
+ x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
107
+ if byte_output:
108
+ return x.type('torch.ByteTensor'.format(self.cuda))
109
+ else:
110
+ return x
datasets/ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os.path as osp
5
+ import torch
6
+ import torch.utils.data as data
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.data_util import paths_from_folder
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11
+ from basicsr.utils.registry import DATASET_REGISTRY
12
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13
+ normalize)
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
20
+ Args:
21
+ opt (dict): Config for train datasets. It contains the following keys:
22
+ dataroot_gt (str): Data root path for gt.
23
+ io_backend (dict): IO backend type and other kwarg.
24
+ mean (list | tuple): Image mean.
25
+ std (list | tuple): Image std.
26
+ use_hflip (bool): Whether to horizontally flip.
27
+ Please see more options in the codes.
28
+ """
29
+
30
+ def __init__(self, opt):
31
+ super(FFHQDegradationDataset, self).__init__()
32
+ self.opt = opt
33
+ # file client (io backend)
34
+ self.file_client = None
35
+ self.io_backend_opt = opt['io_backend']
36
+
37
+ self.gt_folder = opt['dataroot_gt']
38
+ self.mean = opt['mean']
39
+ self.std = opt['std']
40
+ self.out_size = opt['out_size']
41
+
42
+ self.crop_components = opt.get('crop_components', False) # facial components
43
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
44
+
45
+ if self.crop_components:
46
+ # load component list from a pre-process pth files
47
+ self.components_list = torch.load(opt.get('component_path'))
48
+
49
+ # file client (lmdb io backend)
50
+ if self.io_backend_opt['type'] == 'lmdb':
51
+ self.io_backend_opt['db_paths'] = self.gt_folder
52
+ if not self.gt_folder.endswith('.lmdb'):
53
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
54
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
55
+ self.paths = [line.split('.')[0] for line in fin]
56
+ else:
57
+ # disk backend: scan file list from a folder
58
+ self.paths = paths_from_folder(self.gt_folder)
59
+
60
+ # degradation configurations
61
+ self.blur_kernel_size = opt['blur_kernel_size']
62
+ self.kernel_list = opt['kernel_list']
63
+ self.kernel_prob = opt['kernel_prob']
64
+ self.blur_sigma = opt['blur_sigma']
65
+ self.downsample_range = opt['downsample_range']
66
+ self.noise_range = opt['noise_range']
67
+ self.jpeg_range = opt['jpeg_range']
68
+
69
+ # color jitter
70
+ self.color_jitter_prob = opt.get('color_jitter_prob')
71
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
72
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
73
+ # to gray
74
+ self.gray_prob = opt.get('gray_prob')
75
+
76
+ logger = get_root_logger()
77
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
78
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
79
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
80
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
81
+
82
+ if self.color_jitter_prob is not None:
83
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
84
+ if self.gray_prob is not None:
85
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
86
+ self.color_jitter_shift /= 255.
87
+
88
+ @staticmethod
89
+ def color_jitter(img, shift):
90
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
91
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
92
+ img = img + jitter_val
93
+ img = np.clip(img, 0, 1)
94
+ return img
95
+
96
+ @staticmethod
97
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
98
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
99
+ fn_idx = torch.randperm(4)
100
+ for fn_id in fn_idx:
101
+ if fn_id == 0 and brightness is not None:
102
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
103
+ img = adjust_brightness(img, brightness_factor)
104
+
105
+ if fn_id == 1 and contrast is not None:
106
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
107
+ img = adjust_contrast(img, contrast_factor)
108
+
109
+ if fn_id == 2 and saturation is not None:
110
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
111
+ img = adjust_saturation(img, saturation_factor)
112
+
113
+ if fn_id == 3 and hue is not None:
114
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
115
+ img = adjust_hue(img, hue_factor)
116
+ return img
117
+
118
+ def get_component_coordinates(self, index, status):
119
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
120
+ components_bbox = self.components_list[f'{index:08d}']
121
+ if status[0]: # hflip
122
+ # exchange right and left eye
123
+ tmp = components_bbox['left_eye']
124
+ components_bbox['left_eye'] = components_bbox['right_eye']
125
+ components_bbox['right_eye'] = tmp
126
+ # modify the width coordinate
127
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
128
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
129
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
130
+
131
+ # get coordinates
132
+ locations = []
133
+ for part in ['left_eye', 'right_eye', 'mouth']:
134
+ mean = components_bbox[part][0:2]
135
+ mean[0] = mean[0] * 2 + 128 ########
136
+ mean[1] = mean[1] * 2 + 128 ########
137
+ half_len = components_bbox[part][2] * 2 ########
138
+ if 'eye' in part:
139
+ half_len *= self.eye_enlarge_ratio
140
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
141
+ loc = torch.from_numpy(loc).float()
142
+ locations.append(loc)
143
+ return locations
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
+ gt_path = self.paths[index]
152
+ img_bytes = self.file_client.get(gt_path)
153
+ img_gt = imfrombytes(img_bytes, float32=True)
154
+
155
+ # random horizontal flip
156
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
+ h, w, _ = img_gt.shape
158
+
159
+ # get facial component coordinates
160
+ if self.crop_components:
161
+ locations = self.get_component_coordinates(index, status)
162
+ loc_left_eye, loc_right_eye, loc_mouth = locations
163
+
164
+ # ------------------------ generate lq image ------------------------ #
165
+ # blur
166
+ kernel = degradations.random_mixed_kernels(
167
+ self.kernel_list,
168
+ self.kernel_prob,
169
+ self.blur_kernel_size,
170
+ self.blur_sigma,
171
+ self.blur_sigma, [-math.pi, math.pi],
172
+ noise_range=None)
173
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
174
+ # downsample
175
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
176
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
177
+ # noise
178
+ if self.noise_range is not None:
179
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
180
+ # jpeg compression
181
+ if self.jpeg_range is not None:
182
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
183
+
184
+ # resize to original size
185
+ img_lq = cv2.resize(img_lq, (int(w // self.opt['scale']), int(h // self.opt['scale'])), interpolation=cv2.INTER_LINEAR)
186
+
187
+ # random color jitter (only for lq)
188
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
189
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
190
+ # random to gray (only for lq)
191
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
192
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
+
198
+ # BGR to RGB, HWC to CHW, numpy to tensor
199
+ #img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
200
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
201
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
202
+
203
+ # random color jitter (pytorch version) (only for lq)
204
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
205
+ brightness = self.opt.get('brightness', (0.5, 1.5))
206
+ contrast = self.opt.get('contrast', (0.5, 1.5))
207
+ saturation = self.opt.get('saturation', (0, 1.5))
208
+ hue = self.opt.get('hue', (-0.1, 0.1))
209
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
210
+
211
+ # round and clip
212
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
213
+
214
+ # normalize
215
+ normalize(img_gt, self.mean, self.std, inplace=True)
216
+ normalize(img_lq, self.mean, self.std, inplace=True)
217
+
218
+ '''
219
+ if self.crop_components:
220
+ return_dict = {
221
+ 'lq': img_lq,
222
+ 'gt': img_gt,
223
+ 'gt_path': gt_path,
224
+ 'loc_left_eye': loc_left_eye,
225
+ 'loc_right_eye': loc_right_eye,
226
+ 'loc_mouth': loc_mouth
227
+ }
228
+ return return_dict
229
+ else:
230
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
231
+ '''
232
+ return img_lq, img_gt
233
+
234
+ def __len__(self):
235
+ return len(self.paths)
datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
datasets/inference_dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.opts = opts
12
+
13
+ def __len__(self):
14
+ return len(self.paths)
15
+
16
+ def __getitem__(self, index):
17
+ from_path = self.paths[index]
18
+ from_im = Image.open(from_path)
19
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
20
+ if self.transform:
21
+ from_im = self.transform(from_im)
22
+ return from_im
latent_optimization.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import models.stylegan2.lpips as lpips
2
+ from torch import autograd, optim
3
+ from torchvision import transforms, utils
4
+ from tqdm import tqdm
5
+ import torch
6
+ from scripts.align_all_parallel import align_face
7
+ from utils.inference_utils import noise_regularize, noise_normalize_, get_lr, latent_noise, visualize
8
+
9
+ def latent_optimization(frame, pspex, landmarkpredictor, step=500, device='cuda'):
10
+ percept = lpips.PerceptualLoss(
11
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
12
+ )
13
+
14
+ transform = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
17
+ ])
18
+
19
+ with torch.no_grad():
20
+
21
+ noise_sample = torch.randn(1000, 512, device=device)
22
+ latent_out = pspex.decoder.style(noise_sample)
23
+ latent_mean = latent_out.mean(0)
24
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / 1000) ** 0.5
25
+
26
+ y = transform(frame).unsqueeze(dim=0).to(device)
27
+ I_ = align_face(frame, landmarkpredictor)
28
+ I_ = transform(I_).unsqueeze(dim=0).to(device)
29
+ wplus = pspex.encoder(I_) + pspex.latent_avg.unsqueeze(0)
30
+ _, f = pspex.encoder(y, return_feat=True)
31
+ latent_in = wplus.detach().clone()
32
+ feat = [f[0].detach().clone(), f[1].detach().clone()]
33
+
34
+
35
+
36
+ # wplus and f to optimize
37
+ latent_in.requires_grad = True
38
+ feat[0].requires_grad = True
39
+ feat[1].requires_grad = True
40
+
41
+ noises_single = pspex.decoder.make_noise()
42
+ basic_height, basic_width = int(y.shape[2]*32/256), int(y.shape[3]*32/256)
43
+ noises = []
44
+ for noise in noises_single:
45
+ noises.append(noise.new_empty(y.shape[0], 1, max(basic_height, int(y.shape[2]*noise.shape[2]/256)),
46
+ max(basic_width, int(y.shape[3]*noise.shape[2]/256))).normal_())
47
+ for noise in noises:
48
+ noise.requires_grad = True
49
+
50
+ init_lr=0.05
51
+ optimizer = optim.Adam(feat + noises, lr=init_lr)
52
+ optimizer2 = optim.Adam([latent_in], lr=init_lr)
53
+ noise_weight = 0.05 * 0.2
54
+
55
+ pbar = tqdm(range(step))
56
+ latent_path = []
57
+
58
+ for i in pbar:
59
+ t = i / step
60
+ lr = get_lr(t, init_lr)
61
+ optimizer.param_groups[0]["lr"] = lr
62
+ optimizer2.param_groups[0]["lr"] = get_lr(t, init_lr)
63
+
64
+ noise_strength = latent_std * noise_weight * max(0, 1 - t / 0.75) ** 2
65
+ latent_n = latent_noise(latent_in, noise_strength.item())
66
+
67
+ y_hat, _ = pspex.decoder([latent_n], input_is_latent=True, randomize_noise=False,
68
+ first_layer_feature=feat, noise=noises)
69
+
70
+
71
+ batch, channel, height, width = y_hat.shape
72
+
73
+ if height > y.shape[2]:
74
+ factor = height // y.shape[2]
75
+
76
+ y_hat = y_hat.reshape(
77
+ batch, channel, height // factor, factor, width // factor, factor
78
+ )
79
+ y_hat = y_hat.mean([3, 5])
80
+
81
+ p_loss = percept(y_hat, y).sum()
82
+ n_loss = noise_regularize(noises) * 1e3
83
+
84
+ loss = p_loss + n_loss
85
+
86
+ optimizer.zero_grad()
87
+ optimizer2.zero_grad()
88
+ loss.backward()
89
+ optimizer.step()
90
+ optimizer2.step()
91
+
92
+ noise_normalize_(noises)
93
+
94
+ ''' for visualization
95
+ if (i + 1) % 100 == 0 or i == 0:
96
+ viz = torch.cat((y_hat,y,y_hat-y), dim=3)
97
+ visualize(torch.clamp(viz[0].cpu(),-1,1), 60)
98
+ '''
99
+
100
+ pbar.set_description(
101
+ (
102
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
103
+ f" lr: {lr:.4f}"
104
+ )
105
+ )
106
+
107
+ return latent_n, feat, noises, wplus, f
models/__init__.py ADDED
File without changes
models/bisenet/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 zll
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/bisenet/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # face-parsing.PyTorch
2
+
3
+ <p align="center">
4
+ <a href="https://github.com/zllrunning/face-parsing.PyTorch">
5
+ <img class="page-image" src="https://github.com/zllrunning/face-parsing.PyTorch/blob/master/6.jpg" >
6
+ </a>
7
+ </p>
8
+
9
+ ### Contents
10
+ - [Training](#training)
11
+ - [Demo](#Demo)
12
+ - [References](#references)
13
+
14
+ ## Training
15
+
16
+ 1. Prepare training data:
17
+ -- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ)
18
+
19
+ -- change file path in the `prepropess_data.py` and run
20
+ ```Shell
21
+ python prepropess_data.py
22
+ ```
23
+
24
+ 2. Train the model using CelebAMask-HQ dataset:
25
+ Just run the train script:
26
+ ```
27
+ $ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
28
+ ```
29
+
30
+ If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`.
31
+
32
+
33
+ ## Demo
34
+ 1. Evaluate the trained model using:
35
+ ```Shell
36
+ # evaluate using GPU
37
+ python test.py
38
+ ```
39
+
40
+ ## Face makeup using parsing maps
41
+ [**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch)
42
+ <table>
43
+
44
+ <tr>
45
+ <th>&nbsp;</th>
46
+ <th>Hair</th>
47
+ <th>Lip</th>
48
+ </tr>
49
+
50
+ <!-- Line 1: Original Input -->
51
+ <tr>
52
+ <td><em>Original Input</em></td>
53
+ <td><img src="makeup/116_ori.png" height="256" width="256" alt="Original Input"></td>
54
+ <td><img src="makeup/116_lip_ori.png" height="256" width="256" alt="Original Input"></td>
55
+ </tr>
56
+
57
+ <!-- Line 3: Color -->
58
+ <tr>
59
+ <td>Color</td>
60
+ <td><img src="makeup/116_1.png" height="256" width="256" alt="Color"></td>
61
+ <td><img src="makeup/116_3.png" height="256" width="256" alt="Color"></td>
62
+ </tr>
63
+
64
+ </table>
65
+
66
+
67
+ ## References
68
+ - [BiSeNet](https://github.com/CoinCheung/BiSeNet)
models/bisenet/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from models.bisenet.resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
models/bisenet/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
+
5
+ """
6
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
+ """
8
+
9
+
10
+ class Flatten(Module):
11
+ def forward(self, input):
12
+ return input.view(input.size(0), -1)
13
+
14
+
15
+ def l2_norm(input, axis=1):
16
+ norm = torch.norm(input, 2, axis, True)
17
+ output = torch.div(input, norm)
18
+ return output
19
+
20
+
21
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22
+ """ A named tuple describing a ResNet block. """
23
+
24
+
25
+ def get_block(in_channel, depth, num_units, stride=2):
26
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27
+
28
+
29
+ def get_blocks(num_layers):
30
+ if num_layers == 50:
31
+ blocks = [
32
+ get_block(in_channel=64, depth=64, num_units=3),
33
+ get_block(in_channel=64, depth=128, num_units=4),
34
+ get_block(in_channel=128, depth=256, num_units=14),
35
+ get_block(in_channel=256, depth=512, num_units=3)
36
+ ]
37
+ elif num_layers == 100:
38
+ blocks = [
39
+ get_block(in_channel=64, depth=64, num_units=3),
40
+ get_block(in_channel=64, depth=128, num_units=13),
41
+ get_block(in_channel=128, depth=256, num_units=30),
42
+ get_block(in_channel=256, depth=512, num_units=3)
43
+ ]
44
+ elif num_layers == 152:
45
+ blocks = [
46
+ get_block(in_channel=64, depth=64, num_units=3),
47
+ get_block(in_channel=64, depth=128, num_units=8),
48
+ get_block(in_channel=128, depth=256, num_units=36),
49
+ get_block(in_channel=256, depth=512, num_units=3)
50
+ ]
51
+ else:
52
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53
+ return blocks
54
+
55
+
56
+ class SEModule(Module):
57
+ def __init__(self, channels, reduction):
58
+ super(SEModule, self).__init__()
59
+ self.avg_pool = AdaptiveAvgPool2d(1)
60
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61
+ self.relu = ReLU(inplace=True)
62
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63
+ self.sigmoid = Sigmoid()
64
+
65
+ def forward(self, x):
66
+ module_input = x
67
+ x = self.avg_pool(x)
68
+ x = self.fc1(x)
69
+ x = self.relu(x)
70
+ x = self.fc2(x)
71
+ x = self.sigmoid(x)
72
+ return module_input * x
73
+
74
+
75
+ class bottleneck_IR(Module):
76
+ def __init__(self, in_channel, depth, stride):
77
+ super(bottleneck_IR, self).__init__()
78
+ if in_channel == depth:
79
+ self.shortcut_layer = MaxPool2d(1, stride)
80
+ else:
81
+ self.shortcut_layer = Sequential(
82
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83
+ BatchNorm2d(depth)
84
+ )
85
+ self.res_layer = Sequential(
86
+ BatchNorm2d(in_channel),
87
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89
+ )
90
+
91
+ def forward(self, x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class bottleneck_IR_SE(Module):
98
+ def __init__(self, in_channel, depth, stride):
99
+ super(bottleneck_IR_SE, self).__init__()
100
+ if in_channel == depth:
101
+ self.shortcut_layer = MaxPool2d(1, stride)
102
+ else:
103
+ self.shortcut_layer = Sequential(
104
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105
+ BatchNorm2d(depth)
106
+ )
107
+ self.res_layer = Sequential(
108
+ BatchNorm2d(in_channel),
109
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110
+ PReLU(depth),
111
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112
+ BatchNorm2d(depth),
113
+ SEModule(depth, 16)
114
+ )
115
+
116
+ def forward(self, x):
117
+ shortcut = self.shortcut_layer(x)
118
+ res = self.res_layer(x)
119
+ return res + shortcut
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
6
+
7
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
8
+ from models.stylegan2.model import EqualLinear
9
+
10
+
11
+ class GradualStyleBlock(Module):
12
+ def __init__(self, in_c, out_c, spatial, max_pooling=False):
13
+ super(GradualStyleBlock, self).__init__()
14
+ self.out_c = out_c
15
+ self.spatial = spatial
16
+ self.max_pooling = max_pooling
17
+ num_pools = int(np.log2(spatial))
18
+ modules = []
19
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
20
+ nn.LeakyReLU()]
21
+ for i in range(num_pools - 1):
22
+ modules += [
23
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
24
+ nn.LeakyReLU()
25
+ ]
26
+ self.convs = nn.Sequential(*modules)
27
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
28
+
29
+ def forward(self, x):
30
+ x = self.convs(x)
31
+ # To make E accept more general H*W images, we add global average pooling to
32
+ # resize all features to 1*1*512 before mapping to latent codes
33
+ if self.max_pooling:
34
+ x = F.adaptive_max_pool2d(x, 1) ##### modified
35
+ else:
36
+ x = F.adaptive_avg_pool2d(x, 1) ##### modified
37
+ x = x.view(-1, self.out_c)
38
+ x = self.linear(x)
39
+ return x
40
+
41
+ class AdaptiveInstanceNorm(nn.Module):
42
+ def __init__(self, fin, style_dim=512):
43
+ super().__init__()
44
+
45
+ self.norm = nn.InstanceNorm2d(fin, affine=False)
46
+ self.style = nn.Linear(style_dim, fin * 2)
47
+
48
+ self.style.bias.data[:fin] = 1
49
+ self.style.bias.data[fin:] = 0
50
+
51
+ def forward(self, input, style):
52
+ style = self.style(style).unsqueeze(2).unsqueeze(3)
53
+ gamma, beta = style.chunk(2, 1)
54
+ out = self.norm(input)
55
+ out = gamma * out + beta
56
+ return out
57
+
58
+
59
+ class FusionLayer(Module): ##### modified
60
+ def __init__(self, inchannel, outchannel, use_skip_torgb=True, use_att=0):
61
+ super(FusionLayer, self).__init__()
62
+
63
+ self.transform = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1),
64
+ nn.LeakyReLU())
65
+ self.fusion_out = nn.Conv2d(outchannel*2, outchannel, kernel_size=3, stride=1, padding=1)
66
+ self.fusion_out.weight.data *= 0.01
67
+ self.fusion_out.weight[:,0:outchannel,1,1].data += torch.eye(outchannel)
68
+
69
+ self.use_skip_torgb = use_skip_torgb
70
+ if use_skip_torgb:
71
+ self.fusion_skip = nn.Conv2d(3+outchannel, 3, kernel_size=3, stride=1, padding=1)
72
+ self.fusion_skip.weight.data *= 0.01
73
+ self.fusion_skip.weight[:,0:3,1,1].data += torch.eye(3)
74
+
75
+ self.use_att = use_att
76
+ if use_att:
77
+ modules = []
78
+ modules.append(nn.Linear(512, outchannel))
79
+ for _ in range(use_att):
80
+ modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
81
+ modules.append(nn.Linear(outchannel, outchannel))
82
+ modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
83
+ self.linear = Sequential(*modules)
84
+ self.norm = AdaptiveInstanceNorm(outchannel*2, outchannel)
85
+ self.conv = nn.Conv2d(outchannel*2, 1, 3, 1, 1, bias=True)
86
+
87
+ def forward(self, feat, out, skip, editing_w=None):
88
+ x = self.transform(feat)
89
+ # similar to VToonify, use editing vector as condition
90
+ # fuse encoder feature and decoder feature with a predicted attention mask m_E
91
+ # if self.use_att = False, just fuse them with a simple conv layer
92
+ if self.use_att and editing_w is not None:
93
+ label = self.linear(editing_w)
94
+ m_E = (F.relu(self.conv(self.norm(torch.cat([out, abs(out-x)], dim=1), label)))).tanh()
95
+ x = x * m_E
96
+ out = self.fusion_out(torch.cat((out, x), dim=1))
97
+ if self.use_skip_torgb:
98
+ skip = self.fusion_skip(torch.cat((skip, x), dim=1))
99
+ return out, skip
100
+
101
+
102
+ class ResnetBlock(nn.Module):
103
+ def __init__(self, dim):
104
+ super(ResnetBlock, self).__init__()
105
+
106
+ self.conv_block = nn.Sequential(Conv2d(dim, dim, 3, 1, 1),
107
+ nn.LeakyReLU(),
108
+ Conv2d(dim, dim, 3, 1, 1))
109
+ self.relu = nn.LeakyReLU()
110
+
111
+ def forward(self, x):
112
+ out = x + self.conv_block(x)
113
+ return self.relu(out)
114
+
115
+ # trainable light-weight translation network T
116
+ # for sketch/mask-to-face translation,
117
+ # we add a trainable T to map y to an intermediate domain where E can more easily extract features.
118
+ class ResnetGenerator(nn.Module):
119
+ def __init__(self, in_channel=19, res_num=2):
120
+ super(ResnetGenerator, self).__init__()
121
+
122
+ modules = []
123
+ modules.append(Conv2d(in_channel, 16, 3, 2, 1))
124
+ modules.append(nn.LeakyReLU())
125
+ modules.append(Conv2d(16, 16, 3, 2, 1))
126
+ modules.append(nn.LeakyReLU())
127
+ for _ in range(res_num):
128
+ modules.append(ResnetBlock(16))
129
+ for _ in range(2):
130
+ modules.append(nn.ConvTranspose2d(16, 16, 3, 2, 1, output_padding=1))
131
+ modules.append(nn.LeakyReLU())
132
+ modules.append(Conv2d(16, 64, 3, 1, 1, bias=False))
133
+ modules.append(BatchNorm2d(64))
134
+ modules.append(PReLU(64))
135
+ self.model = Sequential(*modules)
136
+
137
+ def forward(self, input):
138
+ return self.model(input)
139
+
140
+ class GradualStyleEncoder(Module):
141
+ def __init__(self, num_layers, mode='ir', opts=None):
142
+ super(GradualStyleEncoder, self).__init__()
143
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
144
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
145
+ blocks = get_blocks(num_layers)
146
+ if mode == 'ir':
147
+ unit_module = bottleneck_IR
148
+ elif mode == 'ir_se':
149
+ unit_module = bottleneck_IR_SE
150
+
151
+ # for sketch/mask-to-face translation, add a new network T
152
+ if opts.input_nc != 3:
153
+ self.input_label_layer = ResnetGenerator(opts.input_nc, opts.res_num)
154
+
155
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
156
+ BatchNorm2d(64),
157
+ PReLU(64))
158
+ modules = []
159
+ for block in blocks:
160
+ for bottleneck in block:
161
+ modules.append(unit_module(bottleneck.in_channel,
162
+ bottleneck.depth,
163
+ bottleneck.stride))
164
+ self.body = Sequential(*modules)
165
+
166
+ self.styles = nn.ModuleList()
167
+ self.style_count = opts.n_styles
168
+ self.coarse_ind = 3
169
+ self.middle_ind = 7
170
+ for i in range(self.style_count):
171
+ if i < self.coarse_ind:
172
+ style = GradualStyleBlock(512, 512, 16, 'max_pooling' in opts and opts.max_pooling)
173
+ elif i < self.middle_ind:
174
+ style = GradualStyleBlock(512, 512, 32, 'max_pooling' in opts and opts.max_pooling)
175
+ else:
176
+ style = GradualStyleBlock(512, 512, 64, 'max_pooling' in opts and opts.max_pooling)
177
+ self.styles.append(style)
178
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
179
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
180
+
181
+ # we concatenate pSp features in the middle layers and
182
+ # add a convolution layer to map the concatenated features to the first-layer input feature f of G.
183
+ self.featlayer = nn.Conv2d(768, 512, kernel_size=1, stride=1, padding=0) ##### modified
184
+ self.skiplayer = nn.Conv2d(768, 3, kernel_size=1, stride=1, padding=0) ##### modified
185
+
186
+ # skip connection
187
+ if 'use_skip' in opts and opts.use_skip: ##### modified
188
+ self.fusion = nn.ModuleList()
189
+ channels = [[256,512], [256,512], [256,512], [256,512], [128,512], [64,256], [64,128]]
190
+ # opts.skip_max_layer: how many layers are skipped to the decoder
191
+ for inc, outc in channels[:max(1, min(7, opts.skip_max_layer))]: # from 4 to 256
192
+ self.fusion.append(FusionLayer(inc, outc, opts.use_skip_torgb, opts.use_att))
193
+
194
+ def _upsample_add(self, x, y):
195
+ '''Upsample and add two feature maps.
196
+ Args:
197
+ x: (Variable) top feature map to be upsampled.
198
+ y: (Variable) lateral feature map.
199
+ Returns:
200
+ (Variable) added feature map.
201
+ Note in PyTorch, when input size is odd, the upsampled feature map
202
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
203
+ maybe not equal to the lateral feature map size.
204
+ e.g.
205
+ original input size: [N,_,15,15] ->
206
+ conv2d feature map size: [N,_,8,8] ->
207
+ upsampled feature map size: [N,_,16,16]
208
+ So we choose bilinear upsample which supports arbitrary output sizes.
209
+ '''
210
+ _, _, H, W = y.size()
211
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
212
+
213
+ # return_feat: return f
214
+ # return_full: return f and the skipped encoder features
215
+ # return [out, feats]
216
+ # out is the style latent code w+
217
+ # feats[0] is f for the 1st conv layer, feats[1] is f for the 1st torgb layer
218
+ # feats[2-8] is the skipped encoder features
219
+ def forward(self, x, return_feat=False, return_full=False): ##### modified
220
+ if x.shape[1] != 3:
221
+ x = self.input_label_layer(x)
222
+ else:
223
+ x = self.input_layer(x)
224
+ c256 = x ##### modified
225
+
226
+ latents = []
227
+ modulelist = list(self.body._modules.values())
228
+ for i, l in enumerate(modulelist):
229
+ x = l(x)
230
+ if i == 2: ##### modified
231
+ c128 = x
232
+ elif i == 6:
233
+ c1 = x
234
+ elif i == 10: ##### modified
235
+ c21 = x ##### modified
236
+ elif i == 15: ##### modified
237
+ c22 = x ##### modified
238
+ elif i == 20:
239
+ c2 = x
240
+ elif i == 23:
241
+ c3 = x
242
+
243
+ for j in range(self.coarse_ind):
244
+ latents.append(self.styles[j](c3))
245
+
246
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
247
+ for j in range(self.coarse_ind, self.middle_ind):
248
+ latents.append(self.styles[j](p2))
249
+
250
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
251
+ for j in range(self.middle_ind, self.style_count):
252
+ latents.append(self.styles[j](p1))
253
+
254
+ out = torch.stack(latents, dim=1)
255
+
256
+ if not return_feat:
257
+ return out
258
+
259
+ feats = [self.featlayer(torch.cat((c21, c22, c2), dim=1)), self.skiplayer(torch.cat((c21, c22, c2), dim=1))]
260
+
261
+ if return_full: ##### modified
262
+ feats += [c2, c2, c22, c21, c1, c128, c256]
263
+
264
+ return out, feats
265
+
266
+
267
+ # only compute the first-layer feature f
268
+ # E_F in the paper
269
+ def get_feat(self, x): ##### modified
270
+ # for sketch/mask-to-face translation
271
+ # use a trainable light-weight translation network T
272
+ if x.shape[1] != 3:
273
+ x = self.input_label_layer(x)
274
+ else:
275
+ x = self.input_layer(x)
276
+
277
+ latents = []
278
+ modulelist = list(self.body._modules.values())
279
+ for i, l in enumerate(modulelist):
280
+ x = l(x)
281
+ if i == 10: ##### modified
282
+ c21 = x ##### modified
283
+ elif i == 15: ##### modified
284
+ c22 = x ##### modified
285
+ elif i == 20:
286
+ c2 = x
287
+ break
288
+ return self.featlayer(torch.cat((c21, c22, c2), dim=1))
289
+
290
+ class BackboneEncoderUsingLastLayerIntoW(Module):
291
+ def __init__(self, num_layers, mode='ir', opts=None):
292
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
293
+ print('Using BackboneEncoderUsingLastLayerIntoW')
294
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
295
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
296
+ blocks = get_blocks(num_layers)
297
+ if mode == 'ir':
298
+ unit_module = bottleneck_IR
299
+ elif mode == 'ir_se':
300
+ unit_module = bottleneck_IR_SE
301
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
302
+ BatchNorm2d(64),
303
+ PReLU(64))
304
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
305
+ self.linear = EqualLinear(512, 512, lr_mul=1)
306
+ modules = []
307
+ for block in blocks:
308
+ for bottleneck in block:
309
+ modules.append(unit_module(bottleneck.in_channel,
310
+ bottleneck.depth,
311
+ bottleneck.stride))
312
+ self.body = Sequential(*modules)
313
+
314
+ def forward(self, x):
315
+ x = self.input_layer(x)
316
+ x = self.body(x)
317
+ x = self.output_pool(x)
318
+ x = x.view(-1, 512)
319
+ x = self.linear(x)
320
+ return x
321
+
322
+
323
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
324
+ def __init__(self, num_layers, mode='ir', opts=None):
325
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
326
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
327
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
328
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
329
+ blocks = get_blocks(num_layers)
330
+ if mode == 'ir':
331
+ unit_module = bottleneck_IR
332
+ elif mode == 'ir_se':
333
+ unit_module = bottleneck_IR_SE
334
+ self.n_styles = opts.n_styles
335
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
336
+ BatchNorm2d(64),
337
+ PReLU(64))
338
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
339
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
340
+ Flatten(),
341
+ Linear(512 * 7 * 7, 512))
342
+ self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
343
+ modules = []
344
+ for block in blocks:
345
+ for bottleneck in block:
346
+ modules.append(unit_module(bottleneck.in_channel,
347
+ bottleneck.depth,
348
+ bottleneck.stride))
349
+ self.body = Sequential(*modules)
350
+
351
+ def forward(self, x):
352
+ x = self.input_layer(x)
353
+ x = self.body(x)
354
+ x = self.output_layer_2(x)
355
+ x = self.linear(x)
356
+ x = x.view(-1, self.n_styles, 512)
357
+ return x
models/mtcnn/__init__.py ADDED
File without changes
models/mtcnn/mtcnn.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
5
+ from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
7
+ from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
8
+
9
+ device = 'cuda:0'
10
+
11
+
12
+ class MTCNN():
13
+ def __init__(self):
14
+ print(device)
15
+ self.pnet = PNet().to(device)
16
+ self.rnet = RNet().to(device)
17
+ self.onet = ONet().to(device)
18
+ self.pnet.eval()
19
+ self.rnet.eval()
20
+ self.onet.eval()
21
+ self.refrence = get_reference_facial_points(default_square=True)
22
+
23
+ def align(self, img):
24
+ _, landmarks = self.detect_faces(img)
25
+ if len(landmarks) == 0:
26
+ return None, None
27
+ facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
28
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
29
+ return Image.fromarray(warped_face), tfm
30
+
31
+ def align_multi(self, img, limit=None, min_face_size=30.0):
32
+ boxes, landmarks = self.detect_faces(img, min_face_size)
33
+ if limit:
34
+ boxes = boxes[:limit]
35
+ landmarks = landmarks[:limit]
36
+ faces = []
37
+ tfms = []
38
+ for landmark in landmarks:
39
+ facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
40
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
41
+ faces.append(Image.fromarray(warped_face))
42
+ tfms.append(tfm)
43
+ return boxes, faces, tfms
44
+
45
+ def detect_faces(self, image, min_face_size=20.0,
46
+ thresholds=[0.15, 0.25, 0.35],
47
+ nms_thresholds=[0.7, 0.7, 0.7]):
48
+ """
49
+ Arguments:
50
+ image: an instance of PIL.Image.
51
+ min_face_size: a float number.
52
+ thresholds: a list of length 3.
53
+ nms_thresholds: a list of length 3.
54
+
55
+ Returns:
56
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
57
+ bounding boxes and facial landmarks.
58
+ """
59
+
60
+ # BUILD AN IMAGE PYRAMID
61
+ width, height = image.size
62
+ min_length = min(height, width)
63
+
64
+ min_detection_size = 12
65
+ factor = 0.707 # sqrt(0.5)
66
+
67
+ # scales for scaling the image
68
+ scales = []
69
+
70
+ # scales the image so that
71
+ # minimum size that we can detect equals to
72
+ # minimum face size that we want to detect
73
+ m = min_detection_size / min_face_size
74
+ min_length *= m
75
+
76
+ factor_count = 0
77
+ while min_length > min_detection_size:
78
+ scales.append(m * factor ** factor_count)
79
+ min_length *= factor
80
+ factor_count += 1
81
+
82
+ # STAGE 1
83
+
84
+ # it will be returned
85
+ bounding_boxes = []
86
+
87
+ with torch.no_grad():
88
+ # run P-Net on different scales
89
+ for s in scales:
90
+ boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
91
+ bounding_boxes.append(boxes)
92
+
93
+ # collect boxes (and offsets, and scores) from different scales
94
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
95
+ bounding_boxes = np.vstack(bounding_boxes)
96
+
97
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
98
+ bounding_boxes = bounding_boxes[keep]
99
+
100
+ # use offsets predicted by pnet to transform bounding boxes
101
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
102
+ # shape [n_boxes, 5]
103
+
104
+ bounding_boxes = convert_to_square(bounding_boxes)
105
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
106
+
107
+ # STAGE 2
108
+
109
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
110
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
111
+
112
+ output = self.rnet(img_boxes)
113
+ offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
114
+ probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
115
+
116
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
117
+ bounding_boxes = bounding_boxes[keep]
118
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
119
+ offsets = offsets[keep]
120
+
121
+ keep = nms(bounding_boxes, nms_thresholds[1])
122
+ bounding_boxes = bounding_boxes[keep]
123
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
124
+ bounding_boxes = convert_to_square(bounding_boxes)
125
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
126
+
127
+ # STAGE 3
128
+
129
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
130
+ if len(img_boxes) == 0:
131
+ return [], []
132
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
133
+ output = self.onet(img_boxes)
134
+ landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
135
+ offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
136
+ probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
137
+
138
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
139
+ bounding_boxes = bounding_boxes[keep]
140
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
141
+ offsets = offsets[keep]
142
+ landmarks = landmarks[keep]
143
+
144
+ # compute landmark points
145
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
146
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
147
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
148
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
149
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
150
+
151
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
152
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
153
+ bounding_boxes = bounding_boxes[keep]
154
+ landmarks = landmarks[keep]
155
+
156
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/__init__.py ADDED
File without changes
models/mtcnn/mtcnn_pytorch/src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .visualization_utils import show_bboxes
2
+ from .detector import detect_faces
models/mtcnn/mtcnn_pytorch/src/align_trans.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Apr 24 15:43:29 2017
4
+ @author: zhaoy
5
+ """
6
+ import numpy as np
7
+ import cv2
8
+
9
+ # from scipy.linalg import lstsq
10
+ # from scipy.ndimage import geometric_transform # , map_coordinates
11
+
12
+ from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
13
+
14
+ # reference facial points, a list of coordinates (x,y)
15
+ REFERENCE_FACIAL_POINTS = [
16
+ [30.29459953, 51.69630051],
17
+ [65.53179932, 51.50139999],
18
+ [48.02519989, 71.73660278],
19
+ [33.54930115, 92.3655014],
20
+ [62.72990036, 92.20410156]
21
+ ]
22
+
23
+ DEFAULT_CROP_SIZE = (96, 112)
24
+
25
+
26
+ class FaceWarpException(Exception):
27
+ def __str__(self):
28
+ return 'In File {}:{}'.format(
29
+ __file__, super.__str__(self))
30
+
31
+
32
+ def get_reference_facial_points(output_size=None,
33
+ inner_padding_factor=0.0,
34
+ outer_padding=(0, 0),
35
+ default_square=False):
36
+ """
37
+ Function:
38
+ ----------
39
+ get reference 5 key points according to crop settings:
40
+ 0. Set default crop_size:
41
+ if default_square:
42
+ crop_size = (112, 112)
43
+ else:
44
+ crop_size = (96, 112)
45
+ 1. Pad the crop_size by inner_padding_factor in each side;
46
+ 2. Resize crop_size into (output_size - outer_padding*2),
47
+ pad into output_size with outer_padding;
48
+ 3. Output reference_5point;
49
+ Parameters:
50
+ ----------
51
+ @output_size: (w, h) or None
52
+ size of aligned face image
53
+ @inner_padding_factor: (w_factor, h_factor)
54
+ padding factor for inner (w, h)
55
+ @outer_padding: (w_pad, h_pad)
56
+ each row is a pair of coordinates (x, y)
57
+ @default_square: True or False
58
+ if True:
59
+ default crop_size = (112, 112)
60
+ else:
61
+ default crop_size = (96, 112);
62
+ !!! make sure, if output_size is not None:
63
+ (output_size - outer_padding)
64
+ = some_scale * (default crop_size * (1.0 + inner_padding_factor))
65
+ Returns:
66
+ ----------
67
+ @reference_5point: 5x2 np.array
68
+ each row is a pair of transformed coordinates (x, y)
69
+ """
70
+ # print('\n===> get_reference_facial_points():')
71
+
72
+ # print('---> Params:')
73
+ # print(' output_size: ', output_size)
74
+ # print(' inner_padding_factor: ', inner_padding_factor)
75
+ # print(' outer_padding:', outer_padding)
76
+ # print(' default_square: ', default_square)
77
+
78
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
79
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
80
+
81
+ # 0) make the inner region a square
82
+ if default_square:
83
+ size_diff = max(tmp_crop_size) - tmp_crop_size
84
+ tmp_5pts += size_diff / 2
85
+ tmp_crop_size += size_diff
86
+
87
+ # print('---> default:')
88
+ # print(' crop_size = ', tmp_crop_size)
89
+ # print(' reference_5pts = ', tmp_5pts)
90
+
91
+ if (output_size and
92
+ output_size[0] == tmp_crop_size[0] and
93
+ output_size[1] == tmp_crop_size[1]):
94
+ # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
95
+ return tmp_5pts
96
+
97
+ if (inner_padding_factor == 0 and
98
+ outer_padding == (0, 0)):
99
+ if output_size is None:
100
+ # print('No paddings to do: return default reference points')
101
+ return tmp_5pts
102
+ else:
103
+ raise FaceWarpException(
104
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
105
+
106
+ # check output size
107
+ if not (0 <= inner_padding_factor <= 1.0):
108
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
109
+
110
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
111
+ and output_size is None):
112
+ output_size = tmp_crop_size * \
113
+ (1 + inner_padding_factor * 2).astype(np.int32)
114
+ output_size += np.array(outer_padding)
115
+ # print(' deduced from paddings, output_size = ', output_size)
116
+
117
+ if not (outer_padding[0] < output_size[0]
118
+ and outer_padding[1] < output_size[1]):
119
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
120
+ 'and outer_padding[1] < output_size[1])')
121
+
122
+ # 1) pad the inner region according inner_padding_factor
123
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
124
+ if inner_padding_factor > 0:
125
+ size_diff = tmp_crop_size * inner_padding_factor * 2
126
+ tmp_5pts += size_diff / 2
127
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
128
+
129
+ # print(' crop_size = ', tmp_crop_size)
130
+ # print(' reference_5pts = ', tmp_5pts)
131
+
132
+ # 2) resize the padded inner region
133
+ # print('---> STEP2: resize the padded inner region')
134
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
135
+ # print(' crop_size = ', tmp_crop_size)
136
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
137
+
138
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
139
+ raise FaceWarpException('Must have (output_size - outer_padding)'
140
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
141
+
142
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
143
+ # print(' resize scale_factor = ', scale_factor)
144
+ tmp_5pts = tmp_5pts * scale_factor
145
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
146
+ # tmp_5pts = tmp_5pts + size_diff / 2
147
+ tmp_crop_size = size_bf_outer_pad
148
+ # print(' crop_size = ', tmp_crop_size)
149
+ # print(' reference_5pts = ', tmp_5pts)
150
+
151
+ # 3) add outer_padding to make output_size
152
+ reference_5point = tmp_5pts + np.array(outer_padding)
153
+ tmp_crop_size = output_size
154
+ # print('---> STEP3: add outer_padding to make output_size')
155
+ # print(' crop_size = ', tmp_crop_size)
156
+ # print(' reference_5pts = ', tmp_5pts)
157
+
158
+ # print('===> end get_reference_facial_points\n')
159
+
160
+ return reference_5point
161
+
162
+
163
+ def get_affine_transform_matrix(src_pts, dst_pts):
164
+ """
165
+ Function:
166
+ ----------
167
+ get affine transform matrix 'tfm' from src_pts to dst_pts
168
+ Parameters:
169
+ ----------
170
+ @src_pts: Kx2 np.array
171
+ source points matrix, each row is a pair of coordinates (x, y)
172
+ @dst_pts: Kx2 np.array
173
+ destination points matrix, each row is a pair of coordinates (x, y)
174
+ Returns:
175
+ ----------
176
+ @tfm: 2x3 np.array
177
+ transform matrix from src_pts to dst_pts
178
+ """
179
+
180
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
181
+ n_pts = src_pts.shape[0]
182
+ ones = np.ones((n_pts, 1), src_pts.dtype)
183
+ src_pts_ = np.hstack([src_pts, ones])
184
+ dst_pts_ = np.hstack([dst_pts, ones])
185
+
186
+ # #print(('src_pts_:\n' + str(src_pts_))
187
+ # #print(('dst_pts_:\n' + str(dst_pts_))
188
+
189
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
190
+
191
+ # #print(('np.linalg.lstsq return A: \n' + str(A))
192
+ # #print(('np.linalg.lstsq return res: \n' + str(res))
193
+ # #print(('np.linalg.lstsq return rank: \n' + str(rank))
194
+ # #print(('np.linalg.lstsq return s: \n' + str(s))
195
+
196
+ if rank == 3:
197
+ tfm = np.float32([
198
+ [A[0, 0], A[1, 0], A[2, 0]],
199
+ [A[0, 1], A[1, 1], A[2, 1]]
200
+ ])
201
+ elif rank == 2:
202
+ tfm = np.float32([
203
+ [A[0, 0], A[1, 0], 0],
204
+ [A[0, 1], A[1, 1], 0]
205
+ ])
206
+
207
+ return tfm
208
+
209
+
210
+ def warp_and_crop_face(src_img,
211
+ facial_pts,
212
+ reference_pts=None,
213
+ crop_size=(96, 112),
214
+ align_type='smilarity'):
215
+ """
216
+ Function:
217
+ ----------
218
+ apply affine transform 'trans' to uv
219
+ Parameters:
220
+ ----------
221
+ @src_img: 3x3 np.array
222
+ input image
223
+ @facial_pts: could be
224
+ 1)a list of K coordinates (x,y)
225
+ or
226
+ 2) Kx2 or 2xK np.array
227
+ each row or col is a pair of coordinates (x, y)
228
+ @reference_pts: could be
229
+ 1) a list of K coordinates (x,y)
230
+ or
231
+ 2) Kx2 or 2xK np.array
232
+ each row or col is a pair of coordinates (x, y)
233
+ or
234
+ 3) None
235
+ if None, use default reference facial points
236
+ @crop_size: (w, h)
237
+ output face image size
238
+ @align_type: transform type, could be one of
239
+ 1) 'similarity': use similarity transform
240
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
241
+ by calling cv2.getAffineTransform()
242
+ 3) 'affine': use all points to do affine transform
243
+ Returns:
244
+ ----------
245
+ @face_img: output face image with size (w, h) = @crop_size
246
+ """
247
+
248
+ if reference_pts is None:
249
+ if crop_size[0] == 96 and crop_size[1] == 112:
250
+ reference_pts = REFERENCE_FACIAL_POINTS
251
+ else:
252
+ default_square = False
253
+ inner_padding_factor = 0
254
+ outer_padding = (0, 0)
255
+ output_size = crop_size
256
+
257
+ reference_pts = get_reference_facial_points(output_size,
258
+ inner_padding_factor,
259
+ outer_padding,
260
+ default_square)
261
+
262
+ ref_pts = np.float32(reference_pts)
263
+ ref_pts_shp = ref_pts.shape
264
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
265
+ raise FaceWarpException(
266
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
267
+
268
+ if ref_pts_shp[0] == 2:
269
+ ref_pts = ref_pts.T
270
+
271
+ src_pts = np.float32(facial_pts)
272
+ src_pts_shp = src_pts.shape
273
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
274
+ raise FaceWarpException(
275
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
276
+
277
+ if src_pts_shp[0] == 2:
278
+ src_pts = src_pts.T
279
+
280
+ # #print('--->src_pts:\n', src_pts
281
+ # #print('--->ref_pts\n', ref_pts
282
+
283
+ if src_pts.shape != ref_pts.shape:
284
+ raise FaceWarpException(
285
+ 'facial_pts and reference_pts must have the same shape')
286
+
287
+ if align_type is 'cv2_affine':
288
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
289
+ # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
290
+ elif align_type is 'affine':
291
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
292
+ # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
293
+ else:
294
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
295
+ # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
296
+
297
+ # #print('--->Transform matrix: '
298
+ # #print(('type(tfm):' + str(type(tfm)))
299
+ # #print(('tfm.dtype:' + str(tfm.dtype))
300
+ # #print( tfm
301
+
302
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
303
+
304
+ return face_img, tfm
models/mtcnn/mtcnn_pytorch/src/box_utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def nms(boxes, overlap_threshold=0.5, mode='union'):
6
+ """Non-maximum suppression.
7
+
8
+ Arguments:
9
+ boxes: a float numpy array of shape [n, 5],
10
+ where each row is (xmin, ymin, xmax, ymax, score).
11
+ overlap_threshold: a float number.
12
+ mode: 'union' or 'min'.
13
+
14
+ Returns:
15
+ list with indices of the selected boxes
16
+ """
17
+
18
+ # if there are no boxes, return the empty list
19
+ if len(boxes) == 0:
20
+ return []
21
+
22
+ # list of picked indices
23
+ pick = []
24
+
25
+ # grab the coordinates of the bounding boxes
26
+ x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
27
+
28
+ area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
29
+ ids = np.argsort(score) # in increasing order
30
+
31
+ while len(ids) > 0:
32
+
33
+ # grab index of the largest value
34
+ last = len(ids) - 1
35
+ i = ids[last]
36
+ pick.append(i)
37
+
38
+ # compute intersections
39
+ # of the box with the largest score
40
+ # with the rest of boxes
41
+
42
+ # left top corner of intersection boxes
43
+ ix1 = np.maximum(x1[i], x1[ids[:last]])
44
+ iy1 = np.maximum(y1[i], y1[ids[:last]])
45
+
46
+ # right bottom corner of intersection boxes
47
+ ix2 = np.minimum(x2[i], x2[ids[:last]])
48
+ iy2 = np.minimum(y2[i], y2[ids[:last]])
49
+
50
+ # width and height of intersection boxes
51
+ w = np.maximum(0.0, ix2 - ix1 + 1.0)
52
+ h = np.maximum(0.0, iy2 - iy1 + 1.0)
53
+
54
+ # intersections' areas
55
+ inter = w * h
56
+ if mode == 'min':
57
+ overlap = inter / np.minimum(area[i], area[ids[:last]])
58
+ elif mode == 'union':
59
+ # intersection over union (IoU)
60
+ overlap = inter / (area[i] + area[ids[:last]] - inter)
61
+
62
+ # delete all boxes where overlap is too big
63
+ ids = np.delete(
64
+ ids,
65
+ np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
66
+ )
67
+
68
+ return pick
69
+
70
+
71
+ def convert_to_square(bboxes):
72
+ """Convert bounding boxes to a square form.
73
+
74
+ Arguments:
75
+ bboxes: a float numpy array of shape [n, 5].
76
+
77
+ Returns:
78
+ a float numpy array of shape [n, 5],
79
+ squared bounding boxes.
80
+ """
81
+
82
+ square_bboxes = np.zeros_like(bboxes)
83
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
84
+ h = y2 - y1 + 1.0
85
+ w = x2 - x1 + 1.0
86
+ max_side = np.maximum(h, w)
87
+ square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
88
+ square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
89
+ square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
90
+ square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
91
+ return square_bboxes
92
+
93
+
94
+ def calibrate_box(bboxes, offsets):
95
+ """Transform bounding boxes to be more like true bounding boxes.
96
+ 'offsets' is one of the outputs of the nets.
97
+
98
+ Arguments:
99
+ bboxes: a float numpy array of shape [n, 5].
100
+ offsets: a float numpy array of shape [n, 4].
101
+
102
+ Returns:
103
+ a float numpy array of shape [n, 5].
104
+ """
105
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
106
+ w = x2 - x1 + 1.0
107
+ h = y2 - y1 + 1.0
108
+ w = np.expand_dims(w, 1)
109
+ h = np.expand_dims(h, 1)
110
+
111
+ # this is what happening here:
112
+ # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
113
+ # x1_true = x1 + tx1*w
114
+ # y1_true = y1 + ty1*h
115
+ # x2_true = x2 + tx2*w
116
+ # y2_true = y2 + ty2*h
117
+ # below is just more compact form of this
118
+
119
+ # are offsets always such that
120
+ # x1 < x2 and y1 < y2 ?
121
+
122
+ translation = np.hstack([w, h, w, h]) * offsets
123
+ bboxes[:, 0:4] = bboxes[:, 0:4] + translation
124
+ return bboxes
125
+
126
+
127
+ def get_image_boxes(bounding_boxes, img, size=24):
128
+ """Cut out boxes from the image.
129
+
130
+ Arguments:
131
+ bounding_boxes: a float numpy array of shape [n, 5].
132
+ img: an instance of PIL.Image.
133
+ size: an integer, size of cutouts.
134
+
135
+ Returns:
136
+ a float numpy array of shape [n, 3, size, size].
137
+ """
138
+
139
+ num_boxes = len(bounding_boxes)
140
+ width, height = img.size
141
+
142
+ [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
143
+ img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
144
+
145
+ for i in range(num_boxes):
146
+ img_box = np.zeros((h[i], w[i], 3), 'uint8')
147
+
148
+ img_array = np.asarray(img, 'uint8')
149
+ img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
150
+ img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
151
+
152
+ # resize
153
+ img_box = Image.fromarray(img_box)
154
+ img_box = img_box.resize((size, size), Image.BILINEAR)
155
+ img_box = np.asarray(img_box, 'float32')
156
+
157
+ img_boxes[i, :, :, :] = _preprocess(img_box)
158
+
159
+ return img_boxes
160
+
161
+
162
+ def correct_bboxes(bboxes, width, height):
163
+ """Crop boxes that are too big and get coordinates
164
+ with respect to cutouts.
165
+
166
+ Arguments:
167
+ bboxes: a float numpy array of shape [n, 5],
168
+ where each row is (xmin, ymin, xmax, ymax, score).
169
+ width: a float number.
170
+ height: a float number.
171
+
172
+ Returns:
173
+ dy, dx, edy, edx: a int numpy arrays of shape [n],
174
+ coordinates of the boxes with respect to the cutouts.
175
+ y, x, ey, ex: a int numpy arrays of shape [n],
176
+ corrected ymin, xmin, ymax, xmax.
177
+ h, w: a int numpy arrays of shape [n],
178
+ just heights and widths of boxes.
179
+
180
+ in the following order:
181
+ [dy, edy, dx, edx, y, ey, x, ex, w, h].
182
+ """
183
+
184
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
185
+ w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
186
+ num_boxes = bboxes.shape[0]
187
+
188
+ # 'e' stands for end
189
+ # (x, y) -> (ex, ey)
190
+ x, y, ex, ey = x1, y1, x2, y2
191
+
192
+ # we need to cut out a box from the image.
193
+ # (x, y, ex, ey) are corrected coordinates of the box
194
+ # in the image.
195
+ # (dx, dy, edx, edy) are coordinates of the box in the cutout
196
+ # from the image.
197
+ dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
198
+ edx, edy = w.copy() - 1.0, h.copy() - 1.0
199
+
200
+ # if box's bottom right corner is too far right
201
+ ind = np.where(ex > width - 1.0)[0]
202
+ edx[ind] = w[ind] + width - 2.0 - ex[ind]
203
+ ex[ind] = width - 1.0
204
+
205
+ # if box's bottom right corner is too low
206
+ ind = np.where(ey > height - 1.0)[0]
207
+ edy[ind] = h[ind] + height - 2.0 - ey[ind]
208
+ ey[ind] = height - 1.0
209
+
210
+ # if box's top left corner is too far left
211
+ ind = np.where(x < 0.0)[0]
212
+ dx[ind] = 0.0 - x[ind]
213
+ x[ind] = 0.0
214
+
215
+ # if box's top left corner is too high
216
+ ind = np.where(y < 0.0)[0]
217
+ dy[ind] = 0.0 - y[ind]
218
+ y[ind] = 0.0
219
+
220
+ return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
221
+ return_list = [i.astype('int32') for i in return_list]
222
+
223
+ return return_list
224
+
225
+
226
+ def _preprocess(img):
227
+ """Preprocessing step before feeding the network.
228
+
229
+ Arguments:
230
+ img: a float numpy array of shape [h, w, c].
231
+
232
+ Returns:
233
+ a float numpy array of shape [1, c, h, w].
234
+ """
235
+ img = img.transpose((2, 0, 1))
236
+ img = np.expand_dims(img, 0)
237
+ img = (img - 127.5) * 0.0078125
238
+ return img
models/mtcnn/mtcnn_pytorch/src/detector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.autograd import Variable
4
+ from .get_nets import PNet, RNet, ONet
5
+ from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from .first_stage import run_first_stage
7
+
8
+
9
+ def detect_faces(image, min_face_size=20.0,
10
+ thresholds=[0.6, 0.7, 0.8],
11
+ nms_thresholds=[0.7, 0.7, 0.7]):
12
+ """
13
+ Arguments:
14
+ image: an instance of PIL.Image.
15
+ min_face_size: a float number.
16
+ thresholds: a list of length 3.
17
+ nms_thresholds: a list of length 3.
18
+
19
+ Returns:
20
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
21
+ bounding boxes and facial landmarks.
22
+ """
23
+
24
+ # LOAD MODELS
25
+ pnet = PNet()
26
+ rnet = RNet()
27
+ onet = ONet()
28
+ onet.eval()
29
+
30
+ # BUILD AN IMAGE PYRAMID
31
+ width, height = image.size
32
+ min_length = min(height, width)
33
+
34
+ min_detection_size = 12
35
+ factor = 0.707 # sqrt(0.5)
36
+
37
+ # scales for scaling the image
38
+ scales = []
39
+
40
+ # scales the image so that
41
+ # minimum size that we can detect equals to
42
+ # minimum face size that we want to detect
43
+ m = min_detection_size / min_face_size
44
+ min_length *= m
45
+
46
+ factor_count = 0
47
+ while min_length > min_detection_size:
48
+ scales.append(m * factor ** factor_count)
49
+ min_length *= factor
50
+ factor_count += 1
51
+
52
+ # STAGE 1
53
+
54
+ # it will be returned
55
+ bounding_boxes = []
56
+
57
+ with torch.no_grad():
58
+ # run P-Net on different scales
59
+ for s in scales:
60
+ boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
61
+ bounding_boxes.append(boxes)
62
+
63
+ # collect boxes (and offsets, and scores) from different scales
64
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
65
+ bounding_boxes = np.vstack(bounding_boxes)
66
+
67
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
68
+ bounding_boxes = bounding_boxes[keep]
69
+
70
+ # use offsets predicted by pnet to transform bounding boxes
71
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
72
+ # shape [n_boxes, 5]
73
+
74
+ bounding_boxes = convert_to_square(bounding_boxes)
75
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
76
+
77
+ # STAGE 2
78
+
79
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
80
+ img_boxes = torch.FloatTensor(img_boxes)
81
+
82
+ output = rnet(img_boxes)
83
+ offsets = output[0].data.numpy() # shape [n_boxes, 4]
84
+ probs = output[1].data.numpy() # shape [n_boxes, 2]
85
+
86
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
87
+ bounding_boxes = bounding_boxes[keep]
88
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
89
+ offsets = offsets[keep]
90
+
91
+ keep = nms(bounding_boxes, nms_thresholds[1])
92
+ bounding_boxes = bounding_boxes[keep]
93
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
94
+ bounding_boxes = convert_to_square(bounding_boxes)
95
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
96
+
97
+ # STAGE 3
98
+
99
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
100
+ if len(img_boxes) == 0:
101
+ return [], []
102
+ img_boxes = torch.FloatTensor(img_boxes)
103
+ output = onet(img_boxes)
104
+ landmarks = output[0].data.numpy() # shape [n_boxes, 10]
105
+ offsets = output[1].data.numpy() # shape [n_boxes, 4]
106
+ probs = output[2].data.numpy() # shape [n_boxes, 2]
107
+
108
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
109
+ bounding_boxes = bounding_boxes[keep]
110
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
111
+ offsets = offsets[keep]
112
+ landmarks = landmarks[keep]
113
+
114
+ # compute landmark points
115
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
116
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
117
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
118
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
119
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
120
+
121
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
122
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
123
+ bounding_boxes = bounding_boxes[keep]
124
+ landmarks = landmarks[keep]
125
+
126
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/src/first_stage.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import math
4
+ from PIL import Image
5
+ import numpy as np
6
+ from .box_utils import nms, _preprocess
7
+
8
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9
+ device = 'cuda:0'
10
+
11
+
12
+ def run_first_stage(image, net, scale, threshold):
13
+ """Run P-Net, generate bounding boxes, and do NMS.
14
+
15
+ Arguments:
16
+ image: an instance of PIL.Image.
17
+ net: an instance of pytorch's nn.Module, P-Net.
18
+ scale: a float number,
19
+ scale width and height of the image by this number.
20
+ threshold: a float number,
21
+ threshold on the probability of a face when generating
22
+ bounding boxes from predictions of the net.
23
+
24
+ Returns:
25
+ a float numpy array of shape [n_boxes, 9],
26
+ bounding boxes with scores and offsets (4 + 1 + 4).
27
+ """
28
+
29
+ # scale the image and convert it to a float array
30
+ width, height = image.size
31
+ sw, sh = math.ceil(width * scale), math.ceil(height * scale)
32
+ img = image.resize((sw, sh), Image.BILINEAR)
33
+ img = np.asarray(img, 'float32')
34
+
35
+ img = torch.FloatTensor(_preprocess(img)).to(device)
36
+ with torch.no_grad():
37
+ output = net(img)
38
+ probs = output[1].cpu().data.numpy()[0, 1, :, :]
39
+ offsets = output[0].cpu().data.numpy()
40
+ # probs: probability of a face at each sliding window
41
+ # offsets: transformations to true bounding boxes
42
+
43
+ boxes = _generate_bboxes(probs, offsets, scale, threshold)
44
+ if len(boxes) == 0:
45
+ return None
46
+
47
+ keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
48
+ return boxes[keep]
49
+
50
+
51
+ def _generate_bboxes(probs, offsets, scale, threshold):
52
+ """Generate bounding boxes at places
53
+ where there is probably a face.
54
+
55
+ Arguments:
56
+ probs: a float numpy array of shape [n, m].
57
+ offsets: a float numpy array of shape [1, 4, n, m].
58
+ scale: a float number,
59
+ width and height of the image were scaled by this number.
60
+ threshold: a float number.
61
+
62
+ Returns:
63
+ a float numpy array of shape [n_boxes, 9]
64
+ """
65
+
66
+ # applying P-Net is equivalent, in some sense, to
67
+ # moving 12x12 window with stride 2
68
+ stride = 2
69
+ cell_size = 12
70
+
71
+ # indices of boxes where there is probably a face
72
+ inds = np.where(probs > threshold)
73
+
74
+ if inds[0].size == 0:
75
+ return np.array([])
76
+
77
+ # transformations of bounding boxes
78
+ tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
79
+ # they are defined as:
80
+ # w = x2 - x1 + 1
81
+ # h = y2 - y1 + 1
82
+ # x1_true = x1 + tx1*w
83
+ # x2_true = x2 + tx2*w
84
+ # y1_true = y1 + ty1*h
85
+ # y2_true = y2 + ty2*h
86
+
87
+ offsets = np.array([tx1, ty1, tx2, ty2])
88
+ score = probs[inds[0], inds[1]]
89
+
90
+ # P-Net is applied to scaled images
91
+ # so we need to rescale bounding boxes back
92
+ bounding_boxes = np.vstack([
93
+ np.round((stride * inds[1] + 1.0) / scale),
94
+ np.round((stride * inds[0] + 1.0) / scale),
95
+ np.round((stride * inds[1] + 1.0 + cell_size) / scale),
96
+ np.round((stride * inds[0] + 1.0 + cell_size) / scale),
97
+ score, offsets
98
+ ])
99
+ # why one is added?
100
+
101
+ return bounding_boxes.T
models/mtcnn/mtcnn_pytorch/src/get_nets.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+
7
+ from configs.paths_config import model_paths
8
+ PNET_PATH = model_paths["mtcnn_pnet"]
9
+ ONET_PATH = model_paths["mtcnn_onet"]
10
+ RNET_PATH = model_paths["mtcnn_rnet"]
11
+
12
+
13
+ class Flatten(nn.Module):
14
+
15
+ def __init__(self):
16
+ super(Flatten, self).__init__()
17
+
18
+ def forward(self, x):
19
+ """
20
+ Arguments:
21
+ x: a float tensor with shape [batch_size, c, h, w].
22
+ Returns:
23
+ a float tensor with shape [batch_size, c*h*w].
24
+ """
25
+
26
+ # without this pretrained model isn't working
27
+ x = x.transpose(3, 2).contiguous()
28
+
29
+ return x.view(x.size(0), -1)
30
+
31
+
32
+ class PNet(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ # suppose we have input with size HxW, then
38
+ # after first layer: H - 2,
39
+ # after pool: ceil((H - 2)/2),
40
+ # after second conv: ceil((H - 2)/2) - 2,
41
+ # after last conv: ceil((H - 2)/2) - 4,
42
+ # and the same for W
43
+
44
+ self.features = nn.Sequential(OrderedDict([
45
+ ('conv1', nn.Conv2d(3, 10, 3, 1)),
46
+ ('prelu1', nn.PReLU(10)),
47
+ ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
48
+
49
+ ('conv2', nn.Conv2d(10, 16, 3, 1)),
50
+ ('prelu2', nn.PReLU(16)),
51
+
52
+ ('conv3', nn.Conv2d(16, 32, 3, 1)),
53
+ ('prelu3', nn.PReLU(32))
54
+ ]))
55
+
56
+ self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
57
+ self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
58
+
59
+ weights = np.load(PNET_PATH, allow_pickle=True)[()]
60
+ for n, p in self.named_parameters():
61
+ p.data = torch.FloatTensor(weights[n])
62
+
63
+ def forward(self, x):
64
+ """
65
+ Arguments:
66
+ x: a float tensor with shape [batch_size, 3, h, w].
67
+ Returns:
68
+ b: a float tensor with shape [batch_size, 4, h', w'].
69
+ a: a float tensor with shape [batch_size, 2, h', w'].
70
+ """
71
+ x = self.features(x)
72
+ a = self.conv4_1(x)
73
+ b = self.conv4_2(x)
74
+ a = F.softmax(a, dim=-1)
75
+ return b, a
76
+
77
+
78
+ class RNet(nn.Module):
79
+
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ self.features = nn.Sequential(OrderedDict([
84
+ ('conv1', nn.Conv2d(3, 28, 3, 1)),
85
+ ('prelu1', nn.PReLU(28)),
86
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
87
+
88
+ ('conv2', nn.Conv2d(28, 48, 3, 1)),
89
+ ('prelu2', nn.PReLU(48)),
90
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
91
+
92
+ ('conv3', nn.Conv2d(48, 64, 2, 1)),
93
+ ('prelu3', nn.PReLU(64)),
94
+
95
+ ('flatten', Flatten()),
96
+ ('conv4', nn.Linear(576, 128)),
97
+ ('prelu4', nn.PReLU(128))
98
+ ]))
99
+
100
+ self.conv5_1 = nn.Linear(128, 2)
101
+ self.conv5_2 = nn.Linear(128, 4)
102
+
103
+ weights = np.load(RNET_PATH, allow_pickle=True)[()]
104
+ for n, p in self.named_parameters():
105
+ p.data = torch.FloatTensor(weights[n])
106
+
107
+ def forward(self, x):
108
+ """
109
+ Arguments:
110
+ x: a float tensor with shape [batch_size, 3, h, w].
111
+ Returns:
112
+ b: a float tensor with shape [batch_size, 4].
113
+ a: a float tensor with shape [batch_size, 2].
114
+ """
115
+ x = self.features(x)
116
+ a = self.conv5_1(x)
117
+ b = self.conv5_2(x)
118
+ a = F.softmax(a, dim=-1)
119
+ return b, a
120
+
121
+
122
+ class ONet(nn.Module):
123
+
124
+ def __init__(self):
125
+ super().__init__()
126
+
127
+ self.features = nn.Sequential(OrderedDict([
128
+ ('conv1', nn.Conv2d(3, 32, 3, 1)),
129
+ ('prelu1', nn.PReLU(32)),
130
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
131
+
132
+ ('conv2', nn.Conv2d(32, 64, 3, 1)),
133
+ ('prelu2', nn.PReLU(64)),
134
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
135
+
136
+ ('conv3', nn.Conv2d(64, 64, 3, 1)),
137
+ ('prelu3', nn.PReLU(64)),
138
+ ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
139
+
140
+ ('conv4', nn.Conv2d(64, 128, 2, 1)),
141
+ ('prelu4', nn.PReLU(128)),
142
+
143
+ ('flatten', Flatten()),
144
+ ('conv5', nn.Linear(1152, 256)),
145
+ ('drop5', nn.Dropout(0.25)),
146
+ ('prelu5', nn.PReLU(256)),
147
+ ]))
148
+
149
+ self.conv6_1 = nn.Linear(256, 2)
150
+ self.conv6_2 = nn.Linear(256, 4)
151
+ self.conv6_3 = nn.Linear(256, 10)
152
+
153
+ weights = np.load(ONET_PATH, allow_pickle=True)[()]
154
+ for n, p in self.named_parameters():
155
+ p.data = torch.FloatTensor(weights[n])
156
+
157
+ def forward(self, x):
158
+ """
159
+ Arguments:
160
+ x: a float tensor with shape [batch_size, 3, h, w].
161
+ Returns:
162
+ c: a float tensor with shape [batch_size, 10].
163
+ b: a float tensor with shape [batch_size, 4].
164
+ a: a float tensor with shape [batch_size, 2].
165
+ """
166
+ x = self.features(x)
167
+ a = self.conv6_1(x)
168
+ b = self.conv6_2(x)
169
+ c = self.conv6_3(x)
170
+ a = F.softmax(a, dim=-1)
171
+ return c, b, a
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 11 06:54:28 2017
4
+
5
+ @author: zhaoyafei
6
+ """
7
+
8
+ import numpy as np
9
+ from numpy.linalg import inv, norm, lstsq
10
+ from numpy.linalg import matrix_rank as rank
11
+
12
+
13
+ class MatlabCp2tormException(Exception):
14
+ def __str__(self):
15
+ return 'In File {}:{}'.format(
16
+ __file__, super.__str__(self))
17
+
18
+
19
+ def tformfwd(trans, uv):
20
+ """
21
+ Function:
22
+ ----------
23
+ apply affine transform 'trans' to uv
24
+
25
+ Parameters:
26
+ ----------
27
+ @trans: 3x3 np.array
28
+ transform matrix
29
+ @uv: Kx2 np.array
30
+ each row is a pair of coordinates (x, y)
31
+
32
+ Returns:
33
+ ----------
34
+ @xy: Kx2 np.array
35
+ each row is a pair of transformed coordinates (x, y)
36
+ """
37
+ uv = np.hstack((
38
+ uv, np.ones((uv.shape[0], 1))
39
+ ))
40
+ xy = np.dot(uv, trans)
41
+ xy = xy[:, 0:-1]
42
+ return xy
43
+
44
+
45
+ def tforminv(trans, uv):
46
+ """
47
+ Function:
48
+ ----------
49
+ apply the inverse of affine transform 'trans' to uv
50
+
51
+ Parameters:
52
+ ----------
53
+ @trans: 3x3 np.array
54
+ transform matrix
55
+ @uv: Kx2 np.array
56
+ each row is a pair of coordinates (x, y)
57
+
58
+ Returns:
59
+ ----------
60
+ @xy: Kx2 np.array
61
+ each row is a pair of inverse-transformed coordinates (x, y)
62
+ """
63
+ Tinv = inv(trans)
64
+ xy = tformfwd(Tinv, uv)
65
+ return xy
66
+
67
+
68
+ def findNonreflectiveSimilarity(uv, xy, options=None):
69
+ options = {'K': 2}
70
+
71
+ K = options['K']
72
+ M = xy.shape[0]
73
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
74
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
75
+ # print('--->x, y:\n', x, y
76
+
77
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
78
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
79
+ X = np.vstack((tmp1, tmp2))
80
+ # print('--->X.shape: ', X.shape
81
+ # print('X:\n', X
82
+
83
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
84
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
85
+ U = np.vstack((u, v))
86
+ # print('--->U.shape: ', U.shape
87
+ # print('U:\n', U
88
+
89
+ # We know that X * r = U
90
+ if rank(X) >= 2 * K:
91
+ r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
92
+ r = np.squeeze(r)
93
+ else:
94
+ raise Exception('cp2tform:twoUniquePointsReq')
95
+
96
+ # print('--->r:\n', r
97
+
98
+ sc = r[0]
99
+ ss = r[1]
100
+ tx = r[2]
101
+ ty = r[3]
102
+
103
+ Tinv = np.array([
104
+ [sc, -ss, 0],
105
+ [ss, sc, 0],
106
+ [tx, ty, 1]
107
+ ])
108
+
109
+ # print('--->Tinv:\n', Tinv
110
+
111
+ T = inv(Tinv)
112
+ # print('--->T:\n', T
113
+
114
+ T[:, 2] = np.array([0, 0, 1])
115
+
116
+ return T, Tinv
117
+
118
+
119
+ def findSimilarity(uv, xy, options=None):
120
+ options = {'K': 2}
121
+
122
+ # uv = np.array(uv)
123
+ # xy = np.array(xy)
124
+
125
+ # Solve for trans1
126
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
127
+
128
+ # Solve for trans2
129
+
130
+ # manually reflect the xy data across the Y-axis
131
+ xyR = xy
132
+ xyR[:, 0] = -1 * xyR[:, 0]
133
+
134
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
135
+
136
+ # manually reflect the tform to undo the reflection done on xyR
137
+ TreflectY = np.array([
138
+ [-1, 0, 0],
139
+ [0, 1, 0],
140
+ [0, 0, 1]
141
+ ])
142
+
143
+ trans2 = np.dot(trans2r, TreflectY)
144
+
145
+ # Figure out if trans1 or trans2 is better
146
+ xy1 = tformfwd(trans1, uv)
147
+ norm1 = norm(xy1 - xy)
148
+
149
+ xy2 = tformfwd(trans2, uv)
150
+ norm2 = norm(xy2 - xy)
151
+
152
+ if norm1 <= norm2:
153
+ return trans1, trans1_inv
154
+ else:
155
+ trans2_inv = inv(trans2)
156
+ return trans2, trans2_inv
157
+
158
+
159
+ def get_similarity_transform(src_pts, dst_pts, reflective=True):
160
+ """
161
+ Function:
162
+ ----------
163
+ Find Similarity Transform Matrix 'trans':
164
+ u = src_pts[:, 0]
165
+ v = src_pts[:, 1]
166
+ x = dst_pts[:, 0]
167
+ y = dst_pts[:, 1]
168
+ [x, y, 1] = [u, v, 1] * trans
169
+
170
+ Parameters:
171
+ ----------
172
+ @src_pts: Kx2 np.array
173
+ source points, each row is a pair of coordinates (x, y)
174
+ @dst_pts: Kx2 np.array
175
+ destination points, each row is a pair of transformed
176
+ coordinates (x, y)
177
+ @reflective: True or False
178
+ if True:
179
+ use reflective similarity transform
180
+ else:
181
+ use non-reflective similarity transform
182
+
183
+ Returns:
184
+ ----------
185
+ @trans: 3x3 np.array
186
+ transform matrix from uv to xy
187
+ trans_inv: 3x3 np.array
188
+ inverse of trans, transform matrix from xy to uv
189
+ """
190
+
191
+ if reflective:
192
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
193
+ else:
194
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
195
+
196
+ return trans, trans_inv
197
+
198
+
199
+ def cvt_tform_mat_for_cv2(trans):
200
+ """
201
+ Function:
202
+ ----------
203
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
204
+ directly used by cv2.warpAffine():
205
+ u = src_pts[:, 0]
206
+ v = src_pts[:, 1]
207
+ x = dst_pts[:, 0]
208
+ y = dst_pts[:, 1]
209
+ [x, y].T = cv_trans * [u, v, 1].T
210
+
211
+ Parameters:
212
+ ----------
213
+ @trans: 3x3 np.array
214
+ transform matrix from uv to xy
215
+
216
+ Returns:
217
+ ----------
218
+ @cv2_trans: 2x3 np.array
219
+ transform matrix from src_pts to dst_pts, could be directly used
220
+ for cv2.warpAffine()
221
+ """
222
+ cv2_trans = trans[:, 0:2].T
223
+
224
+ return cv2_trans
225
+
226
+
227
+ def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
228
+ """
229
+ Function:
230
+ ----------
231
+ Find Similarity Transform Matrix 'cv2_trans' which could be
232
+ directly used by cv2.warpAffine():
233
+ u = src_pts[:, 0]
234
+ v = src_pts[:, 1]
235
+ x = dst_pts[:, 0]
236
+ y = dst_pts[:, 1]
237
+ [x, y].T = cv_trans * [u, v, 1].T
238
+
239
+ Parameters:
240
+ ----------
241
+ @src_pts: Kx2 np.array
242
+ source points, each row is a pair of coordinates (x, y)
243
+ @dst_pts: Kx2 np.array
244
+ destination points, each row is a pair of transformed
245
+ coordinates (x, y)
246
+ reflective: True or False
247
+ if True:
248
+ use reflective similarity transform
249
+ else:
250
+ use non-reflective similarity transform
251
+
252
+ Returns:
253
+ ----------
254
+ @cv2_trans: 2x3 np.array
255
+ transform matrix from src_pts to dst_pts, could be directly used
256
+ for cv2.warpAffine()
257
+ """
258
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
259
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
260
+
261
+ return cv2_trans
262
+
263
+
264
+ if __name__ == '__main__':
265
+ """
266
+ u = [0, 6, -2]
267
+ v = [0, 3, 5]
268
+ x = [-1, 0, 4]
269
+ y = [-1, -10, 4]
270
+
271
+ # In Matlab, run:
272
+ #
273
+ # uv = [u'; v'];
274
+ # xy = [x'; y'];
275
+ # tform_sim=cp2tform(uv,xy,'similarity');
276
+ #
277
+ # trans = tform_sim.tdata.T
278
+ # ans =
279
+ # -0.0764 -1.6190 0
280
+ # 1.6190 -0.0764 0
281
+ # -3.2156 0.0290 1.0000
282
+ # trans_inv = tform_sim.tdata.Tinv
283
+ # ans =
284
+ #
285
+ # -0.0291 0.6163 0
286
+ # -0.6163 -0.0291 0
287
+ # -0.0756 1.9826 1.0000
288
+ # xy_m=tformfwd(tform_sim, u,v)
289
+ #
290
+ # xy_m =
291
+ #
292
+ # -3.2156 0.0290
293
+ # 1.1833 -9.9143
294
+ # 5.0323 2.8853
295
+ # uv_m=tforminv(tform_sim, x,y)
296
+ #
297
+ # uv_m =
298
+ #
299
+ # 0.5698 1.3953
300
+ # 6.0872 2.2733
301
+ # -2.6570 4.3314
302
+ """
303
+ u = [0, 6, -2]
304
+ v = [0, 3, 5]
305
+ x = [-1, 0, 4]
306
+ y = [-1, -10, 4]
307
+
308
+ uv = np.array((u, v)).T
309
+ xy = np.array((x, y)).T
310
+
311
+ print('\n--->uv:')
312
+ print(uv)
313
+ print('\n--->xy:')
314
+ print(xy)
315
+
316
+ trans, trans_inv = get_similarity_transform(uv, xy)
317
+
318
+ print('\n--->trans matrix:')
319
+ print(trans)
320
+
321
+ print('\n--->trans_inv matrix:')
322
+ print(trans_inv)
323
+
324
+ print('\n---> apply transform to uv')
325
+ print('\nxy_m = uv_augmented * trans')
326
+ uv_aug = np.hstack((
327
+ uv, np.ones((uv.shape[0], 1))
328
+ ))
329
+ xy_m = np.dot(uv_aug, trans)
330
+ print(xy_m)
331
+
332
+ print('\nxy_m = tformfwd(trans, uv)')
333
+ xy_m = tformfwd(trans, uv)
334
+ print(xy_m)
335
+
336
+ print('\n---> apply inverse transform to xy')
337
+ print('\nuv_m = xy_augmented * trans_inv')
338
+ xy_aug = np.hstack((
339
+ xy, np.ones((xy.shape[0], 1))
340
+ ))
341
+ uv_m = np.dot(xy_aug, trans_inv)
342
+ print(uv_m)
343
+
344
+ print('\nuv_m = tformfwd(trans_inv, xy)')
345
+ uv_m = tformfwd(trans_inv, xy)
346
+ print(uv_m)
347
+
348
+ uv_m = tforminv(trans, xy)
349
+ print('\nuv_m = tforminv(trans, xy)')
350
+ print(uv_m)
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import ImageDraw
2
+
3
+
4
+ def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
5
+ """Draw bounding boxes and facial landmarks.
6
+
7
+ Arguments:
8
+ img: an instance of PIL.Image.
9
+ bounding_boxes: a float numpy array of shape [n, 5].
10
+ facial_landmarks: a float numpy array of shape [n, 10].
11
+
12
+ Returns:
13
+ an instance of PIL.Image.
14
+ """
15
+
16
+ img_copy = img.copy()
17
+ draw = ImageDraw.Draw(img_copy)
18
+
19
+ for b in bounding_boxes:
20
+ draw.rectangle([
21
+ (b[0], b[1]), (b[2], b[3])
22
+ ], outline='white')
23
+
24
+ for p in facial_landmarks:
25
+ for i in range(5):
26
+ draw.ellipse([
27
+ (p[i] - 1.0, p[i + 5] - 1.0),
28
+ (p[i] + 1.0, p[i + 5] + 1.0)
29
+ ], outline='blue')
30
+
31
+ return img_copy
models/mtcnn/mtcnn_pytorch/src/weights/onet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:313141c3646bebb73cb8350a2d5fee4c7f044fb96304b46ccc21aeea8b818f83
3
+ size 2345483
models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03e19e5c473932ab38f5a6308fe6210624006994a687e858d1dcda53c66f18cb
3
+ size 41271
models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5660aad67688edc9e8a3dd4e47ed120932835e06a8a711a423252a6f2c747083
3
+ size 604651
models/psp.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines the core research contribution
3
+ """
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ from models.encoders import psp_encoders
11
+ from models.stylegan2.model import Generator
12
+ from configs.paths_config import model_paths
13
+ import torch.nn.functional as F
14
+
15
+ def get_keys(d, name):
16
+ if 'state_dict' in d:
17
+ d = d['state_dict']
18
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
19
+ return d_filt
20
+
21
+
22
+ class pSp(nn.Module):
23
+
24
+ def __init__(self, opts, ckpt=None):
25
+ super(pSp, self).__init__()
26
+ self.set_opts(opts)
27
+ # compute number of style inputs based on the output resolution
28
+ self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
29
+ # Define architecture
30
+ self.encoder = self.set_encoder()
31
+ self.decoder = Generator(self.opts.output_size, 512, 8)
32
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
33
+ # Load weights if needed
34
+ self.load_weights(ckpt)
35
+
36
+ def set_encoder(self):
37
+ if self.opts.encoder_type == 'GradualStyleEncoder':
38
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
39
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
40
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
41
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
42
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
43
+ else:
44
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
45
+ return encoder
46
+
47
+ def load_weights(self, ckpt=None):
48
+ if self.opts.checkpoint_path is not None:
49
+ print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
50
+ if ckpt is None:
51
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
52
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
53
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
54
+ self.__load_latent_avg(ckpt)
55
+ else:
56
+ print('Loading encoders weights from irse50!')
57
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
58
+ # if input to encoder is not an RGB image, do not load the input layer weights
59
+ if self.opts.label_nc != 0:
60
+ encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
61
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
62
+ print('Loading decoder weights from pretrained!')
63
+ ckpt = torch.load(self.opts.stylegan_weights)
64
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
65
+ if self.opts.learn_in_w:
66
+ self.__load_latent_avg(ckpt, repeat=1)
67
+ else:
68
+ self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
69
+ # for video toonification, we load G0' model
70
+ if self.opts.toonify_weights is not None: ##### modified
71
+ ckpt = torch.load(self.opts.toonify_weights)
72
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
73
+ self.opts.toonify_weights = None
74
+
75
+ # x1: image for first-layer feature f.
76
+ # x2: image for style latent code w+. If not specified, x2=x1.
77
+ # inject_latent: for sketch/mask-to-face translation, another latent code to fuse with w+
78
+ # latent_mask: fuse w+ and inject_latent with the mask (1~7 use w+ and 8~18 use inject_latent)
79
+ # use_feature: use f. Otherwise, use the orginal StyleGAN first-layer constant 4*4 feature
80
+ # first_layer_feature_ind: always=0, means the 1st layer of G accept f
81
+ # use_skip: use skip connection.
82
+ # zero_noise: use zero noises.
83
+ # editing_w: the editing vector v for video face editing
84
+ def forward(self, x1, x2=None, resize=True, latent_mask=None, randomize_noise=True,
85
+ inject_latent=None, return_latents=False, alpha=None, use_feature=True,
86
+ first_layer_feature_ind=0, use_skip=False, zero_noise=False, editing_w=None): ##### modified
87
+
88
+ feats = None # f and the skipped encoder features
89
+ codes, feats = self.encoder(x1, return_feat=True, return_full=use_skip) ##### modified
90
+ if x2 is not None: ##### modified
91
+ codes = self.encoder(x2) ##### modified
92
+ # normalize with respect to the center of an average face
93
+ if self.opts.start_from_latent_avg:
94
+ if self.opts.learn_in_w:
95
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
96
+ else:
97
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
98
+
99
+ # E_W^{1:7}(T(x1)) concatenate E_W^{8:18}(w~)
100
+ if latent_mask is not None:
101
+ for i in latent_mask:
102
+ if inject_latent is not None:
103
+ if alpha is not None:
104
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
105
+ else:
106
+ codes[:, i] = inject_latent[:, i]
107
+ else:
108
+ codes[:, i] = 0
109
+
110
+ first_layer_feats, skip_layer_feats, fusion = None, None, None ##### modified
111
+ if use_feature: ##### modified
112
+ first_layer_feats = feats[0:2] # use f
113
+ if use_skip: ##### modified
114
+ skip_layer_feats = feats[2:] # use skipped encoder feature
115
+ fusion = self.encoder.fusion # use fusion layer to fuse encoder feature and decoder feature.
116
+
117
+ images, result_latent = self.decoder([codes],
118
+ input_is_latent=True,
119
+ randomize_noise=randomize_noise,
120
+ return_latents=return_latents,
121
+ first_layer_feature=first_layer_feats,
122
+ first_layer_feature_ind=first_layer_feature_ind,
123
+ skip_layer_feature=skip_layer_feats,
124
+ fusion_block=fusion,
125
+ zero_noise=zero_noise,
126
+ editing_w=editing_w) ##### modified
127
+
128
+ if resize:
129
+ if self.opts.output_size == 1024: ##### modified
130
+ images = F.adaptive_avg_pool2d(images, (images.shape[2]//4, images.shape[3]//4)) ##### modified
131
+ else:
132
+ images = self.face_pool(images)
133
+
134
+ if return_latents:
135
+ return images, result_latent
136
+ else:
137
+ return images
138
+
139
+ def set_opts(self, opts):
140
+ self.opts = opts
141
+
142
+ def __load_latent_avg(self, ckpt, repeat=None):
143
+ if 'latent_avg' in ckpt:
144
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
145
+ if repeat is not None:
146
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
147
+ else:
148
+ self.latent_avg = None
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/lpips/__init__.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ #from skimage.measure import compare_ssim
8
+ from skimage.metrics import structural_similarity as compare_ssim
9
+ import torch
10
+ from torch.autograd import Variable
11
+
12
+ from models.stylegan2.lpips import dist_model
13
+
14
+ class PerceptualLoss(torch.nn.Module):
15
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
16
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
17
+ super(PerceptualLoss, self).__init__()
18
+ print('Setting up Perceptual loss...')
19
+ self.use_gpu = use_gpu
20
+ self.spatial = spatial
21
+ self.gpu_ids = gpu_ids
22
+ self.model = dist_model.DistModel()
23
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
24
+ print('...[%s] initialized'%self.model.name())
25
+ print('...Done')
26
+
27
+ def forward(self, pred, target, normalize=False):
28
+ """
29
+ Pred and target are Variables.
30
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
31
+ If normalize is False, assumes the images are already between [-1,+1]
32
+
33
+ Inputs pred and target are Nx3xHxW
34
+ Output pytorch Variable N long
35
+ """
36
+
37
+ if normalize:
38
+ target = 2 * target - 1
39
+ pred = 2 * pred - 1
40
+
41
+ return self.model.forward(target, pred)
42
+
43
+ def normalize_tensor(in_feat,eps=1e-10):
44
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
45
+ return in_feat/(norm_factor+eps)
46
+
47
+ def l2(p0, p1, range=255.):
48
+ return .5*np.mean((p0 / range - p1 / range)**2)
49
+
50
+ def psnr(p0, p1, peak=255.):
51
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
52
+
53
+ def dssim(p0, p1, range=255.):
54
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
55
+
56
+ def rgb2lab(in_img,mean_cent=False):
57
+ from skimage import color
58
+ img_lab = color.rgb2lab(in_img)
59
+ if(mean_cent):
60
+ img_lab[:,:,0] = img_lab[:,:,0]-50
61
+ return img_lab
62
+
63
+ def tensor2np(tensor_obj):
64
+ # change dimension of a tensor object into a numpy array
65
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66
+
67
+ def np2tensor(np_obj):
68
+ # change dimenion of np array into tensor array
69
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70
+
71
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72
+ # image tensor to lab tensor
73
+ from skimage import color
74
+
75
+ img = tensor2im(image_tensor)
76
+ img_lab = color.rgb2lab(img)
77
+ if(mc_only):
78
+ img_lab[:,:,0] = img_lab[:,:,0]-50
79
+ if(to_norm and not mc_only):
80
+ img_lab[:,:,0] = img_lab[:,:,0]-50
81
+ img_lab = img_lab/100.
82
+
83
+ return np2tensor(img_lab)
84
+
85
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
86
+ from skimage import color
87
+ import warnings
88
+ warnings.filterwarnings("ignore")
89
+
90
+ lab = tensor2np(lab_tensor)*100.
91
+ lab[:,:,0] = lab[:,:,0]+50
92
+
93
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94
+ if(return_inbnd):
95
+ # convert back to lab, see if we match
96
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
98
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99
+ return (im2tensor(rgb_back),mask)
100
+ else:
101
+ return im2tensor(rgb_back)
102
+
103
+ def rgb2lab(input):
104
+ from skimage import color
105
+ return color.rgb2lab(input / 255.)
106
+
107
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
108
+ image_numpy = image_tensor[0].cpu().float().numpy()
109
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
110
+ return image_numpy.astype(imtype)
111
+
112
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
113
+ return torch.Tensor((image / factor - cent)
114
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
115
+
116
+ def tensor2vec(vector_tensor):
117
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
118
+
119
+ def voc_ap(rec, prec, use_07_metric=False):
120
+ """ ap = voc_ap(rec, prec, [use_07_metric])
121
+ Compute VOC AP given precision and recall.
122
+ If use_07_metric is true, uses the
123
+ VOC 07 11 point method (default:False).
124
+ """
125
+ if use_07_metric:
126
+ # 11 point metric
127
+ ap = 0.
128
+ for t in np.arange(0., 1.1, 0.1):
129
+ if np.sum(rec >= t) == 0:
130
+ p = 0
131
+ else:
132
+ p = np.max(prec[rec >= t])
133
+ ap = ap + p / 11.
134
+ else:
135
+ # correct AP calculation
136
+ # first append sentinel values at the end
137
+ mrec = np.concatenate(([0.], rec, [1.]))
138
+ mpre = np.concatenate(([0.], prec, [0.]))
139
+
140
+ # compute the precision envelope
141
+ for i in range(mpre.size - 1, 0, -1):
142
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
143
+
144
+ # to calculate area under PR curve, look for points
145
+ # where X axis (recall) changes value
146
+ i = np.where(mrec[1:] != mrec[:-1])[0]
147
+
148
+ # and sum (\Delta recall) * prec
149
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
150
+ return ap
151
+
152
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
153
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
154
+ image_numpy = image_tensor[0].cpu().float().numpy()
155
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
156
+ return image_numpy.astype(imtype)
157
+
158
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
159
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
160
+ return torch.Tensor((image / factor - cent)
161
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
models/stylegan2/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
models/stylegan2/lpips/dist_model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from models.stylegan2.lpips.base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+
19
+ from IPython import embed
20
+
21
+ from models.stylegan2.lpips import networks_basic as networks
22
+ import models.stylegan2.lpips as util
23
+
24
+ class DistModel(BaseModel):
25
+ def name(self):
26
+ return self.model_name
27
+
28
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29
+ use_gpu=True, printNet=False, spatial=False,
30
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31
+ '''
32
+ INPUTS
33
+ model - ['net-lin'] for linearly calibrated network
34
+ ['net'] for off-the-shelf network
35
+ ['L2'] for L2 distance in Lab colorspace
36
+ ['SSIM'] for ssim in RGB colorspace
37
+ net - ['squeeze','alex','vgg']
38
+ model_path - if None, will look in weights/[NET_NAME].pth
39
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40
+ use_gpu - bool - whether or not to use a GPU
41
+ printNet - bool - whether or not to print network architecture out
42
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
43
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46
+ is_train - bool - [True] for training mode
47
+ lr - float - initial learning rate
48
+ beta1 - float - initial momentum term for adam
49
+ version - 0.1 for latest, 0.0 was original (with a bug)
50
+ gpu_ids - int array - [0] by default, gpus to use
51
+ '''
52
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53
+
54
+ self.model = model
55
+ self.net = net
56
+ self.is_train = is_train
57
+ self.spatial = spatial
58
+ self.gpu_ids = gpu_ids
59
+ self.model_name = '%s [%s]'%(model,net)
60
+
61
+ if(self.model == 'net-lin'): # pretrained net + linear layer
62
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
64
+ kw = {}
65
+ if not use_gpu:
66
+ kw['map_location'] = 'cpu'
67
+ if(model_path is None):
68
+ import inspect
69
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70
+
71
+ if(not is_train):
72
+ print('Loading model from: %s'%model_path)
73
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74
+
75
+ elif(self.model=='net'): # pretrained network
76
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77
+ elif(self.model in ['L2','l2']):
78
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79
+ self.model_name = 'L2'
80
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82
+ self.model_name = 'SSIM'
83
+ else:
84
+ raise ValueError("Model [%s] not recognized." % self.model)
85
+
86
+ self.parameters = list(self.net.parameters())
87
+
88
+ if self.is_train: # training mode
89
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
+ self.rankLoss = networks.BCERankingLoss()
91
+ self.parameters += list(self.rankLoss.net.parameters())
92
+ self.lr = lr
93
+ self.old_lr = lr
94
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95
+ else: # test mode
96
+ self.net.eval()
97
+
98
+ if(use_gpu):
99
+ self.net.to(gpu_ids[0])
100
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101
+ if(self.is_train):
102
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103
+
104
+ if(printNet):
105
+ print('---------- Networks initialized -------------')
106
+ networks.print_network(self.net)
107
+ print('-----------------------------------------------')
108
+
109
+ def forward(self, in0, in1, retPerLayer=False):
110
+ ''' Function computes the distance between image patches in0 and in1
111
+ INPUTS
112
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113
+ OUTPUT
114
+ computed distances between in0 and in1
115
+ '''
116
+
117
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118
+
119
+ # ***** TRAINING FUNCTIONS *****
120
+ def optimize_parameters(self):
121
+ self.forward_train()
122
+ self.optimizer_net.zero_grad()
123
+ self.backward_train()
124
+ self.optimizer_net.step()
125
+ self.clamp_weights()
126
+
127
+ def clamp_weights(self):
128
+ for module in self.net.modules():
129
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130
+ module.weight.data = torch.clamp(module.weight.data,min=0)
131
+
132
+ def set_input(self, data):
133
+ self.input_ref = data['ref']
134
+ self.input_p0 = data['p0']
135
+ self.input_p1 = data['p1']
136
+ self.input_judge = data['judge']
137
+
138
+ if(self.use_gpu):
139
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143
+
144
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
145
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
146
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
147
+
148
+ def forward_train(self): # run forward pass
149
+ # print(self.net.module.scaling_layer.shift)
150
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151
+
152
+ self.d0 = self.forward(self.var_ref, self.var_p0)
153
+ self.d1 = self.forward(self.var_ref, self.var_p1)
154
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155
+
156
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157
+
158
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159
+
160
+ return self.loss_total
161
+
162
+ def backward_train(self):
163
+ torch.mean(self.loss_total).backward()
164
+
165
+ def compute_accuracy(self,d0,d1,judge):
166
+ ''' d0, d1 are Variables, judge is a Tensor '''
167
+ d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
168
+ judge_per = judge.cpu().numpy().flatten()
169
+ return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
170
+
171
+ def get_current_errors(self):
172
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
173
+ ('acc_r', self.acc_r)])
174
+
175
+ for key in retDict.keys():
176
+ retDict[key] = np.mean(retDict[key])
177
+
178
+ return retDict
179
+
180
+ def get_current_visuals(self):
181
+ zoom_factor = 256/self.var_ref.data.size()[2]
182
+
183
+ ref_img = util.tensor2im(self.var_ref.data)
184
+ p0_img = util.tensor2im(self.var_p0.data)
185
+ p1_img = util.tensor2im(self.var_p1.data)
186
+
187
+ ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
188
+ p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
189
+ p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
190
+
191
+ return OrderedDict([('ref', ref_img_vis),
192
+ ('p0', p0_img_vis),
193
+ ('p1', p1_img_vis)])
194
+
195
+ def save(self, path, label):
196
+ if(self.use_gpu):
197
+ self.save_network(self.net.module, path, '', label)
198
+ else:
199
+ self.save_network(self.net, path, '', label)
200
+ self.save_network(self.rankLoss.net, path, 'rank', label)
201
+
202
+ def update_learning_rate(self,nepoch_decay):
203
+ lrd = self.lr / nepoch_decay
204
+ lr = self.old_lr - lrd
205
+
206
+ for param_group in self.optimizer_net.param_groups:
207
+ param_group['lr'] = lr
208
+
209
+ print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
210
+ self.old_lr = lr
211
+
212
+ def score_2afc_dataset(data_loader, func, name=''):
213
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
214
+ distance function 'func' in dataset 'data_loader'
215
+ INPUTS
216
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217
+ func - callable distance function - calling d=func(in0,in1) should take 2
218
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219
+ OUTPUTS
220
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221
+ [1] - dictionary with following elements
222
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223
+ gts - N array in [0,1], preferred patch selected by human evaluators
224
+ (closer to "0" for left patch p0, "1" for right patch p1,
225
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
226
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
227
+ CONSTS
228
+ N - number of test triplets in data_loader
229
+ '''
230
+
231
+ d0s = []
232
+ d1s = []
233
+ gts = []
234
+
235
+ for data in tqdm(data_loader.load_data(), desc=name):
236
+ d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237
+ d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
239
+
240
+ d0s = np.array(d0s)
241
+ d1s = np.array(d1s)
242
+ gts = np.array(gts)
243
+ scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
244
+
245
+ return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
246
+
247
+ def score_jnd_dataset(data_loader, func, name=''):
248
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
249
+ INPUTS
250
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
251
+ func - callable distance function - calling d=func(in0,in1) should take 2
252
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
253
+ OUTPUTS
254
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
255
+ [1] - dictionary with following elements
256
+ ds - N array containing distances between two patches shown to human evaluator
257
+ sames - N array containing fraction of people who thought the two patches were identical
258
+ CONSTS
259
+ N - number of test triplets in data_loader
260
+ '''
261
+
262
+ ds = []
263
+ gts = []
264
+
265
+ for data in tqdm(data_loader.load_data(), desc=name):
266
+ ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
267
+ gts+=data['same'].cpu().numpy().flatten().tolist()
268
+
269
+ sames = np.array(gts)
270
+ ds = np.array(ds)
271
+
272
+ sorted_inds = np.argsort(ds)
273
+ ds_sorted = ds[sorted_inds]
274
+ sames_sorted = sames[sorted_inds]
275
+
276
+ TPs = np.cumsum(sames_sorted)
277
+ FPs = np.cumsum(1-sames_sorted)
278
+ FNs = np.sum(sames_sorted)-TPs
279
+
280
+ precs = TPs/(TPs+FPs)
281
+ recs = TPs/(TPs+FNs)
282
+ score = util.voc_ap(recs,precs)
283
+
284
+ return(score, dict(ds=ds,sames=sames))
models/stylegan2/lpips/networks_basic.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from models.stylegan2.lpips import pretrained_networks as pn
14
+
15
+ import models.stylegan2.lpips as util
16
+
17
+ def spatial_average(in_tens, keepdim=True):
18
+ return in_tens.mean([2,3],keepdim=keepdim)
19
+
20
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21
+ in_H = in_tens.shape[2]
22
+ scale_factor = 1.*out_H/in_H
23
+
24
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25
+
26
+ # Learned perceptual metric
27
+ class PNetLin(nn.Module):
28
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29
+ super(PNetLin, self).__init__()
30
+
31
+ self.pnet_type = pnet_type
32
+ self.pnet_tune = pnet_tune
33
+ self.pnet_rand = pnet_rand
34
+ self.spatial = spatial
35
+ self.lpips = lpips
36
+ self.version = version
37
+ self.scaling_layer = ScalingLayer()
38
+
39
+ if(self.pnet_type in ['vgg','vgg16']):
40
+ net_type = pn.vgg16
41
+ self.chns = [64,128,256,512,512]
42
+ elif(self.pnet_type=='alex'):
43
+ net_type = pn.alexnet
44
+ self.chns = [64,192,384,256,256]
45
+ elif(self.pnet_type=='squeeze'):
46
+ net_type = pn.squeezenet
47
+ self.chns = [64,128,256,384,384,512,512]
48
+ self.L = len(self.chns)
49
+
50
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51
+
52
+ if(lpips):
53
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62
+ self.lins+=[self.lin5,self.lin6]
63
+
64
+ def forward(self, in0, in1, retPerLayer=False):
65
+ # v0.0 - original release had a bug, where input was not scaled
66
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68
+ feats0, feats1, diffs = {}, {}, {}
69
+
70
+ for kk in range(self.L):
71
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
73
+
74
+ if(self.lpips):
75
+ if(self.spatial):
76
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77
+ else:
78
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79
+ else:
80
+ if(self.spatial):
81
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82
+ else:
83
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84
+
85
+ val = res[0]
86
+ for l in range(1,self.L):
87
+ val += res[l]
88
+
89
+ if(retPerLayer):
90
+ return (val, res)
91
+ else:
92
+ return val
93
+
94
+ class ScalingLayer(nn.Module):
95
+ def __init__(self):
96
+ super(ScalingLayer, self).__init__()
97
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99
+
100
+ def forward(self, inp):
101
+ return (inp - self.shift) / self.scale
102
+
103
+
104
+ class NetLinLayer(nn.Module):
105
+ ''' A single linear layer which does a 1x1 conv '''
106
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
107
+ super(NetLinLayer, self).__init__()
108
+
109
+ layers = [nn.Dropout(),] if(use_dropout) else []
110
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111
+ self.model = nn.Sequential(*layers)
112
+
113
+
114
+ class Dist2LogitLayer(nn.Module):
115
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116
+ def __init__(self, chn_mid=32, use_sigmoid=True):
117
+ super(Dist2LogitLayer, self).__init__()
118
+
119
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120
+ layers += [nn.LeakyReLU(0.2,True),]
121
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122
+ layers += [nn.LeakyReLU(0.2,True),]
123
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124
+ if(use_sigmoid):
125
+ layers += [nn.Sigmoid(),]
126
+ self.model = nn.Sequential(*layers)
127
+
128
+ def forward(self,d0,d1,eps=0.1):
129
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130
+
131
+ class BCERankingLoss(nn.Module):
132
+ def __init__(self, chn_mid=32):
133
+ super(BCERankingLoss, self).__init__()
134
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
135
+ # self.parameters = list(self.net.parameters())
136
+ self.loss = torch.nn.BCELoss()
137
+
138
+ def forward(self, d0, d1, judge):
139
+ per = (judge+1.)/2.
140
+ self.logit = self.net.forward(d0,d1)
141
+ return self.loss(self.logit, per)
142
+
143
+ # L2, DSSIM metrics
144
+ class FakeNet(nn.Module):
145
+ def __init__(self, use_gpu=True, colorspace='Lab'):
146
+ super(FakeNet, self).__init__()
147
+ self.use_gpu = use_gpu
148
+ self.colorspace=colorspace
149
+
150
+ class L2(FakeNet):
151
+
152
+ def forward(self, in0, in1, retPerLayer=None):
153
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
154
+
155
+ if(self.colorspace=='RGB'):
156
+ (N,C,X,Y) = in0.size()
157
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158
+ return value
159
+ elif(self.colorspace=='Lab'):
160
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162
+ ret_var = Variable( torch.Tensor((value,) ) )
163
+ if(self.use_gpu):
164
+ ret_var = ret_var.cuda()
165
+ return ret_var
166
+
167
+ class DSSIM(FakeNet):
168
+
169
+ def forward(self, in0, in1, retPerLayer=None):
170
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
171
+
172
+ if(self.colorspace=='RGB'):
173
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174
+ elif(self.colorspace=='Lab'):
175
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177
+ ret_var = Variable( torch.Tensor((value,) ) )
178
+ if(self.use_gpu):
179
+ ret_var = ret_var.cuda()
180
+ return ret_var
181
+
182
+ def print_network(net):
183
+ num_params = 0
184
+ for param in net.parameters():
185
+ num_params += param.numel()
186
+ print('Network',net)
187
+ print('Total number of parameters: %d' % num_params)
models/stylegan2/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
models/stylegan2/lpips/weights/v0.0/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
+ size 5455
models/stylegan2/lpips/weights/v0.0/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
+ size 10057
models/stylegan2/lpips/weights/v0.0/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
+ size 6735
models/stylegan2/lpips/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009