multimodalart HF staff commited on
Commit
a891a57
·
verified ·
1 Parent(s): 7454c19

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +17 -0
  3. .vscode/settings.json +19 -0
  4. LICENSE +21 -0
  5. app.py +154 -0
  6. assets/docs/inference.gif +0 -0
  7. assets/docs/showcase.gif +3 -0
  8. assets/docs/showcase2.gif +3 -0
  9. assets/examples/driving/d0.mp4 +3 -0
  10. assets/examples/driving/d1.mp4 +0 -0
  11. assets/examples/driving/d2.mp4 +0 -0
  12. assets/examples/driving/d3.mp4 +3 -0
  13. assets/examples/driving/d5.mp4 +0 -0
  14. assets/examples/driving/d6.mp4 +3 -0
  15. assets/examples/driving/d7.mp4 +0 -0
  16. assets/examples/driving/d8.mp4 +0 -0
  17. assets/examples/driving/d9.mp4 +3 -0
  18. assets/examples/source/s0.jpg +0 -0
  19. assets/examples/source/s1.jpg +0 -0
  20. assets/examples/source/s10.jpg +0 -0
  21. assets/examples/source/s2.jpg +0 -0
  22. assets/examples/source/s3.jpg +0 -0
  23. assets/examples/source/s4.jpg +0 -0
  24. assets/examples/source/s5.jpg +0 -0
  25. assets/examples/source/s6.jpg +0 -0
  26. assets/examples/source/s7.jpg +0 -0
  27. assets/examples/source/s8.jpg +0 -0
  28. assets/examples/source/s9.jpg +0 -0
  29. assets/gradio_description_animation.md +7 -0
  30. assets/gradio_description_retargeting.md +1 -0
  31. assets/gradio_description_upload.md +2 -0
  32. assets/gradio_title.md +10 -0
  33. inference.py +33 -0
  34. pretrained_weights/.gitkeep +0 -0
  35. readme.md +143 -0
  36. requirements.txt +22 -0
  37. speed.py +192 -0
  38. src/config/__init__.py +0 -0
  39. src/config/argument_config.py +44 -0
  40. src/config/base_config.py +29 -0
  41. src/config/crop_config.py +18 -0
  42. src/config/inference_config.py +49 -0
  43. src/config/models.yaml +43 -0
  44. src/gradio_pipeline.py +140 -0
  45. src/live_portrait_pipeline.py +190 -0
  46. src/live_portrait_wrapper.py +307 -0
  47. src/modules/__init__.py +0 -0
  48. src/modules/appearance_feature_extractor.py +48 -0
  49. src/modules/convnextv2.py +149 -0
  50. src/modules/dense_motion.py +104 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/docs/showcase.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/docs/showcase2.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/examples/driving/d0.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ assets/examples/driving/d3.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ assets/examples/driving/d6.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ assets/examples/driving/d9.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ **/__pycache__/
4
+ *.py[cod]
5
+ **/*.py[cod]
6
+ *$py.class
7
+
8
+ # Model weights
9
+ **/*.pth
10
+ **/*.onnx
11
+
12
+ # Ipython notebook
13
+ *.ipynb
14
+
15
+ # Temporary files or benchmark resources
16
+ animations/*
17
+ tmp/*
.vscode/settings.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.tabSize": 4
4
+ },
5
+ "files.eol": "\n",
6
+ "files.insertFinalNewline": true,
7
+ "files.trimFinalNewlines": true,
8
+ "files.trimTrailingWhitespace": true,
9
+ "files.exclude": {
10
+ "**/.git": true,
11
+ "**/.svn": true,
12
+ "**/.hg": true,
13
+ "**/CVS": true,
14
+ "**/.DS_Store": true,
15
+ "**/Thumbs.db": true,
16
+ "**/*.crswap": true,
17
+ "**/__pycache__": true
18
+ }
19
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The entrance of the gradio
5
+ """
6
+
7
+ import tyro
8
+ import gradio as gr
9
+ import os.path as osp
10
+ from src.utils.helper import load_description
11
+ from src.gradio_pipeline import GradioPipeline
12
+ from src.config.crop_config import CropConfig
13
+ from src.config.argument_config import ArgumentConfig
14
+ from src.config.inference_config import InferenceConfig
15
+
16
+
17
+ def partial_fields(target_class, kwargs):
18
+ return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
19
+
20
+
21
+ # set tyro theme
22
+ tyro.extras.set_accent_color("bright_cyan")
23
+ args = tyro.cli(ArgumentConfig)
24
+
25
+ # specify configs for inference
26
+ inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
27
+ crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
28
+ gradio_pipeline = GradioPipeline(
29
+ inference_cfg=inference_cfg,
30
+ crop_cfg=crop_cfg,
31
+ args=args
32
+ )
33
+ # assets
34
+ title_md = "assets/gradio_title.md"
35
+ example_portrait_dir = "assets/examples/source"
36
+ example_video_dir = "assets/examples/driving"
37
+ data_examples = [
38
+ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
39
+ [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, True],
40
+ [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d5.mp4"), True, True, True, True],
41
+ [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d6.mp4"), True, True, True, True],
42
+ [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True, True],
43
+ ]
44
+ #################### interface logic ####################
45
+
46
+ # Define components first
47
+ eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
48
+ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
49
+ retargeting_input_image = gr.Image(type="numpy")
50
+ output_image = gr.Image(type="numpy")
51
+ output_image_paste_back = gr.Image(type="numpy")
52
+ output_video = gr.Video()
53
+ output_video_concat = gr.Video()
54
+
55
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
56
+ gr.HTML(load_description(title_md))
57
+ gr.Markdown(load_description("assets/gradio_description_upload.md"))
58
+ with gr.Row():
59
+ with gr.Accordion(open=True, label="Source Portrait"):
60
+ image_input = gr.Image(type="filepath")
61
+ with gr.Accordion(open=True, label="Driving Video"):
62
+ video_input = gr.Video()
63
+ gr.Markdown(load_description("assets/gradio_description_animation.md"))
64
+ with gr.Row():
65
+ with gr.Accordion(open=True, label="Animation Options"):
66
+ with gr.Row():
67
+ flag_relative_input = gr.Checkbox(value=True, label="relative motion")
68
+ flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
69
+ flag_remap_input = gr.Checkbox(value=True, label="paste-back")
70
+ with gr.Row():
71
+ with gr.Column():
72
+ process_button_animation = gr.Button("🚀 Animate", variant="primary")
73
+ with gr.Column():
74
+ process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
75
+ with gr.Row():
76
+ with gr.Column():
77
+ with gr.Accordion(open=True, label="The animated video in the original image space"):
78
+ output_video.render()
79
+ with gr.Column():
80
+ with gr.Accordion(open=True, label="The animated video"):
81
+ output_video_concat.render()
82
+ with gr.Row():
83
+ # Examples
84
+ gr.Markdown("## You could choose the examples below ⬇️")
85
+ with gr.Row():
86
+ gr.Examples(
87
+ examples=data_examples,
88
+ inputs=[
89
+ image_input,
90
+ video_input,
91
+ flag_relative_input,
92
+ flag_do_crop_input,
93
+ flag_remap_input
94
+ ],
95
+ examples_per_page=5
96
+ )
97
+ gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
98
+ with gr.Row():
99
+ eye_retargeting_slider.render()
100
+ lip_retargeting_slider.render()
101
+ with gr.Row():
102
+ process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
103
+ process_button_reset_retargeting = gr.ClearButton(
104
+ [
105
+ eye_retargeting_slider,
106
+ lip_retargeting_slider,
107
+ retargeting_input_image,
108
+ output_image,
109
+ output_image_paste_back
110
+ ],
111
+ value="🧹 Clear"
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ with gr.Accordion(open=True, label="Retargeting Input"):
116
+ retargeting_input_image.render()
117
+ with gr.Column():
118
+ with gr.Accordion(open=True, label="Retargeting Result"):
119
+ output_image.render()
120
+ with gr.Column():
121
+ with gr.Accordion(open=True, label="Paste-back Result"):
122
+ output_image_paste_back.render()
123
+ # binding functions for buttons
124
+ process_button_retargeting.click(
125
+ fn=gradio_pipeline.execute_image,
126
+ inputs=[eye_retargeting_slider, lip_retargeting_slider],
127
+ outputs=[output_image, output_image_paste_back],
128
+ show_progress=True
129
+ )
130
+ process_button_animation.click(
131
+ fn=gradio_pipeline.execute_video,
132
+ inputs=[
133
+ image_input,
134
+ video_input,
135
+ flag_relative_input,
136
+ flag_do_crop_input,
137
+ flag_remap_input
138
+ ],
139
+ outputs=[output_video, output_video_concat],
140
+ show_progress=True
141
+ )
142
+ image_input.change(
143
+ fn=gradio_pipeline.prepare_retargeting,
144
+ inputs=image_input,
145
+ outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
146
+ )
147
+
148
+ ##########################################################
149
+
150
+ demo.launch(
151
+ server_name=args.server_name,
152
+ server_port=args.server_port,
153
+ share=args.share,
154
+ )
assets/docs/inference.gif ADDED
assets/docs/showcase.gif ADDED

Git LFS Details

  • SHA256: 7bca5f38bfd555bf7c013312d87883afdf39d97fba719ac171c60f897af49e21
  • Pointer size: 132 Bytes
  • Size of remote file: 6.62 MB
assets/docs/showcase2.gif ADDED

Git LFS Details

  • SHA256: eb1fffb139681775780b2956e7d0289f55d199c1a3e14ab263887864d4b0d586
  • Pointer size: 132 Bytes
  • Size of remote file: 2.88 MB
assets/examples/driving/d0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63f6f9962e1fdf6e6722172e7a18155204858d5d5ce3b1e0646c150360c33bed
3
+ size 2958395
assets/examples/driving/d1.mp4 ADDED
Binary file (48.8 kB). View file
 
assets/examples/driving/d2.mp4 ADDED
Binary file (47.8 kB). View file
 
assets/examples/driving/d3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef5c86e49b1b43dcb1449b499eb5a7f0cbae2f78aec08b5598193be1e4257099
3
+ size 1430968
assets/examples/driving/d5.mp4 ADDED
Binary file (135 kB). View file
 
assets/examples/driving/d6.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00e3ea79bbf28cbdc4fbb67ec655d9a0fe876e880ec45af55ae481348d0c0fff
3
+ size 1967790
assets/examples/driving/d7.mp4 ADDED
Binary file (185 kB). View file
 
assets/examples/driving/d8.mp4 ADDED
Binary file (312 kB). View file
 
assets/examples/driving/d9.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a414aa1d547be35306d692065a2157434bf40a6025ba8e30ce12e5bb322cc33
3
+ size 2257929
assets/examples/source/s0.jpg ADDED
assets/examples/source/s1.jpg ADDED
assets/examples/source/s10.jpg ADDED
assets/examples/source/s2.jpg ADDED
assets/examples/source/s3.jpg ADDED
assets/examples/source/s4.jpg ADDED
assets/examples/source/s5.jpg ADDED
assets/examples/source/s6.jpg ADDED
assets/examples/source/s7.jpg ADDED
assets/examples/source/s8.jpg ADDED
assets/examples/source/s9.jpg ADDED
assets/gradio_description_animation.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
2
+ <div style="font-size: 1.2em; margin-left: 20px;">
3
+ 1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
4
+ </div>
5
+ <div style="font-size: 1.2em; margin-left: 20px;">
6
+ 2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
7
+ </div>
assets/gradio_description_retargeting.md ADDED
@@ -0,0 +1 @@
 
 
1
+ <span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
assets/gradio_description_upload.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ## 🤗 This is the official gradio demo for **LivePortrait**.
2
+ <div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
assets/gradio_title.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
2
+ <div>
3
+ <h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
4
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
5
+ <a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
6
+ <a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
7
+ <a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
8
+ </div>
9
+ </div>
10
+ </div>
inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import tyro
4
+ from src.config.argument_config import ArgumentConfig
5
+ from src.config.inference_config import InferenceConfig
6
+ from src.config.crop_config import CropConfig
7
+ from src.live_portrait_pipeline import LivePortraitPipeline
8
+
9
+
10
+ def partial_fields(target_class, kwargs):
11
+ return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
12
+
13
+
14
+ def main():
15
+ # set tyro theme
16
+ tyro.extras.set_accent_color("bright_cyan")
17
+ args = tyro.cli(ArgumentConfig)
18
+
19
+ # specify configs for inference
20
+ inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
21
+ crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
22
+
23
+ live_portrait_pipeline = LivePortraitPipeline(
24
+ inference_cfg=inference_cfg,
25
+ crop_cfg=crop_cfg
26
+ )
27
+
28
+ # run
29
+ live_portrait_pipeline.execute(args)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ main()
pretrained_weights/.gitkeep ADDED
File without changes
readme.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
2
+
3
+ <div align='center'>
4
+ <a href='https://github.com/cleardusk' target='_blank'><strong>Jianzhu Guo</strong></a><sup> 1†</sup>&emsp;
5
+ <a href='https://github.com/KwaiVGI' target='_blank'><strong>Dingyun Zhang</strong></a><sup> 1,2</sup>&emsp;
6
+ <a href='https://github.com/KwaiVGI' target='_blank'><strong>Xiaoqiang Liu</strong></a><sup> 1</sup>&emsp;
7
+ <a href='https://github.com/KwaiVGI' target='_blank'><strong>Zhizhou Zhong</strong></a><sup> 1,3</sup>&emsp;
8
+ <a href='https://scholar.google.com.hk/citations?user=_8k1ubAAAAAJ' target='_blank'><strong>Yuan Zhang</strong></a><sup> 1</sup>&emsp;
9
+ </div>
10
+
11
+ <div align='center'>
12
+ <a href='https://scholar.google.com/citations?user=P6MraaYAAAAJ' target='_blank'><strong>Pengfei Wan</strong></a><sup> 1</sup>&emsp;
13
+ <a href='https://openreview.net/profile?id=~Di_ZHANG3' target='_blank'><strong>Di Zhang</strong></a><sup> 1</sup>&emsp;
14
+ </div>
15
+
16
+ <div align='center'>
17
+ <sup>1 </sup>Kuaishou Technology&emsp; <sup>2 </sup>University of Science and Technology of China&emsp; <sup>3 </sup>Fudan University&emsp;
18
+ </div>
19
+
20
+ <br>
21
+ <div align="center">
22
+ <!-- <a href='LICENSE'><img src='https://img.shields.io/badge/license-MIT-yellow'></a> -->
23
+ <a href='https://liveportrait.github.io'><img src='https://img.shields.io/badge/Project-Homepage-green'></a>
24
+ <a href='https://arxiv.org/pdf/2407.03168'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
25
+ </div>
26
+ <br>
27
+
28
+ <p align="center">
29
+ <img src="./assets/docs/showcase2.gif" alt="showcase">
30
+ <br>
31
+ 🔥 For more results, visit our <a href="https://liveportrait.github.io/"><strong>homepage</strong></a> 🔥
32
+ </p>
33
+
34
+
35
+
36
+ ## 🔥 Updates
37
+ - **`2024/07/04`**: 🔥 We released the initial version of the inference code and models. Continuous updates, stay tuned!
38
+ - **`2024/07/04`**: 😊 We released the [homepage](https://liveportrait.github.io) and technical report on [arXiv](https://arxiv.org/pdf/2407.03168).
39
+
40
+ ## Introduction
41
+ This repo, named **LivePortrait**, contains the official PyTorch implementation of our paper [LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control](https://arxiv.org/pdf/2407.03168).
42
+ We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.
43
+
44
+ ## 🔥 Getting Started
45
+ ### 1. Clone the code and prepare the environment
46
+ ```bash
47
+ git clone https://github.com/KwaiVGI/LivePortrait
48
+ cd LivePortrait
49
+
50
+ # create env using conda
51
+ conda create -n LivePortrait python==3.9.18
52
+ conda activate LivePortrait
53
+ # install dependencies with pip
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+ ### 2. Download pretrained weights
58
+ Download our pretrained LivePortrait weights and face detection models of InsightFace from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
59
+ ```text
60
+ pretrained_weights
61
+ ├── insightface
62
+ │ └── models
63
+ │ └── buffalo_l
64
+ │ ├── 2d106det.onnx
65
+ │ └── det_10g.onnx
66
+ └── liveportrait
67
+ ├── base_models
68
+ │ ├── appearance_feature_extractor.pth
69
+ │ ├── motion_extractor.pth
70
+ │ ├── spade_generator.pth
71
+ │ └── warping_module.pth
72
+ ├── landmark.onnx
73
+ └── retargeting_models
74
+ └── stitching_retargeting_module.pth
75
+ ```
76
+
77
+ ### 3. Inference 🚀
78
+
79
+ ```bash
80
+ python inference.py
81
+ ```
82
+
83
+ If the script runs successfully, you will get an output mp4 file named `animations/s6--d0_concat.mp4`. This file includes the following results: driving video, input image, and generated result.
84
+
85
+ <p align="center">
86
+ <img src="./assets/docs/inference.gif" alt="image">
87
+ </p>
88
+
89
+ Or, you can change the input by specifying the `-s` and `-d` arguments:
90
+
91
+ ```bash
92
+ python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4
93
+
94
+ # or disable pasting back
95
+ python inference.py -s assets/examples/source/s9.jpg -d assets/examples/driving/d0.mp4 --no_flag_pasteback
96
+
97
+ # more options to see
98
+ python inference.py -h
99
+ ```
100
+
101
+ **More interesting results can be found in our [Homepage](https://liveportrait.github.io)** 😊
102
+
103
+ ### 4. Gradio interface
104
+
105
+ We also provide a Gradio interface for a better experience, just run by:
106
+
107
+ ```bash
108
+ python app.py
109
+ ```
110
+
111
+ ### 5. Inference speed evaluation 🚀🚀🚀
112
+ We have also provided a script to evaluate the inference speed of each module:
113
+
114
+ ```bash
115
+ python speed.py
116
+ ```
117
+
118
+ Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
119
+
120
+ | Model | Parameters(M) | Model Size(MB) | Inference(ms) |
121
+ |-----------------------------------|:-------------:|:--------------:|:-------------:|
122
+ | Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
123
+ | Motion Extractor | 28.12 | 108 | 0.84 |
124
+ | Spade Generator | 55.37 | 212 | 7.59 |
125
+ | Warping Module | 45.53 | 174 | 5.21 |
126
+ | Stitching and Retargeting Modules| 0.23 | 2.3 | 0.31 |
127
+
128
+ *Note: the listed values of Stitching and Retargeting Modules represent the combined parameter counts and the total sequential inference time of three MLP networks.*
129
+
130
+
131
+ ## Acknowledgements
132
+ We would like to thank the contributors of [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [Open Facevid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis), [SPADE](https://github.com/NVlabs/SPADE), [InsightFace](https://github.com/deepinsight/insightface) repositories, for their open research and contributions.
133
+
134
+ ## Citation 💖
135
+ If you find LivePortrait useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
136
+ ```bibtex
137
+ @article{guo2024live,
138
+ title = {LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control},
139
+ author = {Jianzhu Guo and Dingyun Zhang and Xiaoqiang Liu and Zhizhou Zhong and Yuan Zhang and Pengfei Wan and Di Zhang},
140
+ year = {2024},
141
+ journal = {arXiv preprint:2407.03168},
142
+ }
143
+ ```
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.3.0
3
+ torchvision==0.18.0
4
+ torchaudio==2.3.0
5
+
6
+ numpy==1.26.4
7
+ pyyaml==6.0.1
8
+ opencv-python==4.10.0.84
9
+ scipy==1.13.1
10
+ imageio==2.34.2
11
+ lmdb==1.4.1
12
+ tqdm==4.66.4
13
+ rich==13.7.1
14
+ ffmpeg==1.4
15
+ onnxruntime-gpu==1.18.0
16
+ onnx==1.16.1
17
+ scikit-image==0.24.0
18
+ albumentations==1.4.10
19
+ matplotlib==3.9.0
20
+ imageio-ffmpeg==0.5.1
21
+ tyro==0.8.5
22
+ gradio==4.37.1
speed.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Benchmark the inference speed of each module in LivePortrait.
5
+
6
+ TODO: heavy GPT style, need to refactor
7
+ """
8
+
9
+ import yaml
10
+ import torch
11
+ import time
12
+ import numpy as np
13
+ from src.utils.helper import load_model, concat_feat
14
+ from src.config.inference_config import InferenceConfig
15
+
16
+
17
+ def initialize_inputs(batch_size=1):
18
+ """
19
+ Generate random input tensors and move them to GPU
20
+ """
21
+ feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
22
+ kp_source = torch.randn(batch_size, 21, 3).cuda().half()
23
+ kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
24
+ source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
25
+ generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
26
+ eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
27
+ lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
28
+ feat_stitching = concat_feat(kp_source, kp_driving).half()
29
+ feat_eye = concat_feat(kp_source, eye_close_ratio).half()
30
+ feat_lip = concat_feat(kp_source, lip_close_ratio).half()
31
+
32
+ inputs = {
33
+ 'feature_3d': feature_3d,
34
+ 'kp_source': kp_source,
35
+ 'kp_driving': kp_driving,
36
+ 'source_image': source_image,
37
+ 'generator_input': generator_input,
38
+ 'feat_stitching': feat_stitching,
39
+ 'feat_eye': feat_eye,
40
+ 'feat_lip': feat_lip
41
+ }
42
+
43
+ return inputs
44
+
45
+
46
+ def load_and_compile_models(cfg, model_config):
47
+ """
48
+ Load and compile models for inference
49
+ """
50
+ appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
51
+ motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
52
+ warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
53
+ spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
54
+ stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
55
+
56
+ models_with_params = [
57
+ ('Appearance Feature Extractor', appearance_feature_extractor),
58
+ ('Motion Extractor', motion_extractor),
59
+ ('Warping Network', warping_module),
60
+ ('SPADE Decoder', spade_generator)
61
+ ]
62
+
63
+ compiled_models = {}
64
+ for name, model in models_with_params:
65
+ model = model.half()
66
+ model = torch.compile(model, mode='max-autotune') # Optimize for inference
67
+ model.eval() # Switch to evaluation mode
68
+ compiled_models[name] = model
69
+
70
+ retargeting_models = ['stitching', 'eye', 'lip']
71
+ for retarget in retargeting_models:
72
+ module = stitching_retargeting_module[retarget].half()
73
+ module = torch.compile(module, mode='max-autotune') # Optimize for inference
74
+ module.eval() # Switch to evaluation mode
75
+ stitching_retargeting_module[retarget] = module
76
+
77
+ return compiled_models, stitching_retargeting_module
78
+
79
+
80
+ def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
81
+ """
82
+ Warm up models to prepare them for benchmarking
83
+ """
84
+ print("Warm up start!")
85
+ with torch.no_grad():
86
+ for _ in range(10):
87
+ compiled_models['Appearance Feature Extractor'](inputs['source_image'])
88
+ compiled_models['Motion Extractor'](inputs['source_image'])
89
+ compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
90
+ compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
91
+ stitching_retargeting_module['stitching'](inputs['feat_stitching'])
92
+ stitching_retargeting_module['eye'](inputs['feat_eye'])
93
+ stitching_retargeting_module['lip'](inputs['feat_lip'])
94
+ print("Warm up end!")
95
+
96
+
97
+ def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
98
+ """
99
+ Measure inference times for each model
100
+ """
101
+ times = {name: [] for name in compiled_models.keys()}
102
+ times['Retargeting Models'] = []
103
+
104
+ overall_times = []
105
+
106
+ with torch.no_grad():
107
+ for _ in range(100):
108
+ torch.cuda.synchronize()
109
+ overall_start = time.time()
110
+
111
+ start = time.time()
112
+ compiled_models['Appearance Feature Extractor'](inputs['source_image'])
113
+ torch.cuda.synchronize()
114
+ times['Appearance Feature Extractor'].append(time.time() - start)
115
+
116
+ start = time.time()
117
+ compiled_models['Motion Extractor'](inputs['source_image'])
118
+ torch.cuda.synchronize()
119
+ times['Motion Extractor'].append(time.time() - start)
120
+
121
+ start = time.time()
122
+ compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
123
+ torch.cuda.synchronize()
124
+ times['Warping Network'].append(time.time() - start)
125
+
126
+ start = time.time()
127
+ compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
128
+ torch.cuda.synchronize()
129
+ times['SPADE Decoder'].append(time.time() - start)
130
+
131
+ start = time.time()
132
+ stitching_retargeting_module['stitching'](inputs['feat_stitching'])
133
+ stitching_retargeting_module['eye'](inputs['feat_eye'])
134
+ stitching_retargeting_module['lip'](inputs['feat_lip'])
135
+ torch.cuda.synchronize()
136
+ times['Retargeting Models'].append(time.time() - start)
137
+
138
+ overall_times.append(time.time() - overall_start)
139
+
140
+ return times, overall_times
141
+
142
+
143
+ def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
144
+ """
145
+ Print benchmark results with average and standard deviation of inference times
146
+ """
147
+ average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
148
+ std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
149
+
150
+ for name, model in compiled_models.items():
151
+ num_params = sum(p.numel() for p in model.parameters())
152
+ num_params_in_millions = num_params / 1e6
153
+ print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
154
+
155
+ for index, retarget in enumerate(retargeting_models):
156
+ num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
157
+ num_params_in_millions = num_params / 1e6
158
+ print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
159
+
160
+ for name, avg_time in average_times.items():
161
+ std_time = std_times[name]
162
+ print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
163
+
164
+
165
+ def main():
166
+ """
167
+ Main function to benchmark speed and model parameters
168
+ """
169
+ # Sample input tensors
170
+ inputs = initialize_inputs()
171
+
172
+ # Load configuration
173
+ cfg = InferenceConfig(device_id=0)
174
+ model_config_path = cfg.models_config
175
+ with open(model_config_path, 'r') as file:
176
+ model_config = yaml.safe_load(file)
177
+
178
+ # Load and compile models
179
+ compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
180
+
181
+ # Warm up models
182
+ warm_up_models(compiled_models, stitching_retargeting_module, inputs)
183
+
184
+ # Measure inference times
185
+ times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
186
+
187
+ # Print benchmark results
188
+ print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
src/config/__init__.py ADDED
File without changes
src/config/argument_config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ config for user
5
+ """
6
+
7
+ import os.path as osp
8
+ from dataclasses import dataclass
9
+ import tyro
10
+ from typing_extensions import Annotated
11
+ from .base_config import PrintableConfig, make_abs_path
12
+
13
+
14
+ @dataclass(repr=False) # use repr from PrintableConfig
15
+ class ArgumentConfig(PrintableConfig):
16
+ ########## input arguments ##########
17
+ source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
18
+ driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
19
+ output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
20
+ #####################################
21
+
22
+ ########## inference arguments ##########
23
+ device_id: int = 0
24
+ flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
25
+ flag_eye_retargeting: bool = False
26
+ flag_lip_retargeting: bool = False
27
+ flag_stitching: bool = True # we recommend setting it to True!
28
+ flag_relative: bool = True # whether to use relative motion
29
+ flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
30
+ flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
31
+ flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
32
+ #########################################
33
+
34
+ ########## crop arguments ##########
35
+ dsize: int = 512
36
+ scale: float = 2.3
37
+ vx_ratio: float = 0 # vx ratio
38
+ vy_ratio: float = -0.125 # vy ratio +up, -down
39
+ ####################################
40
+
41
+ ########## gradio arguments ##########
42
+ server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
43
+ share: bool = True
44
+ server_name: str = "0.0.0.0"
src/config/base_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ pretty printing class
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import os.path as osp
9
+ from typing import Tuple
10
+
11
+
12
+ def make_abs_path(fn):
13
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
14
+
15
+
16
+ class PrintableConfig: # pylint: disable=too-few-public-methods
17
+ """Printable Config defining str function"""
18
+
19
+ def __repr__(self):
20
+ lines = [self.__class__.__name__ + ":"]
21
+ for key, val in vars(self).items():
22
+ if isinstance(val, Tuple):
23
+ flattened_val = "["
24
+ for item in val:
25
+ flattened_val += str(item) + "\n"
26
+ flattened_val = flattened_val.rstrip("\n")
27
+ val = flattened_val + "]"
28
+ lines += f"{key}: {str(val)}".split("\n")
29
+ return "\n ".join(lines)
src/config/crop_config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ parameters used for crop faces
5
+ """
6
+
7
+ import os.path as osp
8
+ from dataclasses import dataclass
9
+ from typing import Union, List
10
+ from .base_config import PrintableConfig
11
+
12
+
13
+ @dataclass(repr=False) # use repr from PrintableConfig
14
+ class CropConfig(PrintableConfig):
15
+ dsize: int = 512 # crop size
16
+ scale: float = 2.3 # scale factor
17
+ vx_ratio: float = 0 # vx ratio
18
+ vy_ratio: float = -0.125 # vy ratio +up, -down
src/config/inference_config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ config dataclass used for inference
5
+ """
6
+
7
+ import os.path as osp
8
+ from dataclasses import dataclass
9
+ from typing import Literal, Tuple
10
+ from .base_config import PrintableConfig, make_abs_path
11
+
12
+
13
+ @dataclass(repr=False) # use repr from PrintableConfig
14
+ class InferenceConfig(PrintableConfig):
15
+ models_config: str = make_abs_path('./models.yaml') # portrait animation config
16
+ checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint
17
+ checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint
18
+ checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint
19
+ checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint
20
+
21
+ checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint
22
+ flag_use_half_precision: bool = True # whether to use half precision
23
+
24
+ flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
25
+ lip_zero_threshold: float = 0.03
26
+
27
+ flag_eye_retargeting: bool = False
28
+ flag_lip_retargeting: bool = False
29
+ flag_stitching: bool = True # we recommend setting it to True!
30
+
31
+ flag_relative: bool = True # whether to use relative motion
32
+ anchor_frame: int = 0 # set this value if find_best_frame is True
33
+
34
+ input_shape: Tuple[int, int] = (256, 256) # input shape
35
+ output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
36
+ output_fps: int = 30 # fps for output video
37
+ crf: int = 15 # crf for output video
38
+
39
+ flag_write_result: bool = True # whether to write output video
40
+ flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
41
+ mask_crop = None
42
+ flag_write_gif: bool = False
43
+ size_gif: int = 256
44
+ ref_max_shape: int = 1280
45
+ ref_shape_n: int = 2
46
+
47
+ device_id: int = 0
48
+ flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
49
+ flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
src/config/models.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ appearance_feature_extractor_params: # the F in the paper
3
+ image_channel: 3
4
+ block_expansion: 64
5
+ num_down_blocks: 2
6
+ max_features: 512
7
+ reshape_channel: 32
8
+ reshape_depth: 16
9
+ num_resblocks: 6
10
+ motion_extractor_params: # the M in the paper
11
+ num_kp: 21
12
+ backbone: convnextv2_tiny
13
+ warping_module_params: # the W in the paper
14
+ num_kp: 21
15
+ block_expansion: 64
16
+ max_features: 512
17
+ num_down_blocks: 2
18
+ reshape_channel: 32
19
+ estimate_occlusion_map: True
20
+ dense_motion_params:
21
+ block_expansion: 32
22
+ max_features: 1024
23
+ num_blocks: 5
24
+ reshape_depth: 16
25
+ compress: 4
26
+ spade_generator_params: # the G in the paper
27
+ upscale: 2 # represents upsample factor 256x256 -> 512x512
28
+ block_expansion: 64
29
+ max_features: 512
30
+ num_down_blocks: 2
31
+ stitching_retargeting_module_params: # the S in the paper
32
+ stitching:
33
+ input_size: 126 # (21*3)*2
34
+ hidden_sizes: [128, 128, 64]
35
+ output_size: 65 # (21*3)+2(tx,ty)
36
+ lip:
37
+ input_size: 65 # (21*3)+2
38
+ hidden_sizes: [128, 128, 64]
39
+ output_size: 63 # (21*3)
40
+ eye:
41
+ input_size: 66 # (21*3)+3
42
+ hidden_sizes: [256, 256, 128, 128, 64]
43
+ output_size: 63 # (21*3)
src/gradio_pipeline.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline for gradio
5
+ """
6
+ import gradio as gr
7
+ from .config.argument_config import ArgumentConfig
8
+ from .live_portrait_pipeline import LivePortraitPipeline
9
+ from .utils.io import load_img_online
10
+ from .utils.rprint import rlog as log
11
+ from .utils.crop import prepare_paste_back, paste_back
12
+ from .utils.camera import get_rotation_matrix
13
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
14
+
15
+ def update_args(args, user_args):
16
+ """update the args according to user inputs
17
+ """
18
+ for k, v in user_args.items():
19
+ if hasattr(args, k):
20
+ setattr(args, k, v)
21
+ return args
22
+
23
+ class GradioPipeline(LivePortraitPipeline):
24
+
25
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
26
+ super().__init__(inference_cfg, crop_cfg)
27
+ # self.live_portrait_wrapper = self.live_portrait_wrapper
28
+ self.args = args
29
+ # for single image retargeting
30
+ self.start_prepare = False
31
+ self.f_s_user = None
32
+ self.x_c_s_info_user = None
33
+ self.x_s_user = None
34
+ self.source_lmk_user = None
35
+ self.mask_ori = None
36
+ self.img_rgb = None
37
+ self.crop_M_c2o = None
38
+
39
+
40
+ def execute_video(
41
+ self,
42
+ input_image_path,
43
+ input_video_path,
44
+ flag_relative_input,
45
+ flag_do_crop_input,
46
+ flag_remap_input,
47
+ ):
48
+ """ for video driven potrait animation
49
+ """
50
+ if input_image_path is not None and input_video_path is not None:
51
+ args_user = {
52
+ 'source_image': input_image_path,
53
+ 'driving_info': input_video_path,
54
+ 'flag_relative': flag_relative_input,
55
+ 'flag_do_crop': flag_do_crop_input,
56
+ 'flag_pasteback': flag_remap_input,
57
+ }
58
+ # update config from user input
59
+ self.args = update_args(self.args, args_user)
60
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
61
+ self.cropper.update_config(self.args.__dict__)
62
+ # video driven animation
63
+ video_path, video_path_concat = self.execute(self.args)
64
+ gr.Info("Run successfully!", duration=2)
65
+ return video_path, video_path_concat,
66
+ else:
67
+ raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
68
+
69
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
70
+ """ for single image retargeting
71
+ """
72
+ if input_eye_ratio is None or input_eye_ratio is None:
73
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
74
+ elif self.f_s_user is None:
75
+ if self.start_prepare:
76
+ raise gr.Error(
77
+ "The source portrait is under processing 💥! Please wait for a second.",
78
+ duration=5
79
+ )
80
+ else:
81
+ raise gr.Error(
82
+ "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
83
+ duration=5
84
+ )
85
+ else:
86
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
87
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
88
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
89
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
90
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
91
+ lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
92
+ num_kp = self.x_s_user.shape[1]
93
+ # default: use x_s
94
+ x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
95
+ # D(W(f_s; x_s, x′_d))
96
+ out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
97
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
98
+ out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
99
+ gr.Info("Run successfully!", duration=2)
100
+ return out, out_to_ori_blend
101
+
102
+
103
+ def prepare_retargeting(self, input_image_path, flag_do_crop = True):
104
+ """ for single image retargeting
105
+ """
106
+ if input_image_path is not None:
107
+ gr.Info("Upload successfully!", duration=2)
108
+ self.start_prepare = True
109
+ inference_cfg = self.live_portrait_wrapper.cfg
110
+ ######## process source portrait ########
111
+ img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
112
+ log(f"Load source image from {input_image_path}.")
113
+ crop_info = self.cropper.crop_single_image(img_rgb)
114
+ if flag_do_crop:
115
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
116
+ else:
117
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
118
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
119
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
120
+ ############################################
121
+
122
+ # record global info for next time use
123
+ self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
124
+ self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
125
+ self.x_s_info_user = x_s_info
126
+ self.source_lmk_user = crop_info['lmk_crop']
127
+ self.img_rgb = img_rgb
128
+ self.crop_M_c2o = crop_info['M_c2o']
129
+ self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
130
+ # update slider
131
+ eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
132
+ eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
133
+ lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
134
+ lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
135
+ # for vis
136
+ self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
137
+ return eye_close_ratio, lip_close_ratio, self.I_s_vis
138
+ else:
139
+ # when press the clear button, go here
140
+ return 0.8, 0.8, self.I_s_vis
src/live_portrait_pipeline.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline of LivePortrait
5
+ """
6
+
7
+ # TODO:
8
+ # 1. 当前假定所有的模板都是已经裁好的,需要修改下
9
+ # 2. pick样例图 source + driving
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import pickle
14
+ import os.path as osp
15
+ from rich.progress import track
16
+
17
+ from .config.argument_config import ArgumentConfig
18
+ from .config.inference_config import InferenceConfig
19
+ from .config.crop_config import CropConfig
20
+ from .utils.cropper import Cropper
21
+ from .utils.camera import get_rotation_matrix
22
+ from .utils.video import images2video, concat_frames
23
+ from .utils.crop import _transform_img, prepare_paste_back, paste_back
24
+ from .utils.retargeting_utils import calc_lip_close_ratio
25
+ from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
26
+ from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
27
+ from .utils.rprint import rlog as log
28
+ from .live_portrait_wrapper import LivePortraitWrapper
29
+
30
+
31
+ def make_abs_path(fn):
32
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
33
+
34
+
35
+ class LivePortraitPipeline(object):
36
+
37
+ def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
38
+ self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
39
+ self.cropper = Cropper(crop_cfg=crop_cfg)
40
+
41
+ def execute(self, args: ArgumentConfig):
42
+ inference_cfg = self.live_portrait_wrapper.cfg # for convenience
43
+ ######## process source portrait ########
44
+ img_rgb = load_image_rgb(args.source_image)
45
+ img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
46
+ log(f"Load source image from {args.source_image}")
47
+ crop_info = self.cropper.crop_single_image(img_rgb)
48
+ source_lmk = crop_info['lmk_crop']
49
+ img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
50
+ if inference_cfg.flag_do_crop:
51
+ I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
52
+ else:
53
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
54
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
55
+ x_c_s = x_s_info['kp']
56
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
57
+ f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
58
+ x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
59
+
60
+ if inference_cfg.flag_lip_zero:
61
+ # let lip-open scalar to be 0 at first
62
+ c_d_lip_before_animation = [0.]
63
+ combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
64
+ if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
65
+ inference_cfg.flag_lip_zero = False
66
+ else:
67
+ lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
68
+ ############################################
69
+
70
+ ######## process driving info ########
71
+ if is_video(args.driving_info):
72
+ log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
73
+ # TODO: 这里track一下驱动视频 -> 构建模板
74
+ driving_rgb_lst = load_driving_info(args.driving_info)
75
+ driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
76
+ I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
77
+ n_frames = I_d_lst.shape[0]
78
+ if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
79
+ driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
80
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
81
+ elif is_template(args.driving_info):
82
+ log(f"Load from video templates {args.driving_info}")
83
+ with open(args.driving_info, 'rb') as f:
84
+ template_lst, driving_lmk_lst = pickle.load(f)
85
+ n_frames = template_lst[0]['n_frames']
86
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
87
+ else:
88
+ raise Exception("Unsupported driving types!")
89
+ #########################################
90
+
91
+ ######## prepare for pasteback ########
92
+ if inference_cfg.flag_pasteback:
93
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
94
+ I_p_paste_lst = []
95
+ #########################################
96
+
97
+ I_p_lst = []
98
+ R_d_0, x_d_0_info = None, None
99
+ for i in track(range(n_frames), description='Animating...', total=n_frames):
100
+ if is_video(args.driving_info):
101
+ # extract kp info by M
102
+ I_d_i = I_d_lst[i]
103
+ x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
104
+ R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
105
+ else:
106
+ # from template
107
+ x_d_i_info = template_lst[i]
108
+ x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
109
+ R_d_i = x_d_i_info['R_d']
110
+
111
+ if i == 0:
112
+ R_d_0 = R_d_i
113
+ x_d_0_info = x_d_i_info
114
+
115
+ if inference_cfg.flag_relative:
116
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
117
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
118
+ scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
119
+ t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
120
+ else:
121
+ R_new = R_d_i
122
+ delta_new = x_d_i_info['exp']
123
+ scale_new = x_s_info['scale']
124
+ t_new = x_d_i_info['t']
125
+
126
+ t_new[..., 2].fill_(0) # zero tz
127
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
128
+
129
+ # Algorithm 1:
130
+ if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
131
+ # without stitching or retargeting
132
+ if inference_cfg.flag_lip_zero:
133
+ x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
134
+ else:
135
+ pass
136
+ elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
137
+ # with stitching and without retargeting
138
+ if inference_cfg.flag_lip_zero:
139
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
140
+ else:
141
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
142
+ else:
143
+ eyes_delta, lip_delta = None, None
144
+ if inference_cfg.flag_eye_retargeting:
145
+ c_d_eyes_i = input_eye_ratio_lst[i]
146
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
147
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
148
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
149
+ if inference_cfg.flag_lip_retargeting:
150
+ c_d_lip_i = input_lip_ratio_lst[i]
151
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
152
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
153
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
154
+
155
+ if inference_cfg.flag_relative: # use x_s
156
+ x_d_i_new = x_s + \
157
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
158
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
159
+ else: # use x_d,i
160
+ x_d_i_new = x_d_i_new + \
161
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
162
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
163
+
164
+ if inference_cfg.flag_stitching:
165
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
166
+
167
+ out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
168
+ I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
169
+ I_p_lst.append(I_p_i)
170
+
171
+ if inference_cfg.flag_pasteback:
172
+ I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
173
+ I_p_paste_lst.append(I_p_i_to_ori_blend)
174
+
175
+ mkdir(args.output_dir)
176
+ wfp_concat = None
177
+ if is_video(args.driving_info):
178
+ frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
179
+ # save (driving frames, source image, drived frames) result
180
+ wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
181
+ images2video(frames_concatenated, wfp=wfp_concat)
182
+
183
+ # save drived result
184
+ wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
185
+ if inference_cfg.flag_pasteback:
186
+ images2video(I_p_paste_lst, wfp=wfp)
187
+ else:
188
+ images2video(I_p_lst, wfp=wfp)
189
+
190
+ return wfp, wfp_concat
src/live_portrait_wrapper.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Wrapper for LivePortrait core functions
5
+ """
6
+
7
+ import os.path as osp
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ import yaml
12
+
13
+ from .utils.timer import Timer
14
+ from .utils.helper import load_model, concat_feat
15
+ from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
16
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
17
+ from .config.inference_config import InferenceConfig
18
+ from .utils.rprint import rlog as log
19
+
20
+
21
+ class LivePortraitWrapper(object):
22
+
23
+ def __init__(self, cfg: InferenceConfig):
24
+
25
+ model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
26
+
27
+ # init F
28
+ self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
29
+ log(f'Load appearance_feature_extractor done.')
30
+ # init M
31
+ self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
32
+ log(f'Load motion_extractor done.')
33
+ # init W
34
+ self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
35
+ log(f'Load warping_module done.')
36
+ # init G
37
+ self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
38
+ log(f'Load spade_generator done.')
39
+ # init S and R
40
+ if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
41
+ self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
42
+ log(f'Load stitching_retargeting_module done.')
43
+ else:
44
+ self.stitching_retargeting_module = None
45
+
46
+ self.cfg = cfg
47
+ self.device_id = cfg.device_id
48
+ self.timer = Timer()
49
+
50
+ def update_config(self, user_args):
51
+ for k, v in user_args.items():
52
+ if hasattr(self.cfg, k):
53
+ setattr(self.cfg, k, v)
54
+
55
+ def prepare_source(self, img: np.ndarray) -> torch.Tensor:
56
+ """ construct the input as standard
57
+ img: HxWx3, uint8, 256x256
58
+ """
59
+ h, w = img.shape[:2]
60
+ if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
61
+ x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
62
+ else:
63
+ x = img.copy()
64
+
65
+ if x.ndim == 3:
66
+ x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
67
+ elif x.ndim == 4:
68
+ x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
69
+ else:
70
+ raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
71
+ x = np.clip(x, 0, 1) # clip to 0~1
72
+ x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
73
+ x = x.cuda(self.device_id)
74
+ return x
75
+
76
+ def prepare_driving_videos(self, imgs) -> torch.Tensor:
77
+ """ construct the input as standard
78
+ imgs: NxBxHxWx3, uint8
79
+ """
80
+ if isinstance(imgs, list):
81
+ _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
82
+ elif isinstance(imgs, np.ndarray):
83
+ _imgs = imgs
84
+ else:
85
+ raise ValueError(f'imgs type error: {type(imgs)}')
86
+
87
+ y = _imgs.astype(np.float32) / 255.
88
+ y = np.clip(y, 0, 1) # clip to 0~1
89
+ y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
90
+ y = y.cuda(self.device_id)
91
+
92
+ return y
93
+
94
+ def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
95
+ """ get the appearance feature of the image by F
96
+ x: Bx3xHxW, normalized to 0~1
97
+ """
98
+ with torch.no_grad():
99
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
100
+ feature_3d = self.appearance_feature_extractor(x)
101
+
102
+ return feature_3d.float()
103
+
104
+ def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
105
+ """ get the implicit keypoint information
106
+ x: Bx3xHxW, normalized to 0~1
107
+ flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
108
+ return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
109
+ """
110
+ with torch.no_grad():
111
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
112
+ kp_info = self.motion_extractor(x)
113
+
114
+ if self.cfg.flag_use_half_precision:
115
+ # float the dict
116
+ for k, v in kp_info.items():
117
+ if isinstance(v, torch.Tensor):
118
+ kp_info[k] = v.float()
119
+
120
+ flag_refine_info: bool = kwargs.get('flag_refine_info', True)
121
+ if flag_refine_info:
122
+ bs = kp_info['kp'].shape[0]
123
+ kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
124
+ kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
125
+ kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
126
+ kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
127
+ kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
128
+
129
+ return kp_info
130
+
131
+ def get_pose_dct(self, kp_info: dict) -> dict:
132
+ pose_dct = dict(
133
+ pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
134
+ yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
135
+ roll=headpose_pred_to_degree(kp_info['roll']).item(),
136
+ )
137
+ return pose_dct
138
+
139
+ def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
140
+
141
+ # get the canonical keypoints of source image by M
142
+ source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
143
+ source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
144
+
145
+ # get the canonical keypoints of first driving frame by M
146
+ driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
147
+ driving_first_frame_rotation = get_rotation_matrix(
148
+ driving_first_frame_kp_info['pitch'],
149
+ driving_first_frame_kp_info['yaw'],
150
+ driving_first_frame_kp_info['roll']
151
+ )
152
+
153
+ # get feature volume by F
154
+ source_feature_3d = self.extract_feature_3d(source_prepared)
155
+
156
+ return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
157
+
158
+ def transform_keypoint(self, kp_info: dict):
159
+ """
160
+ transform the implicit keypoints with the pose, shift, and expression deformation
161
+ kp: BxNx3
162
+ """
163
+ kp = kp_info['kp'] # (bs, k, 3)
164
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
165
+
166
+ t, exp = kp_info['t'], kp_info['exp']
167
+ scale = kp_info['scale']
168
+
169
+ pitch = headpose_pred_to_degree(pitch)
170
+ yaw = headpose_pred_to_degree(yaw)
171
+ roll = headpose_pred_to_degree(roll)
172
+
173
+ bs = kp.shape[0]
174
+ if kp.ndim == 2:
175
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
176
+ else:
177
+ num_kp = kp.shape[1] # Bxnum_kpx3
178
+
179
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
180
+
181
+ # Eqn.2: s * (R * x_c,s + exp) + t
182
+ kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
183
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
184
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
185
+
186
+ return kp_transformed
187
+
188
+ def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ kp_source: BxNx3
191
+ eye_close_ratio: Bx3
192
+ Return: Bx(3*num_kp+2)
193
+ """
194
+ feat_eye = concat_feat(kp_source, eye_close_ratio)
195
+
196
+ with torch.no_grad():
197
+ delta = self.stitching_retargeting_module['eye'](feat_eye)
198
+
199
+ return delta
200
+
201
+ def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
202
+ """
203
+ kp_source: BxNx3
204
+ lip_close_ratio: Bx2
205
+ """
206
+ feat_lip = concat_feat(kp_source, lip_close_ratio)
207
+
208
+ with torch.no_grad():
209
+ delta = self.stitching_retargeting_module['lip'](feat_lip)
210
+
211
+ return delta
212
+
213
+ def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ kp_source: BxNx3
216
+ kp_driving: BxNx3
217
+ Return: Bx(3*num_kp+2)
218
+ """
219
+ feat_stiching = concat_feat(kp_source, kp_driving)
220
+
221
+ with torch.no_grad():
222
+ delta = self.stitching_retargeting_module['stitching'](feat_stiching)
223
+
224
+ return delta
225
+
226
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
227
+ """ conduct the stitching
228
+ kp_source: Bxnum_kpx3
229
+ kp_driving: Bxnum_kpx3
230
+ """
231
+
232
+ if self.stitching_retargeting_module is not None:
233
+
234
+ bs, num_kp = kp_source.shape[:2]
235
+
236
+ kp_driving_new = kp_driving.clone()
237
+ delta = self.stitch(kp_source, kp_driving_new)
238
+
239
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
240
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
241
+
242
+ kp_driving_new += delta_exp
243
+ kp_driving_new[..., :2] += delta_tx_ty
244
+
245
+ return kp_driving_new
246
+
247
+ return kp_driving
248
+
249
+ def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
250
+ """ get the image after the warping of the implicit keypoints
251
+ feature_3d: Bx32x16x64x64, feature volume
252
+ kp_source: BxNx3
253
+ kp_driving: BxNx3
254
+ """
255
+ # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
256
+ with torch.no_grad():
257
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
258
+ # get decoder input
259
+ ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
260
+ # decode
261
+ ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
262
+
263
+ # float the dict
264
+ if self.cfg.flag_use_half_precision:
265
+ for k, v in ret_dct.items():
266
+ if isinstance(v, torch.Tensor):
267
+ ret_dct[k] = v.float()
268
+
269
+ return ret_dct
270
+
271
+ def parse_output(self, out: torch.Tensor) -> np.ndarray:
272
+ """ construct the output as standard
273
+ return: 1xHxWx3, uint8
274
+ """
275
+ out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
276
+ out = np.clip(out, 0, 1) # clip to 0~1
277
+ out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
278
+
279
+ return out
280
+
281
+ def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
282
+ input_eye_ratio_lst = []
283
+ input_lip_ratio_lst = []
284
+ for lmk in driving_lmk_lst:
285
+ # for eyes retargeting
286
+ input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
287
+ # for lip retargeting
288
+ input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
289
+ return input_eye_ratio_lst, input_lip_ratio_lst
290
+
291
+ def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
292
+ eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
293
+ eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
294
+ input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
295
+ # [c_s,eyes, c_d,eyes,i]
296
+ combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
297
+ return combined_eye_ratio_tensor
298
+
299
+ def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
300
+ lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
301
+ lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
302
+ # [c_s,lip, c_d,lip,i]
303
+ input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
304
+ if input_lip_ratio_tensor.shape != [1, 1]:
305
+ input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
306
+ combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
307
+ return combined_lip_ratio_tensor
src/modules/__init__.py ADDED
File without changes
src/modules/appearance_feature_extractor.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ from .util import SameBlock2d, DownBlock2d, ResBlock3d
10
+
11
+
12
+ class AppearanceFeatureExtractor(nn.Module):
13
+
14
+ def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
15
+ super(AppearanceFeatureExtractor, self).__init__()
16
+ self.image_channel = image_channel
17
+ self.block_expansion = block_expansion
18
+ self.num_down_blocks = num_down_blocks
19
+ self.max_features = max_features
20
+ self.reshape_channel = reshape_channel
21
+ self.reshape_depth = reshape_depth
22
+
23
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
24
+
25
+ down_blocks = []
26
+ for i in range(num_down_blocks):
27
+ in_features = min(max_features, block_expansion * (2 ** i))
28
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
29
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
30
+ self.down_blocks = nn.ModuleList(down_blocks)
31
+
32
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
33
+
34
+ self.resblocks_3d = torch.nn.Sequential()
35
+ for i in range(num_resblocks):
36
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
37
+
38
+ def forward(self, source_image):
39
+ out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
40
+
41
+ for i in range(len(self.down_blocks)):
42
+ out = self.down_blocks[i](out)
43
+ out = self.second(out)
44
+ bs, c, h, w = out.shape # ->Bx512x64x64
45
+
46
+ f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
47
+ f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
48
+ return f_s
src/modules/convnextv2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ # from timm.models.layers import trunc_normal_, DropPath
10
+ from .util import LayerNorm, DropPath, trunc_normal_, GRN
11
+
12
+ __all__ = ['convnextv2_tiny']
13
+
14
+
15
+ class Block(nn.Module):
16
+ """ ConvNeXtV2 Block.
17
+
18
+ Args:
19
+ dim (int): Number of input channels.
20
+ drop_path (float): Stochastic depth rate. Default: 0.0
21
+ """
22
+
23
+ def __init__(self, dim, drop_path=0.):
24
+ super().__init__()
25
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
26
+ self.norm = LayerNorm(dim, eps=1e-6)
27
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
28
+ self.act = nn.GELU()
29
+ self.grn = GRN(4 * dim)
30
+ self.pwconv2 = nn.Linear(4 * dim, dim)
31
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
32
+
33
+ def forward(self, x):
34
+ input = x
35
+ x = self.dwconv(x)
36
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37
+ x = self.norm(x)
38
+ x = self.pwconv1(x)
39
+ x = self.act(x)
40
+ x = self.grn(x)
41
+ x = self.pwconv2(x)
42
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
43
+
44
+ x = input + self.drop_path(x)
45
+ return x
46
+
47
+
48
+ class ConvNeXtV2(nn.Module):
49
+ """ ConvNeXt V2
50
+
51
+ Args:
52
+ in_chans (int): Number of input image channels. Default: 3
53
+ num_classes (int): Number of classes for classification head. Default: 1000
54
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
55
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
56
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
57
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_chans=3,
63
+ depths=[3, 3, 9, 3],
64
+ dims=[96, 192, 384, 768],
65
+ drop_path_rate=0.,
66
+ **kwargs
67
+ ):
68
+ super().__init__()
69
+ self.depths = depths
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
89
+ )
90
+ self.stages.append(stage)
91
+ cur += depths[i]
92
+
93
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
94
+
95
+ # NOTE: the output semantic items
96
+ num_bins = kwargs.get('num_bins', 66)
97
+ num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
98
+ self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
99
+
100
+ # print('dims[-1]: ', dims[-1])
101
+ self.fc_scale = nn.Linear(dims[-1], 1) # scale
102
+ self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
103
+ self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
104
+ self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
105
+ self.fc_t = nn.Linear(dims[-1], 3) # translation
106
+ self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
107
+
108
+ def _init_weights(self, m):
109
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
110
+ trunc_normal_(m.weight, std=.02)
111
+ nn.init.constant_(m.bias, 0)
112
+
113
+ def forward_features(self, x):
114
+ for i in range(4):
115
+ x = self.downsample_layers[i](x)
116
+ x = self.stages[i](x)
117
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118
+
119
+ def forward(self, x):
120
+ x = self.forward_features(x)
121
+
122
+ # implicit keypoints
123
+ kp = self.fc_kp(x)
124
+
125
+ # pose and expression deformation
126
+ pitch = self.fc_pitch(x)
127
+ yaw = self.fc_yaw(x)
128
+ roll = self.fc_roll(x)
129
+ t = self.fc_t(x)
130
+ exp = self.fc_exp(x)
131
+ scale = self.fc_scale(x)
132
+
133
+ ret_dct = {
134
+ 'pitch': pitch,
135
+ 'yaw': yaw,
136
+ 'roll': roll,
137
+ 't': t,
138
+ 'exp': exp,
139
+ 'scale': scale,
140
+
141
+ 'kp': kp, # canonical keypoint
142
+ }
143
+
144
+ return ret_dct
145
+
146
+
147
+ def convnextv2_tiny(**kwargs):
148
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
149
+ return model
src/modules/dense_motion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
5
+ """
6
+
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torch
10
+ from .util import Hourglass, make_coordinate_grid, kp2gaussian
11
+
12
+
13
+ class DenseMotionNetwork(nn.Module):
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
15
+ super(DenseMotionNetwork, self).__init__()
16
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
17
+
18
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
19
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
20
+ self.norm = nn.BatchNorm3d(compress, affine=True)
21
+ self.num_kp = num_kp
22
+ self.flag_estimate_occlusion_map = estimate_occlusion_map
23
+
24
+ if self.flag_estimate_occlusion_map:
25
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
26
+ else:
27
+ self.occlusion = None
28
+
29
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
30
+ bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
31
+ identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
32
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
33
+ coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
34
+
35
+ k = coordinate_grid.shape[1]
36
+
37
+ # NOTE: there lacks an one-order flow
38
+ driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
39
+
40
+ # adding background feature
41
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
42
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
43
+ return sparse_motions
44
+
45
+ def create_deformed_feature(self, feature, sparse_motions):
46
+ bs, _, d, h, w = feature.shape
47
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
48
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
49
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
50
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
51
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
52
+
53
+ return sparse_deformed
54
+
55
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
56
+ spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
57
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
58
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
59
+ heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
60
+
61
+ # adding background feature
62
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
63
+ heatmap = torch.cat([zeros, heatmap], dim=1)
64
+ heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
65
+ return heatmap
66
+
67
+ def forward(self, feature, kp_driving, kp_source):
68
+ bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
69
+
70
+ feature = self.compress(feature) # (bs, 4, 16, 64, 64)
71
+ feature = self.norm(feature) # (bs, 4, 16, 64, 64)
72
+ feature = F.relu(feature) # (bs, 4, 16, 64, 64)
73
+
74
+ out_dict = dict()
75
+
76
+ # 1. deform 3d feature
77
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
78
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
79
+
80
+ # 2. (bs, 1+num_kp, d, h, w)
81
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
82
+
83
+ input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
84
+ input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
85
+
86
+ prediction = self.hourglass(input)
87
+
88
+ mask = self.mask(prediction)
89
+ mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
90
+ out_dict['mask'] = mask
91
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
92
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
93
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
94
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
95
+
96
+ out_dict['deformation'] = deformation
97
+
98
+ if self.flag_estimate_occlusion_map:
99
+ bs, _, d, h, w = prediction.shape
100
+ prediction_reshape = prediction.view(bs, -1, h, w)
101
+ occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
102
+ out_dict['occlusion_map'] = occlusion_map
103
+
104
+ return out_dict