Evgeny Zhukov commited on
Commit
2ba4412
·
1 Parent(s): 45e557f

Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. UniAnimate/.gitignore +18 -0
  3. UniAnimate/README.md +344 -0
  4. UniAnimate/configs/UniAnimate_infer.yaml +98 -0
  5. UniAnimate/configs/UniAnimate_infer_long.yaml +101 -0
  6. UniAnimate/dwpose/__init__.py +0 -0
  7. UniAnimate/dwpose/onnxdet.py +127 -0
  8. UniAnimate/dwpose/onnxpose.py +360 -0
  9. UniAnimate/dwpose/util.py +336 -0
  10. UniAnimate/dwpose/wholebody.py +48 -0
  11. UniAnimate/environment.yaml +236 -0
  12. UniAnimate/inference.py +18 -0
  13. UniAnimate/requirements.txt +201 -0
  14. UniAnimate/run_align_pose.py +712 -0
  15. UniAnimate/test_func/save_targer_keys.py +108 -0
  16. UniAnimate/test_func/test_EndDec.py +95 -0
  17. UniAnimate/test_func/test_dataset.py +152 -0
  18. UniAnimate/test_func/test_models.py +56 -0
  19. UniAnimate/test_func/test_save_video.py +24 -0
  20. UniAnimate/tools/__init__.py +3 -0
  21. UniAnimate/tools/datasets/__init__.py +2 -0
  22. UniAnimate/tools/datasets/image_dataset.py +86 -0
  23. UniAnimate/tools/datasets/video_dataset.py +118 -0
  24. UniAnimate/tools/inferences/__init__.py +2 -0
  25. UniAnimate/tools/inferences/inference_unianimate_entrance.py +483 -0
  26. UniAnimate/tools/inferences/inference_unianimate_long_entrance.py +508 -0
  27. UniAnimate/tools/modules/__init__.py +7 -0
  28. UniAnimate/tools/modules/autoencoder.py +690 -0
  29. UniAnimate/tools/modules/clip_embedder.py +212 -0
  30. UniAnimate/tools/modules/config.py +206 -0
  31. UniAnimate/tools/modules/diffusions/__init__.py +1 -0
  32. UniAnimate/tools/modules/diffusions/diffusion_ddim.py +1121 -0
  33. UniAnimate/tools/modules/diffusions/diffusion_gauss.py +498 -0
  34. UniAnimate/tools/modules/diffusions/losses.py +28 -0
  35. UniAnimate/tools/modules/diffusions/schedules.py +166 -0
  36. UniAnimate/tools/modules/embedding_manager.py +179 -0
  37. UniAnimate/tools/modules/unet/__init__.py +2 -0
  38. UniAnimate/tools/modules/unet/mha_flash.py +103 -0
  39. UniAnimate/tools/modules/unet/unet_unianimate.py +659 -0
  40. UniAnimate/tools/modules/unet/util.py +1741 -0
  41. UniAnimate/utils/__init__.py +0 -0
  42. UniAnimate/utils/assign_cfg.py +78 -0
  43. UniAnimate/utils/config.py +230 -0
  44. UniAnimate/utils/distributed.py +430 -0
  45. UniAnimate/utils/logging.py +90 -0
  46. UniAnimate/utils/mp4_to_gif.py +16 -0
  47. UniAnimate/utils/multi_port.py +9 -0
  48. UniAnimate/utils/optim/__init__.py +2 -0
  49. UniAnimate/utils/optim/adafactor.py +230 -0
  50. UniAnimate/utils/optim/lr_scheduler.py +58 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ UniAnimate/data/** filter=lfs diff=lfs merge=lfs -text
UniAnimate/.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pkl
2
+ *.pt
3
+ *.mov
4
+ *.pth
5
+ *.mov
6
+ *.npz
7
+ *.npy
8
+ *.boj
9
+ *.onnx
10
+ *.tar
11
+ *.bin
12
+ cache*
13
+ .DS_Store
14
+ *DS_Store
15
+ outputs/
16
+ **/__pycache__
17
+ ***/__pycache__
18
+ */__pycache__
UniAnimate/README.md ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- main documents -->
2
+
3
+
4
+ <div align="center">
5
+
6
+
7
+ # UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation
8
+
9
+ [Xiang Wang](https://scholar.google.com.hk/citations?user=cQbXvkcAAAAJ&hl=zh-CN&oi=sra)<sup>1</sup>, [Shiwei Zhang](https://scholar.google.com.hk/citations?user=ZO3OQ-8AAAAJ&hl=zh-CN)<sup>2</sup>, [Changxin Gao](https://scholar.google.com.hk/citations?user=4tku-lwAAAAJ&hl=zh-CN)<sup>1</sup>, [Jiayu Wang](#)<sup>2</sup>, [Xiaoqiang Zhou](https://scholar.google.com.hk/citations?user=Z2BTkNIAAAAJ&hl=zh-CN&oi=ao)<sup>3</sup>, [Yingya Zhang](https://scholar.google.com.hk/citations?user=16RDSEUAAAAJ&hl=zh-CN)<sup>2</sup> , [Luxin Yan](#)<sup>1</sup> , [Nong Sang](https://scholar.google.com.hk/citations?user=ky_ZowEAAAAJ&hl=zh-CN)<sup>1</sup>
10
+ <sup>1</sup>HUST &nbsp; <sup>2</sup>Alibaba Group &nbsp; <sup>3</sup>USTC
11
+
12
+
13
+ [🎨 Project Page](https://unianimate.github.io/)
14
+
15
+
16
+ <p align="middle">
17
+ <img src='https://img.alicdn.com/imgextra/i4/O1CN01bW2Y491JkHAUK4W0i_!!6000000001066-2-tps-2352-1460.png' width='784'>
18
+
19
+ Demo cases generated by the proposed UniAnimate
20
+ </p>
21
+
22
+
23
+ </div>
24
+
25
+ ## 🔥 News
26
+ - **[2024/07/19]** 🔥 We added a **<font color=red>noise prior</font>** to the code (refer to line 381: `noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise)` in `tools/inferences/inference_unianimate_long_entrance.py`), which can help achieve better appearance preservation (such as background), especially in long video generation. In addition, we are considering releasing an upgraded version of UniAnimate if we obtain an open source license from the company.
27
+ - **[2024/06/26]** For cards with large GPU memory, such as A100 GPU, we support multiple segments parallel denoising to accelerate long video inference. You can change `context_batch_size: 1` in `configs/UniAnimate_infer_long.yaml` to other values greater than 1, such as `context_batch_size: 4`. The inference speed will be improved to a certain extent.
28
+ - **[2024/06/15]** 🔥🔥🔥 By offloading CLIP and VAE and explicitly adding torch.float16 (i.e., set `CPU_CLIP_VAE: True` in `configs/UniAnimate_infer.yaml`), the GPU memory can be greatly reduced. Now generating a 32x768x512 video clip only requires **~12G GPU memory**. Refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/10) for more details. Thanks to [@blackight](https://github.com/blackight) for the contribution!
29
+ - **[2024/06/13]** **🔥🔥🔥 <font color=red>We released the code and models for human image animation, enjoy it!</font>**
30
+ - **[2024/06/13]** We have submitted the code to the company for approval, and **the code is expected to be released today or tomorrow**.
31
+ - **[2024/06/03]** We initialized this github repository and planed to release the paper.
32
+
33
+
34
+ ## TODO
35
+
36
+ - [x] Release the models and inference code, and pose alignment code.
37
+ - [x] Support generating both short and long videos.
38
+ - [ ] Release the models for longer video generation in one batch.
39
+ - [ ] Release models based on VideoLCM for faster video synthesis.
40
+ - [ ] Training the models on higher resolution videos.
41
+
42
+
43
+
44
+ ## Introduction
45
+
46
+ <div align="center">
47
+ <p align="middle">
48
+ <img src='https://img.alicdn.com/imgextra/i3/O1CN01VvncFJ1ueRudiMOZu_!!6000000006062-2-tps-2654-1042.png' width='784'>
49
+
50
+ Overall framework of UniAnimate
51
+ </p>
52
+ </div>
53
+
54
+ Recent diffusion-based human image animation techniques have demonstrated impressive success in synthesizing videos that faithfully follow a given reference identity and a sequence of desired movement poses. Despite this, there are still two limitations: i) an extra reference model is required to align the identity image with the main video branch, which significantly increases the optimization burden and model parameters; ii) the generated video is usually short in time (e.g., 24 frames), hampering practical applications. To address these shortcomings, we present a UniAnimate framework to enable efficient and long-term human video generation. First, to reduce the optimization difficulty and ensure temporal coherence, we map the reference image along with the posture guidance and noise video into a common feature space by incorporating a unified video diffusion model. Second, we propose a unified noise input that supports random noised input as well as first frame conditioned input, which enhances the ability to generate long-term video. Finally, to further efficiently handle long sequences, we explore an alternative temporal modeling architecture based on state space model to replace the original computation-consuming temporal Transformer. Extensive experimental results indicate that UniAnimate achieves superior synthesis results over existing state-of-the-art counterparts in both quantitative and qualitative evaluations. Notably, UniAnimate can even generate highly consistent one-minute videos by iteratively employing the first frame conditioning strategy.
55
+
56
+
57
+ ## Getting Started with UniAnimate
58
+
59
+
60
+ ### (1) Installation
61
+
62
+ Installation the python dependencies:
63
+
64
+ ```
65
+ git clone https://github.com/ali-vilab/UniAnimate.git
66
+ cd UniAnimate
67
+ conda create -n UniAnimate python=3.9
68
+ conda activate UniAnimate
69
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
70
+ pip install -r requirements.txt
71
+ ```
72
+ We also provide all the dependencies in `environment.yaml`.
73
+
74
+ **Note**: for Windows operating system, you can refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/11) to install the dependencies. Thanks to [@zephirusgit](https://github.com/zephirusgit) for the contribution. If you encouter the problem of `The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1).`, please refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/61) to solve it, thanks to [@Isi-dev](https://github.com/Isi-dev) for the contribution.
75
+
76
+ ### (2) Download the pretrained checkpoints
77
+
78
+ Download models:
79
+ ```
80
+ !pip install modelscope
81
+ from modelscope.hub.snapshot_download import snapshot_download
82
+ model_dir = snapshot_download('iic/unianimate', cache_dir='checkpoints/')
83
+ ```
84
+ Then you might need the following command to move the checkpoints to the "checkpoints/" directory:
85
+ ```
86
+ mv ./checkpoints/iic/unianimate/* ./checkpoints/
87
+ ```
88
+
89
+ Finally, the model weights will be organized in `./checkpoints/` as follows:
90
+ ```
91
+ ./checkpoints/
92
+ |---- dw-ll_ucoco_384.onnx
93
+ |---- open_clip_pytorch_model.bin
94
+ |---- unianimate_16f_32f_non_ema_223000.pth
95
+ |---- v2-1_512-ema-pruned.ckpt
96
+ └---- yolox_l.onnx
97
+ ```
98
+
99
+ ### (3) Pose alignment **(Important)**
100
+
101
+ Rescale the target pose sequence to match the pose of the reference image:
102
+ ```
103
+ # reference image 1
104
+ python run_align_pose.py --ref_name data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full
105
+
106
+ # reference image 2
107
+ python run_align_pose.py --ref_name data/images/musk.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/musk
108
+
109
+ # reference image 3
110
+ python run_align_pose.py --ref_name data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full
111
+
112
+ # reference image 4
113
+ python run_align_pose.py --ref_name data/images/IMG_20240514_104337.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/IMG_20240514_104337
114
+ ```
115
+ We have already provided the processed target pose for demo videos in ```data/saved_pose```, if you run our demo video example, this step can be skipped. In addition, you need to install onnxruntime-gpu (`pip install onnxruntime-gpu==1.13.1`) to run pose alignment on GPU.
116
+
117
+ **<font color=red>&#10004; Some tips</font>**:
118
+
119
+ - > In pose alignment, the first frame in the target pose sequence is used to calculate the scale coefficient of the alignment. Therefore, if the first frame in the target pose sequence contains the entire face and pose (hand and foot), it can help obtain more accurate estimation and better video generation results.
120
+
121
+
122
+
123
+ ### (4) Run the UniAnimate model to generate videos
124
+
125
+ #### (4.1) Generating video clips (32 frames with 768x512 resolution)
126
+
127
+ Execute the following command to generate video clips:
128
+ ```
129
+ python inference.py --cfg configs/UniAnimate_infer.yaml
130
+ ```
131
+ After this, 32-frame video clips with 768x512 resolution will be generated:
132
+
133
+
134
+ <table>
135
+ <center>
136
+ <tr>
137
+ <td ><center>
138
+ <image height="260" src="assets/1.gif"></image>
139
+ </center></td>
140
+ <td ><center>
141
+ <image height="260" src="assets/2.gif"></image>
142
+ </center></td>
143
+ </tr>
144
+ <tr>
145
+ <td ><center>
146
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV1hTb2g3Zlpmb1E/Vk9HZHZkdDBUQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
147
+ </center></td>
148
+ <td ><center>
149
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYzNUUWRKR043c1FaZkVHSkpSMnpoeTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
150
+ </center></td>
151
+ </tr>
152
+ </center>
153
+ </table>
154
+ </center>
155
+
156
+ <!-- <table>
157
+ <center>
158
+ <tr>
159
+ <td ><center>
160
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYzNUUWRKR043c1FaZkVHSkpSMnpoeTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
161
+ </td>
162
+ <td ><center>
163
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV1hTb2g3Zlpmb1E/Vk9HZHZkdDBUQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
164
+ </td>
165
+ </tr>
166
+ </table> -->
167
+
168
+
169
+ **<font color=red>&#10004; Some tips</font>**:
170
+
171
+ - > To run the model, **~12G** ~~26G~~ GPU memory will be used. If your GPU is smaller than this, you can change the `max_frames: 32` in `configs/UniAnimate_infer.yaml` to other values, e.g., 24, 16, and 8. Our model is compatible with all of them.
172
+
173
+
174
+ #### (4.2) Generating video clips (32 frames with 1216x768 resolution)
175
+
176
+ If you want to synthesize higher resolution results, you can change the `resolution: [512, 768]` in `configs/UniAnimate_infer.yaml` to `resolution: [768, 1216]`. And execute the following command to generate video clips:
177
+ ```
178
+ python inference.py --cfg configs/UniAnimate_infer.yaml
179
+ ```
180
+ After this, 32-frame video clips with 1216x768 resolution will be generated:
181
+
182
+
183
+
184
+ <table>
185
+ <center>
186
+ <tr>
187
+ <td ><center>
188
+ <image height="260" src="assets/3.gif"></image>
189
+ </center></td>
190
+ <td ><center>
191
+ <image height="260" src="assets/4.gif"></image>
192
+ </center></td>
193
+ </tr>
194
+ <tr>
195
+ <td ><center>
196
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/NTFJUWJ1YXphUzU5b3dhZHJlQk1YZjA3emppMWNJbHhXSlN6WmZHc2FTYTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
197
+ </center></td>
198
+ <td ><center>
199
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYklMcGdIRFlDcXcwVEU5ZnR0VlBpRzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
200
+ </center></td>
201
+ </tr>
202
+ </center>
203
+ </table>
204
+ </center>
205
+
206
+ <!-- <table>
207
+ <center>
208
+ <tr>
209
+ <td ><center>
210
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/NTFJUWJ1YXphUzU5b3dhZHJlQk1YZjA3emppMWNJbHhXSlN6WmZHc2FTYTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
211
+ </td>
212
+ <td ><center>
213
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYklMcGdIRFlDcXcwVEU5ZnR0VlBpRzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
214
+ </td>
215
+ </tr>
216
+ </table> -->
217
+
218
+
219
+ **<font color=red>&#10004; Some tips</font>**:
220
+
221
+ - > To run the model, **~21G** ~~36G~~ GPU memory will be used. Even though our model was trained on 512x768 resolution, we observed that direct inference on 768x1216 is usually allowed and produces satisfactory results. If this results in inconsistent apparence, you can try a different seed or adjust the resolution to 512x768.
222
+
223
+ - > Although our model was not trained on 48 or 64 frames, we found that the model generalizes well to synthesis of these lengths.
224
+
225
+
226
+
227
+ In the `configs/UniAnimate_infer.yaml` configuration file, you can specify the data, adjust the video length using `max_frames`, and validate your ideas with different Diffusion settings, and so on.
228
+
229
+
230
+
231
+ #### (4.3) Generating long videos
232
+
233
+
234
+
235
+ If you want to synthesize videos as long as the target pose sequence, you can execute the following command to generate long videos:
236
+ ```
237
+ python inference.py --cfg configs/UniAnimate_infer_long.yaml
238
+ ```
239
+ After this, long videos with 1216x768 resolution will be generated:
240
+
241
+
242
+
243
+ <table>
244
+ <center>
245
+ <tr>
246
+ <td ><center>
247
+ <image height="260" src="assets/5.gif"></image>
248
+ </center></td>
249
+ <td ><center>
250
+ <image height="260" src="assets/6.gif"></image>
251
+ </center></td>
252
+ </tr>
253
+ <tr>
254
+ <td ><center>
255
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYVmJKZUJSbDl6N1FXU01DYTlDRmJKTzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
256
+ </center></td>
257
+ <td ><center>
258
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/VUdZTUE5MWtST3VtNEdFaVpGbHN1U25nNEorTEc2SzZROUNiUjNncW5ycTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
259
+ </center></td>
260
+ </tr>
261
+ </center>
262
+ </table>
263
+ </center>
264
+
265
+ <!-- <table>
266
+ <center>
267
+ <tr>
268
+ <td ><center>
269
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYVmJKZUJSbDl6N1FXU01DYTlDRmJKTzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
270
+ </td>
271
+ <td ><center>
272
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/VUdZTUE5MWtST3VtNEdFaVpGbHN1U25nNEorTEc2SzZROUNiUjNncW5ycTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
273
+ </td>
274
+ </tr>
275
+ </table> -->
276
+
277
+
278
+
279
+
280
+ <table>
281
+ <center>
282
+ <tr>
283
+ <td ><center>
284
+ <image height="260" src="assets/7.gif"></image>
285
+ </center></td>
286
+ <td ><center>
287
+ <image height="260" src="assets/8.gif"></image>
288
+ </center></td>
289
+ </tr>
290
+ <tr>
291
+ <td ><center>
292
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV04xKzd3eWFPVGZCQjVTUWdtbTFuQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
293
+ </center></td>
294
+ <td ><center>
295
+ <p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYWGwxVkNMY1NXOHpWTVdNZDRxKzRuZTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
296
+ </center></td>
297
+ </tr>
298
+ </center>
299
+ </table>
300
+ </center>
301
+
302
+ <!-- <table>
303
+ <center>
304
+ <tr>
305
+ <td ><center>
306
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV04xKzd3eWFPVGZCQjVTUWdtbTFuQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
307
+ </td>
308
+ <td ><center>
309
+ <video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYWGwxVkNMY1NXOHpWTVdNZDRxKzRuZTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
310
+ </td>
311
+ </tr>
312
+ </table> -->
313
+
314
+ In the `configs/UniAnimate_infer_long.yaml` configuration file, `test_list_path` should in the format of `[frame_interval, reference image, driving pose sequence]`, where `frame_interval=1` means that all frames in the target pose sequence will be used to generate the video, and `frame_interval=2` means that one frame is sampled every two frames. `reference image` is the location where the reference image is saved, and `driving pose sequence` is the location where the driving pose sequence is saved.
315
+
316
+
317
+
318
+ **<font color=red>&#10004; Some tips</font>**:
319
+
320
+ - > If you find inconsistent appearance, you can change the resolution from 768x1216 to 512x768, or change the `context_overlap` from 8 to 16.
321
+ - > In the default setting of `configs/UniAnimate_infer_long.yaml`, the strategy of sliding window with temporal overlap is used. You can also generate a satisfactory video segment first, and then input the last frame of this segment into the model to generate the next segment to continue the video.
322
+
323
+
324
+
325
+ ## Citation
326
+
327
+ If you find this codebase useful for your research, please cite the following paper:
328
+
329
+ ```
330
+ @article{wang2024unianimate,
331
+ title={UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation},
332
+ author={Wang, Xiang and Zhang, Shiwei and Gao, Changxin and Wang, Jiayu and Zhou, Xiaoqiang and Zhang, Yingya and Yan, Luxin and Sang, Nong},
333
+ journal={arXiv preprint arXiv:2406.01188},
334
+ year={2024}
335
+ }
336
+ ```
337
+
338
+
339
+
340
+ ## Disclaimer
341
+
342
+
343
+ This open-source model is intended for <strong>RESEARCH/NON-COMMERCIAL USE ONLY</strong>.
344
+ We explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards.
UniAnimate/configs/UniAnimate_infer.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # manual setting
2
+ max_frames: 32
3
+ resolution: [512, 768] # or resolution: [768, 1216]
4
+ # resolution: [768, 1216]
5
+ round: 1
6
+ ddim_timesteps: 30 # among 25-50
7
+ seed: 11 # 7
8
+ test_list_path: [
9
+ # Format: [frame_interval, reference image, driving pose sequence]
10
+ [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"],
11
+ [2, "data/images/musk.jpg", "data/saved_pose/musk"],
12
+ [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"],
13
+ [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"]
14
+ ]
15
+ partial_keys: [
16
+ ['image','local_image', "dwpose"], # reference image as the first frame of the generated video (optional)
17
+ ['image', 'randomref', "dwpose"],
18
+ ]
19
+
20
+
21
+
22
+ # default settings
23
+ TASK_TYPE: inference_unianimate_entrance
24
+ guide_scale: 2.5
25
+ vit_resolution: [224, 224]
26
+ use_fp16: True
27
+ batch_size: 1
28
+ latent_random_ref: True
29
+ chunk_size: 2
30
+ decoder_bs: 2
31
+ scale: 8
32
+ use_fps_condition: False
33
+ test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth
34
+ embedder: {
35
+ 'type': 'FrozenOpenCLIPTextVisualEmbedder',
36
+ 'layer': 'penultimate',
37
+ 'pretrained': 'checkpoints/open_clip_pytorch_model.bin'
38
+ }
39
+
40
+
41
+ auto_encoder: {
42
+ 'type': 'AutoencoderKL',
43
+ 'ddconfig': {
44
+ 'double_z': True,
45
+ 'z_channels': 4,
46
+ 'resolution': 256,
47
+ 'in_channels': 3,
48
+ 'out_ch': 3,
49
+ 'ch': 128,
50
+ 'ch_mult': [1, 2, 4, 4],
51
+ 'num_res_blocks': 2,
52
+ 'attn_resolutions': [],
53
+ 'dropout': 0.0,
54
+ 'video_kernel_size': [3, 1, 1]
55
+ },
56
+ 'embed_dim': 4,
57
+ 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt'
58
+ }
59
+
60
+ UNet: {
61
+ 'type': 'UNetSD_UniAnimate',
62
+ 'config': None,
63
+ 'in_dim': 4,
64
+ 'dim': 320,
65
+ 'y_dim': 1024,
66
+ 'context_dim': 1024,
67
+ 'out_dim': 4,
68
+ 'dim_mult': [1, 2, 4, 4],
69
+ 'num_heads': 8,
70
+ 'head_dim': 64,
71
+ 'num_res_blocks': 2,
72
+ 'dropout': 0.1,
73
+ 'temporal_attention': True,
74
+ 'num_tokens': 4,
75
+ 'temporal_attn_times': 1,
76
+ 'use_checkpoint': True,
77
+ 'use_fps_condition': False,
78
+ 'use_sim_mask': False
79
+ }
80
+ video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose']
81
+ Diffusion: {
82
+ 'type': 'DiffusionDDIM',
83
+ 'schedule': 'linear_sd',
84
+ 'schedule_param': {
85
+ 'num_timesteps': 1000,
86
+ "init_beta": 0.00085,
87
+ "last_beta": 0.0120,
88
+ 'zero_terminal_snr': True,
89
+ },
90
+ 'mean_type': 'v',
91
+ 'loss_type': 'mse',
92
+ 'var_type': 'fixed_small', # 'fixed_large',
93
+ 'rescale_timesteps': False,
94
+ 'noise_strength': 0.1
95
+ }
96
+ use_DiffusionDPM: False
97
+ CPU_CLIP_VAE: True
98
+ noise_prior_value: 949 # or 999, 949
UniAnimate/configs/UniAnimate_infer_long.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # manual setting
2
+ # resolution: [512, 768] # or [768, 1216]
3
+ resolution: [768, 1216]
4
+ round: 1
5
+ ddim_timesteps: 30 # among 25-50
6
+ context_size: 32
7
+ context_stride: 1
8
+ context_overlap: 8
9
+ seed: 7
10
+ max_frames: "None" # 64, 96, "None" mean the length of original pose sequence
11
+ test_list_path: [
12
+ # Format: [frame_interval, reference image, driving pose sequence]
13
+ [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"],
14
+ [2, "data/images/musk.jpg", "data/saved_pose/musk"],
15
+ [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"],
16
+ [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"],
17
+ [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337_dance"],
18
+ [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full_dance"]
19
+ ]
20
+
21
+
22
+ # default settings
23
+ TASK_TYPE: inference_unianimate_long_entrance
24
+ guide_scale: 2.5
25
+ vit_resolution: [224, 224]
26
+ use_fp16: True
27
+ batch_size: 1
28
+ latent_random_ref: True
29
+ chunk_size: 2
30
+ decoder_bs: 2
31
+ scale: 8
32
+ use_fps_condition: False
33
+ test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth
34
+ partial_keys: [
35
+ ['image', 'randomref', "dwpose"],
36
+ ]
37
+ embedder: {
38
+ 'type': 'FrozenOpenCLIPTextVisualEmbedder',
39
+ 'layer': 'penultimate',
40
+ 'pretrained': 'checkpoints/open_clip_pytorch_model.bin'
41
+ }
42
+
43
+
44
+ auto_encoder: {
45
+ 'type': 'AutoencoderKL',
46
+ 'ddconfig': {
47
+ 'double_z': True,
48
+ 'z_channels': 4,
49
+ 'resolution': 256,
50
+ 'in_channels': 3,
51
+ 'out_ch': 3,
52
+ 'ch': 128,
53
+ 'ch_mult': [1, 2, 4, 4],
54
+ 'num_res_blocks': 2,
55
+ 'attn_resolutions': [],
56
+ 'dropout': 0.0,
57
+ 'video_kernel_size': [3, 1, 1]
58
+ },
59
+ 'embed_dim': 4,
60
+ 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt'
61
+ }
62
+
63
+ UNet: {
64
+ 'type': 'UNetSD_UniAnimate',
65
+ 'config': None,
66
+ 'in_dim': 4,
67
+ 'dim': 320,
68
+ 'y_dim': 1024,
69
+ 'context_dim': 1024,
70
+ 'out_dim': 4,
71
+ 'dim_mult': [1, 2, 4, 4],
72
+ 'num_heads': 8,
73
+ 'head_dim': 64,
74
+ 'num_res_blocks': 2,
75
+ 'dropout': 0.1,
76
+ 'temporal_attention': True,
77
+ 'num_tokens': 4,
78
+ 'temporal_attn_times': 1,
79
+ 'use_checkpoint': True,
80
+ 'use_fps_condition': False,
81
+ 'use_sim_mask': False
82
+ }
83
+ video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose']
84
+ Diffusion: {
85
+ 'type': 'DiffusionDDIMLong',
86
+ 'schedule': 'linear_sd',
87
+ 'schedule_param': {
88
+ 'num_timesteps': 1000,
89
+ "init_beta": 0.00085,
90
+ "last_beta": 0.0120,
91
+ 'zero_terminal_snr': True,
92
+ },
93
+ 'mean_type': 'v',
94
+ 'loss_type': 'mse',
95
+ 'var_type': 'fixed_small',
96
+ 'rescale_timesteps': False,
97
+ 'noise_strength': 0.1
98
+ }
99
+ CPU_CLIP_VAE: True
100
+ context_batch_size: 1
101
+ noise_prior_value: 939 # or 999, 949
UniAnimate/dwpose/__init__.py ADDED
File without changes
UniAnimate/dwpose/onnxdet.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime
5
+
6
+ def nms(boxes, scores, nms_thr):
7
+ """Single class NMS implemented in Numpy."""
8
+ x1 = boxes[:, 0]
9
+ y1 = boxes[:, 1]
10
+ x2 = boxes[:, 2]
11
+ y2 = boxes[:, 3]
12
+
13
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
14
+ order = scores.argsort()[::-1]
15
+
16
+ keep = []
17
+ while order.size > 0:
18
+ i = order[0]
19
+ keep.append(i)
20
+ xx1 = np.maximum(x1[i], x1[order[1:]])
21
+ yy1 = np.maximum(y1[i], y1[order[1:]])
22
+ xx2 = np.minimum(x2[i], x2[order[1:]])
23
+ yy2 = np.minimum(y2[i], y2[order[1:]])
24
+
25
+ w = np.maximum(0.0, xx2 - xx1 + 1)
26
+ h = np.maximum(0.0, yy2 - yy1 + 1)
27
+ inter = w * h
28
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
29
+
30
+ inds = np.where(ovr <= nms_thr)[0]
31
+ order = order[inds + 1]
32
+
33
+ return keep
34
+
35
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
36
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
37
+ final_dets = []
38
+ num_classes = scores.shape[1]
39
+ for cls_ind in range(num_classes):
40
+ cls_scores = scores[:, cls_ind]
41
+ valid_score_mask = cls_scores > score_thr
42
+ if valid_score_mask.sum() == 0:
43
+ continue
44
+ else:
45
+ valid_scores = cls_scores[valid_score_mask]
46
+ valid_boxes = boxes[valid_score_mask]
47
+ keep = nms(valid_boxes, valid_scores, nms_thr)
48
+ if len(keep) > 0:
49
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
50
+ dets = np.concatenate(
51
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
52
+ )
53
+ final_dets.append(dets)
54
+ if len(final_dets) == 0:
55
+ return None
56
+ return np.concatenate(final_dets, 0)
57
+
58
+ def demo_postprocess(outputs, img_size, p6=False):
59
+ grids = []
60
+ expanded_strides = []
61
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
62
+
63
+ hsizes = [img_size[0] // stride for stride in strides]
64
+ wsizes = [img_size[1] // stride for stride in strides]
65
+
66
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
67
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
68
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
69
+ grids.append(grid)
70
+ shape = grid.shape[:2]
71
+ expanded_strides.append(np.full((*shape, 1), stride))
72
+
73
+ grids = np.concatenate(grids, 1)
74
+ expanded_strides = np.concatenate(expanded_strides, 1)
75
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
76
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
77
+
78
+ return outputs
79
+
80
+ def preprocess(img, input_size, swap=(2, 0, 1)):
81
+ if len(img.shape) == 3:
82
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
83
+ else:
84
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
85
+
86
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
87
+ resized_img = cv2.resize(
88
+ img,
89
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
90
+ interpolation=cv2.INTER_LINEAR,
91
+ ).astype(np.uint8)
92
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
93
+
94
+ padded_img = padded_img.transpose(swap)
95
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
96
+ return padded_img, r
97
+
98
+ def inference_detector(session, oriImg):
99
+ input_shape = (640,640)
100
+ img, ratio = preprocess(oriImg, input_shape)
101
+
102
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
103
+
104
+ output = session.run(None, ort_inputs)
105
+
106
+ predictions = demo_postprocess(output[0], input_shape)[0]
107
+
108
+ boxes = predictions[:, :4]
109
+ scores = predictions[:, 4:5] * predictions[:, 5:]
110
+
111
+ boxes_xyxy = np.ones_like(boxes)
112
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
113
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
114
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
115
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
116
+ boxes_xyxy /= ratio
117
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
118
+ if dets is not None:
119
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
120
+ isscore = final_scores>0.3
121
+ iscat = final_cls_inds == 0
122
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
123
+ final_boxes = final_boxes[isbbox]
124
+ else:
125
+ final_boxes = np.array([])
126
+
127
+ return final_boxes
UniAnimate/dwpose/onnxpose.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for RTMPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+
52
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53
+ """Inference RTMPose model.
54
+
55
+ Args:
56
+ sess (ort.InferenceSession): ONNXRuntime session.
57
+ img (np.ndarray): Input image in shape.
58
+
59
+ Returns:
60
+ outputs (np.ndarray): Output of RTMPose model.
61
+ """
62
+ all_out = []
63
+ # build input
64
+ for i in range(len(img)):
65
+ input = [img[i].transpose(2, 0, 1)]
66
+
67
+ # build output
68
+ sess_input = {sess.get_inputs()[0].name: input}
69
+ sess_output = []
70
+ for out in sess.get_outputs():
71
+ sess_output.append(out.name)
72
+
73
+ # run model
74
+ outputs = sess.run(sess_output, sess_input)
75
+ all_out.append(outputs)
76
+
77
+ return all_out
78
+
79
+
80
+ def postprocess(outputs: List[np.ndarray],
81
+ model_input_size: Tuple[int, int],
82
+ center: Tuple[int, int],
83
+ scale: Tuple[int, int],
84
+ simcc_split_ratio: float = 2.0
85
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
+ """Postprocess for RTMPose model output.
87
+
88
+ Args:
89
+ outputs (np.ndarray): Output of RTMPose model.
90
+ model_input_size (tuple): RTMPose model Input image size.
91
+ center (tuple): Center of bbox in shape (x, y).
92
+ scale (tuple): Scale of bbox in shape (w, h).
93
+ simcc_split_ratio (float): Split ratio of simcc.
94
+
95
+ Returns:
96
+ tuple:
97
+ - keypoints (np.ndarray): Rescaled keypoints.
98
+ - scores (np.ndarray): Model predict scores.
99
+ """
100
+ all_key = []
101
+ all_score = []
102
+ for i in range(len(outputs)):
103
+ # use simcc to decode
104
+ simcc_x, simcc_y = outputs[i]
105
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106
+
107
+ # rescale keypoints
108
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109
+ all_key.append(keypoints[0])
110
+ all_score.append(scores[0])
111
+
112
+ return np.array(all_key), np.array(all_score)
113
+
114
+
115
+ def bbox_xyxy2cs(bbox: np.ndarray,
116
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
118
+
119
+ Args:
120
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121
+ as (left, top, right, bottom)
122
+ padding (float): BBox padding factor that will be multilied to scale.
123
+ Default: 1.0
124
+
125
+ Returns:
126
+ tuple: A tuple containing center and scale.
127
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128
+ (n, 2)
129
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ """
132
+ # convert single bbox from (4, ) to (1, 4)
133
+ dim = bbox.ndim
134
+ if dim == 1:
135
+ bbox = bbox[None, :]
136
+
137
+ # get bbox center and scale
138
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
141
+
142
+ if dim == 1:
143
+ center = center[0]
144
+ scale = scale[0]
145
+
146
+ return center, scale
147
+
148
+
149
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
150
+ aspect_ratio: float) -> np.ndarray:
151
+ """Extend the scale to match the given aspect ratio.
152
+
153
+ Args:
154
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
155
+ aspect_ratio (float): The ratio of ``w/h``
156
+
157
+ Returns:
158
+ np.ndarray: The reshaped image scale in (2, )
159
+ """
160
+ w, h = np.hsplit(bbox_scale, [1])
161
+ bbox_scale = np.where(w > h * aspect_ratio,
162
+ np.hstack([w, w / aspect_ratio]),
163
+ np.hstack([h * aspect_ratio, h]))
164
+ return bbox_scale
165
+
166
+
167
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168
+ """Rotate a point by an angle.
169
+
170
+ Args:
171
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172
+ angle_rad (float): rotation angle in radian
173
+
174
+ Returns:
175
+ np.ndarray: Rotated point in shape (2, )
176
+ """
177
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
179
+ return rot_mat @ pt
180
+
181
+
182
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183
+ """To calculate the affine matrix, three pairs of points are required. This
184
+ function is used to get the 3rd point, given 2D points a & b.
185
+
186
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
187
+ anticlockwise, using b as the rotation center.
188
+
189
+ Args:
190
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
191
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
192
+
193
+ Returns:
194
+ np.ndarray: The 3rd point.
195
+ """
196
+ direction = a - b
197
+ c = b + np.r_[-direction[1], direction[0]]
198
+ return c
199
+
200
+
201
+ def get_warp_matrix(center: np.ndarray,
202
+ scale: np.ndarray,
203
+ rot: float,
204
+ output_size: Tuple[int, int],
205
+ shift: Tuple[float, float] = (0., 0.),
206
+ inv: bool = False) -> np.ndarray:
207
+ """Calculate the affine transformation matrix that can warp the bbox area
208
+ in the input image to the output size.
209
+
210
+ Args:
211
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
212
+ scale (np.ndarray[2, ]): Scale of the bounding box
213
+ wrt [width, height].
214
+ rot (float): Rotation angle (degree).
215
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
216
+ destination heatmaps.
217
+ shift (0-100%): Shift translation ratio wrt the width/height.
218
+ Default (0., 0.).
219
+ inv (bool): Option to inverse the affine transform direction.
220
+ (inv=False: src->dst or inv=True: dst->src)
221
+
222
+ Returns:
223
+ np.ndarray: A 2x3 transformation matrix
224
+ """
225
+ shift = np.array(shift)
226
+ src_w = scale[0]
227
+ dst_w = output_size[0]
228
+ dst_h = output_size[1]
229
+
230
+ # compute transformation matrix
231
+ rot_rad = np.deg2rad(rot)
232
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233
+ dst_dir = np.array([0., dst_w * -0.5])
234
+
235
+ # get four corners of the src rectangle in the original image
236
+ src = np.zeros((3, 2), dtype=np.float32)
237
+ src[0, :] = center + scale * shift
238
+ src[1, :] = center + src_dir + scale * shift
239
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240
+
241
+ # get four corners of the dst rectangle in the input image
242
+ dst = np.zeros((3, 2), dtype=np.float32)
243
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246
+
247
+ if inv:
248
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249
+ else:
250
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251
+
252
+ return warp_mat
253
+
254
+
255
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Get the bbox image as the model input by affine transform.
258
+
259
+ Args:
260
+ input_size (dict): The input size of the model.
261
+ bbox_scale (dict): The bbox scale of the img.
262
+ bbox_center (dict): The bbox center of the img.
263
+ img (np.ndarray): The original image.
264
+
265
+ Returns:
266
+ tuple: A tuple containing center and scale.
267
+ - np.ndarray[float32]: img after affine transform.
268
+ - np.ndarray[float32]: bbox scale after affine transform.
269
+ """
270
+ w, h = input_size
271
+ warp_size = (int(w), int(h))
272
+
273
+ # reshape bbox to fixed aspect ratio
274
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275
+
276
+ # get the affine matrix
277
+ center = bbox_center
278
+ scale = bbox_scale
279
+ rot = 0
280
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281
+
282
+ # do affine transform
283
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284
+
285
+ return img, bbox_scale
286
+
287
+
288
+ def get_simcc_maximum(simcc_x: np.ndarray,
289
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290
+ """Get maximum response location and value from simcc representations.
291
+
292
+ Note:
293
+ instance number: N
294
+ num_keypoints: K
295
+ heatmap height: H
296
+ heatmap width: W
297
+
298
+ Args:
299
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301
+
302
+ Returns:
303
+ tuple:
304
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
305
+ (K, 2) or (N, K, 2)
306
+ - vals (np.ndarray): values of maximum heatmap responses in shape
307
+ (K,) or (N, K)
308
+ """
309
+ N, K, Wx = simcc_x.shape
310
+ simcc_x = simcc_x.reshape(N * K, -1)
311
+ simcc_y = simcc_y.reshape(N * K, -1)
312
+
313
+ # get maximum value locations
314
+ x_locs = np.argmax(simcc_x, axis=1)
315
+ y_locs = np.argmax(simcc_y, axis=1)
316
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317
+ max_val_x = np.amax(simcc_x, axis=1)
318
+ max_val_y = np.amax(simcc_y, axis=1)
319
+
320
+ # get maximum value across x and y axis
321
+ mask = max_val_x > max_val_y
322
+ max_val_x[mask] = max_val_y[mask]
323
+ vals = max_val_x
324
+ locs[vals <= 0.] = -1
325
+
326
+ # reshape
327
+ locs = locs.reshape(N, K, 2)
328
+ vals = vals.reshape(N, K)
329
+
330
+ return locs, vals
331
+
332
+
333
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335
+ """Modulate simcc distribution with Gaussian.
336
+
337
+ Args:
338
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340
+ simcc_split_ratio (int): The split ratio of simcc.
341
+
342
+ Returns:
343
+ tuple: A tuple containing center and scale.
344
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
346
+ """
347
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348
+ keypoints /= simcc_split_ratio
349
+
350
+ return keypoints, scores
351
+
352
+
353
+ def inference_pose(session, out_bbox, oriImg):
354
+ h, w = session.get_inputs()[0].shape[2:]
355
+ model_input_size = (w, h)
356
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
357
+ outputs = inference(session, resized_img)
358
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
359
+
360
+ return keypoints, scores
UniAnimate/dwpose/util.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import cv2
5
+
6
+
7
+ eps = 0.01
8
+
9
+
10
+ def smart_resize(x, s):
11
+ Ht, Wt = s
12
+ if x.ndim == 2:
13
+ Ho, Wo = x.shape
14
+ Co = 1
15
+ else:
16
+ Ho, Wo, Co = x.shape
17
+ if Co == 3 or Co == 1:
18
+ k = float(Ht + Wt) / float(Ho + Wo)
19
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
20
+ else:
21
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
22
+
23
+
24
+ def smart_resize_k(x, fx, fy):
25
+ if x.ndim == 2:
26
+ Ho, Wo = x.shape
27
+ Co = 1
28
+ else:
29
+ Ho, Wo, Co = x.shape
30
+ Ht, Wt = Ho * fy, Wo * fx
31
+ if Co == 3 or Co == 1:
32
+ k = float(Ht + Wt) / float(Ho + Wo)
33
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
34
+ else:
35
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
36
+
37
+
38
+ def padRightDownCorner(img, stride, padValue):
39
+ h = img.shape[0]
40
+ w = img.shape[1]
41
+
42
+ pad = 4 * [None]
43
+ pad[0] = 0 # up
44
+ pad[1] = 0 # left
45
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
46
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
47
+
48
+ img_padded = img
49
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
50
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
51
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
52
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
53
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
54
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
55
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
56
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
57
+
58
+ return img_padded, pad
59
+
60
+
61
+ def transfer(model, model_weights):
62
+ transfered_model_weights = {}
63
+ for weights_name in model.state_dict().keys():
64
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
65
+ return transfered_model_weights
66
+
67
+
68
+ def draw_bodypose(canvas, candidate, subset):
69
+ H, W, C = canvas.shape
70
+ candidate = np.array(candidate)
71
+ subset = np.array(subset)
72
+
73
+ stickwidth = 4
74
+
75
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
76
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
77
+ [1, 16], [16, 18], [3, 17], [6, 18]]
78
+
79
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
80
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
81
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
82
+
83
+ for i in range(17):
84
+ for n in range(len(subset)):
85
+ index = subset[n][np.array(limbSeq[i]) - 1]
86
+ if -1 in index:
87
+ continue
88
+ Y = candidate[index.astype(int), 0] * float(W)
89
+ X = candidate[index.astype(int), 1] * float(H)
90
+ mX = np.mean(X)
91
+ mY = np.mean(Y)
92
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
93
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
94
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
95
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
96
+
97
+ canvas = (canvas * 0.6).astype(np.uint8)
98
+
99
+ for i in range(18):
100
+ for n in range(len(subset)):
101
+ index = int(subset[n][i])
102
+ if index == -1:
103
+ continue
104
+ x, y = candidate[index][0:2]
105
+ x = int(x * W)
106
+ y = int(y * H)
107
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
108
+
109
+ return canvas
110
+
111
+
112
+ def draw_body_and_foot(canvas, candidate, subset):
113
+ H, W, C = canvas.shape
114
+ candidate = np.array(candidate)
115
+ subset = np.array(subset)
116
+
117
+ stickwidth = 4
118
+
119
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
120
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
121
+ [1, 16], [16, 18], [14,19], [11, 20]]
122
+
123
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
124
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
125
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [170, 255, 255], [255, 255, 0]]
126
+
127
+ for i in range(19):
128
+ for n in range(len(subset)):
129
+ index = subset[n][np.array(limbSeq[i]) - 1]
130
+ if -1 in index:
131
+ continue
132
+ Y = candidate[index.astype(int), 0] * float(W)
133
+ X = candidate[index.astype(int), 1] * float(H)
134
+ mX = np.mean(X)
135
+ mY = np.mean(Y)
136
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
137
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
138
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
139
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
140
+
141
+ canvas = (canvas * 0.6).astype(np.uint8)
142
+
143
+ for i in range(20):
144
+ for n in range(len(subset)):
145
+ index = int(subset[n][i])
146
+ if index == -1:
147
+ continue
148
+ x, y = candidate[index][0:2]
149
+ x = int(x * W)
150
+ y = int(y * H)
151
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
152
+
153
+ return canvas
154
+
155
+
156
+ def draw_handpose(canvas, all_hand_peaks):
157
+ H, W, C = canvas.shape
158
+
159
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
160
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
161
+
162
+ for peaks in all_hand_peaks:
163
+ peaks = np.array(peaks)
164
+
165
+ for ie, e in enumerate(edges):
166
+ x1, y1 = peaks[e[0]]
167
+ x2, y2 = peaks[e[1]]
168
+ x1 = int(x1 * W)
169
+ y1 = int(y1 * H)
170
+ x2 = int(x2 * W)
171
+ y2 = int(y2 * H)
172
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
173
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
174
+
175
+ for i, keyponit in enumerate(peaks):
176
+ x, y = keyponit
177
+ x = int(x * W)
178
+ y = int(y * H)
179
+ if x > eps and y > eps:
180
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
181
+ return canvas
182
+
183
+
184
+ def draw_facepose(canvas, all_lmks):
185
+ H, W, C = canvas.shape
186
+ for lmks in all_lmks:
187
+ lmks = np.array(lmks)
188
+ for lmk in lmks:
189
+ x, y = lmk
190
+ x = int(x * W)
191
+ y = int(y * H)
192
+ if x > eps and y > eps:
193
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
194
+ return canvas
195
+
196
+
197
+ # detect hand according to body pose keypoints
198
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
199
+ def handDetect(candidate, subset, oriImg):
200
+ # right hand: wrist 4, elbow 3, shoulder 2
201
+ # left hand: wrist 7, elbow 6, shoulder 5
202
+ ratioWristElbow = 0.33
203
+ detect_result = []
204
+ image_height, image_width = oriImg.shape[0:2]
205
+ for person in subset.astype(int):
206
+ # if any of three not detected
207
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
208
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
209
+ if not (has_left or has_right):
210
+ continue
211
+ hands = []
212
+ #left hand
213
+ if has_left:
214
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
215
+ x1, y1 = candidate[left_shoulder_index][:2]
216
+ x2, y2 = candidate[left_elbow_index][:2]
217
+ x3, y3 = candidate[left_wrist_index][:2]
218
+ hands.append([x1, y1, x2, y2, x3, y3, True])
219
+ # right hand
220
+ if has_right:
221
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
222
+ x1, y1 = candidate[right_shoulder_index][:2]
223
+ x2, y2 = candidate[right_elbow_index][:2]
224
+ x3, y3 = candidate[right_wrist_index][:2]
225
+ hands.append([x1, y1, x2, y2, x3, y3, False])
226
+
227
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
228
+
229
+ x = x3 + ratioWristElbow * (x3 - x2)
230
+ y = y3 + ratioWristElbow * (y3 - y2)
231
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
232
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
233
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
234
+ # x-y refers to the center --> offset to topLeft point
235
+ # handRectangle.x -= handRectangle.width / 2.f;
236
+ # handRectangle.y -= handRectangle.height / 2.f;
237
+ x -= width / 2
238
+ y -= width / 2 # width = height
239
+ # overflow the image
240
+ if x < 0: x = 0
241
+ if y < 0: y = 0
242
+ width1 = width
243
+ width2 = width
244
+ if x + width > image_width: width1 = image_width - x
245
+ if y + width > image_height: width2 = image_height - y
246
+ width = min(width1, width2)
247
+ # the max hand box value is 20 pixels
248
+ if width >= 20:
249
+ detect_result.append([int(x), int(y), int(width), is_left])
250
+
251
+ '''
252
+ return value: [[x, y, w, True if left hand else False]].
253
+ width=height since the network require squared input.
254
+ x, y is the coordinate of top left
255
+ '''
256
+ return detect_result
257
+
258
+
259
+ # Written by Lvmin
260
+ def faceDetect(candidate, subset, oriImg):
261
+ # left right eye ear 14 15 16 17
262
+ detect_result = []
263
+ image_height, image_width = oriImg.shape[0:2]
264
+ for person in subset.astype(int):
265
+ has_head = person[0] > -1
266
+ if not has_head:
267
+ continue
268
+
269
+ has_left_eye = person[14] > -1
270
+ has_right_eye = person[15] > -1
271
+ has_left_ear = person[16] > -1
272
+ has_right_ear = person[17] > -1
273
+
274
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
275
+ continue
276
+
277
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
278
+
279
+ width = 0.0
280
+ x0, y0 = candidate[head][:2]
281
+
282
+ if has_left_eye:
283
+ x1, y1 = candidate[left_eye][:2]
284
+ d = max(abs(x0 - x1), abs(y0 - y1))
285
+ width = max(width, d * 3.0)
286
+
287
+ if has_right_eye:
288
+ x1, y1 = candidate[right_eye][:2]
289
+ d = max(abs(x0 - x1), abs(y0 - y1))
290
+ width = max(width, d * 3.0)
291
+
292
+ if has_left_ear:
293
+ x1, y1 = candidate[left_ear][:2]
294
+ d = max(abs(x0 - x1), abs(y0 - y1))
295
+ width = max(width, d * 1.5)
296
+
297
+ if has_right_ear:
298
+ x1, y1 = candidate[right_ear][:2]
299
+ d = max(abs(x0 - x1), abs(y0 - y1))
300
+ width = max(width, d * 1.5)
301
+
302
+ x, y = x0, y0
303
+
304
+ x -= width
305
+ y -= width
306
+
307
+ if x < 0:
308
+ x = 0
309
+
310
+ if y < 0:
311
+ y = 0
312
+
313
+ width1 = width * 2
314
+ width2 = width * 2
315
+
316
+ if x + width > image_width:
317
+ width1 = image_width - x
318
+
319
+ if y + width > image_height:
320
+ width2 = image_height - y
321
+
322
+ width = min(width1, width2)
323
+
324
+ if width >= 20:
325
+ detect_result.append([int(x), int(y), int(width)])
326
+
327
+ return detect_result
328
+
329
+
330
+ # get max index of 2d array
331
+ def npmax(array):
332
+ arrayindex = array.argmax(1)
333
+ arrayvalue = array.max(1)
334
+ i = arrayvalue.argmax()
335
+ j = arrayindex[i]
336
+ return i, j
UniAnimate/dwpose/wholebody.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime as ort
5
+ from dwpose.onnxdet import inference_detector
6
+ from dwpose.onnxpose import inference_pose
7
+
8
+ class Wholebody:
9
+ def __init__(self):
10
+ device = 'cuda' # 'cpu' #
11
+ providers = ['CPUExecutionProvider'
12
+ ] if device == 'cpu' else ['CUDAExecutionProvider']
13
+ onnx_det = 'checkpoints/yolox_l.onnx'
14
+ onnx_pose = 'checkpoints/dw-ll_ucoco_384.onnx'
15
+
16
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
17
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
18
+
19
+ def __call__(self, oriImg):
20
+ det_result = inference_detector(self.session_det, oriImg)
21
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
22
+
23
+ keypoints_info = np.concatenate(
24
+ (keypoints, scores[..., None]), axis=-1)
25
+ # compute neck joint
26
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
27
+ # neck score when visualizing pred
28
+ neck[:, 2:4] = np.logical_and(
29
+ keypoints_info[:, 5, 2:4] > 0.3,
30
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
31
+ new_keypoints_info = np.insert(
32
+ keypoints_info, 17, neck, axis=1)
33
+ mmpose_idx = [
34
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
35
+ ]
36
+ openpose_idx = [
37
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
38
+ ]
39
+ new_keypoints_info[:, openpose_idx] = \
40
+ new_keypoints_info[:, mmpose_idx]
41
+ keypoints_info = new_keypoints_info
42
+
43
+ keypoints, scores = keypoints_info[
44
+ ..., :2], keypoints_info[..., 2]
45
+
46
+ return keypoints, scores
47
+
48
+
UniAnimate/environment.yaml ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: /mnt/user/miniconda3/envs/dtrans
2
+ channels:
3
+ - http://mirrors.aliyun.com/anaconda/pkgs/main
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - ca-certificates=2023.12.12=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_0
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgomp=11.2.0=h1234567_1
13
+ - libstdcxx-ng=11.2.0=h1234567_1
14
+ - ncurses=6.4=h6a678d5_0
15
+ - openssl=3.0.12=h7f8727e_0
16
+ - pip=23.3.1=py39h06a4308_0
17
+ - python=3.9.18=h955ad1f_0
18
+ - readline=8.2=h5eee18b_0
19
+ - setuptools=68.2.2=py39h06a4308_0
20
+ - sqlite=3.41.2=h5eee18b_0
21
+ - tk=8.6.12=h1ccaba5_0
22
+ - wheel=0.41.2=py39h06a4308_0
23
+ - xz=5.4.5=h5eee18b_0
24
+ - zlib=1.2.13=h5eee18b_0
25
+ - pip:
26
+ - aiofiles==23.2.1
27
+ - aiohttp==3.9.1
28
+ - aiosignal==1.3.1
29
+ - aliyun-python-sdk-core==2.14.0
30
+ - aliyun-python-sdk-kms==2.16.2
31
+ - altair==5.2.0
32
+ - annotated-types==0.6.0
33
+ - antlr4-python3-runtime==4.9.3
34
+ - anyio==4.2.0
35
+ - argparse==1.4.0
36
+ - asttokens==2.4.1
37
+ - async-timeout==4.0.3
38
+ - attrs==23.2.0
39
+ - automat==22.10.0
40
+ - beartype==0.16.4
41
+ - blessed==1.20.0
42
+ - buildtools==1.0.6
43
+ - causal-conv1d==1.1.3.post1
44
+ - certifi==2023.11.17
45
+ - cffi==1.16.0
46
+ - chardet==5.2.0
47
+ - charset-normalizer==3.3.2
48
+ - clean-fid==0.1.35
49
+ - click==8.1.7
50
+ - clip==1.0
51
+ - cmake==3.28.1
52
+ - colorama==0.4.6
53
+ - coloredlogs==15.0.1
54
+ - constantly==23.10.4
55
+ - contourpy==1.2.0
56
+ - crcmod==1.7
57
+ - cryptography==41.0.7
58
+ - cycler==0.12.1
59
+ - decorator==5.1.1
60
+ - decord==0.6.0
61
+ - diffusers==0.26.3
62
+ - docopt==0.6.2
63
+ - easydict==1.11
64
+ - einops==0.7.0
65
+ - exceptiongroup==1.2.0
66
+ - executing==2.0.1
67
+ - fairscale==0.4.13
68
+ - fastapi==0.109.0
69
+ - ffmpeg==1.4
70
+ - ffmpy==0.3.1
71
+ - filelock==3.13.1
72
+ - flatbuffers==24.3.25
73
+ - fonttools==4.47.2
74
+ - frozenlist==1.4.1
75
+ - fsspec==2023.12.2
76
+ - ftfy==6.1.3
77
+ - furl==2.1.3
78
+ - gpustat==1.1.1
79
+ - gradio==4.14.0
80
+ - gradio-client==0.8.0
81
+ - greenlet==3.0.3
82
+ - h11==0.14.0
83
+ - httpcore==1.0.2
84
+ - httpx==0.26.0
85
+ - huggingface-hub==0.20.2
86
+ - humanfriendly==10.0
87
+ - hyperlink==21.0.0
88
+ - idna==3.6
89
+ - imageio==2.33.1
90
+ - imageio-ffmpeg==0.4.9
91
+ - importlib-metadata==7.0.1
92
+ - importlib-resources==6.1.1
93
+ - incremental==22.10.0
94
+ - ipdb==0.13.13
95
+ - ipython==8.18.1
96
+ - jedi==0.19.1
97
+ - jinja2==3.1.3
98
+ - jmespath==0.10.0
99
+ - joblib==1.3.2
100
+ - jsonschema==4.21.0
101
+ - jsonschema-specifications==2023.12.1
102
+ - kiwisolver==1.4.5
103
+ - kornia==0.7.1
104
+ - lazy-loader==0.3
105
+ - lightning-utilities==0.10.0
106
+ - lit==17.0.6
107
+ - lpips==0.1.4
108
+ - mamba-ssm==1.1.4
109
+ - markdown-it-py==3.0.0
110
+ - markupsafe==2.1.3
111
+ - matplotlib==3.8.2
112
+ - matplotlib-inline==0.1.6
113
+ - mdurl==0.1.2
114
+ - motion-vector-extractor==1.0.6
115
+ - mpmath==1.3.0
116
+ - multidict==6.0.4
117
+ - mypy-extensions==1.0.0
118
+ - networkx==3.2.1
119
+ - ninja==1.11.1.1
120
+ - numpy==1.26.3
121
+ - nvidia-cublas-cu11==11.10.3.66
122
+ - nvidia-cublas-cu12==12.1.3.1
123
+ - nvidia-cuda-cupti-cu11==11.7.101
124
+ - nvidia-cuda-cupti-cu12==12.1.105
125
+ - nvidia-cuda-nvrtc-cu11==11.7.99
126
+ - nvidia-cuda-nvrtc-cu12==12.1.105
127
+ - nvidia-cuda-runtime-cu11==11.7.99
128
+ - nvidia-cuda-runtime-cu12==12.1.105
129
+ - nvidia-cudnn-cu11==8.5.0.96
130
+ - nvidia-cudnn-cu12==8.9.2.26
131
+ - nvidia-cufft-cu11==10.9.0.58
132
+ - nvidia-cufft-cu12==11.0.2.54
133
+ - nvidia-curand-cu11==10.2.10.91
134
+ - nvidia-curand-cu12==10.3.2.106
135
+ - nvidia-cusolver-cu11==11.4.0.1
136
+ - nvidia-cusolver-cu12==11.4.5.107
137
+ - nvidia-cusparse-cu11==11.7.4.91
138
+ - nvidia-cusparse-cu12==12.1.0.106
139
+ - nvidia-ml-py==12.535.133
140
+ - nvidia-nccl-cu11==2.14.3
141
+ - nvidia-nccl-cu12==2.19.3
142
+ - nvidia-nvjitlink-cu12==12.3.101
143
+ - nvidia-nvtx-cu11==11.7.91
144
+ - nvidia-nvtx-cu12==12.1.105
145
+ - omegaconf==2.3.0
146
+ - onnxruntime==1.18.0
147
+ - open-clip-torch==2.24.0
148
+ - opencv-python==4.5.3.56
149
+ - opencv-python-headless==4.9.0.80
150
+ - orderedmultidict==1.0.1
151
+ - orjson==3.9.10
152
+ - oss2==2.18.4
153
+ - packaging==23.2
154
+ - pandas==2.1.4
155
+ - parso==0.8.3
156
+ - pexpect==4.9.0
157
+ - pillow==10.2.0
158
+ - piq==0.8.0
159
+ - pkgconfig==1.5.5
160
+ - prompt-toolkit==3.0.43
161
+ - protobuf==4.25.2
162
+ - psutil==5.9.8
163
+ - ptflops==0.7.2.2
164
+ - ptyprocess==0.7.0
165
+ - pure-eval==0.2.2
166
+ - pycparser==2.21
167
+ - pycryptodome==3.20.0
168
+ - pydantic==2.5.3
169
+ - pydantic-core==2.14.6
170
+ - pydub==0.25.1
171
+ - pygments==2.17.2
172
+ - pynvml==11.5.0
173
+ - pyparsing==3.1.1
174
+ - pyre-extensions==0.0.29
175
+ - python-dateutil==2.8.2
176
+ - python-multipart==0.0.6
177
+ - pytorch-lightning==2.1.3
178
+ - pytz==2023.3.post1
179
+ - pyyaml==6.0.1
180
+ - redo==2.0.4
181
+ - referencing==0.32.1
182
+ - regex==2023.12.25
183
+ - requests==2.31.0
184
+ - rich==13.7.0
185
+ - rotary-embedding-torch==0.5.3
186
+ - rpds-py==0.17.1
187
+ - ruff==0.2.0
188
+ - safetensors==0.4.1
189
+ - scikit-image==0.22.0
190
+ - scikit-learn==1.4.0
191
+ - scipy==1.11.4
192
+ - semantic-version==2.10.0
193
+ - sentencepiece==0.1.99
194
+ - shellingham==1.5.4
195
+ - simplejson==3.19.2
196
+ - six==1.16.0
197
+ - sk-video==1.1.10
198
+ - sniffio==1.3.0
199
+ - sqlalchemy==2.0.27
200
+ - stack-data==0.6.3
201
+ - starlette==0.35.1
202
+ - sympy==1.12
203
+ - thop==0.1.1-2209072238
204
+ - threadpoolctl==3.2.0
205
+ - tifffile==2023.12.9
206
+ - timm==0.9.12
207
+ - tokenizers==0.15.0
208
+ - tomli==2.0.1
209
+ - tomlkit==0.12.0
210
+ - toolz==0.12.0
211
+ - torch==2.0.1+cu118
212
+ - torchaudio==2.0.2+cu118
213
+ - torchdiffeq==0.2.3
214
+ - torchmetrics==1.3.0.post0
215
+ - torchsde==0.2.6
216
+ - torchvision==0.15.2+cu118
217
+ - tqdm==4.66.1
218
+ - traitlets==5.14.1
219
+ - trampoline==0.1.2
220
+ - transformers==4.36.2
221
+ - triton==2.0.0
222
+ - twisted==23.10.0
223
+ - typer==0.9.0
224
+ - typing-extensions==4.9.0
225
+ - typing-inspect==0.9.0
226
+ - tzdata==2023.4
227
+ - urllib3==2.1.0
228
+ - uvicorn==0.26.0
229
+ - wcwidth==0.2.13
230
+ - websockets==11.0.3
231
+ - xformers==0.0.20
232
+ - yarl==1.9.4
233
+ - zipp==3.17.0
234
+ - zope-interface==6.2
235
+ - onnxruntime-gpu==1.13.1
236
+ prefix: /mnt/user/miniconda3/envs/dtrans
UniAnimate/inference.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import json
5
+ import math
6
+ import random
7
+ import logging
8
+ import itertools
9
+ import numpy as np
10
+
11
+ from utils.config import Config
12
+ from utils.registry_class import INFER_ENGINE
13
+
14
+ from tools import *
15
+
16
+ if __name__ == '__main__':
17
+ cfg_update = Config(load=True)
18
+ INFER_ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict)
UniAnimate/requirements.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # antlr4-python3-runtime==4.9.3
2
+ # anyio==4.2.0
3
+ # asttokens==2.4.1
4
+ # async-timeout==4.0.3
5
+ # attrs==23.2.0
6
+ # Automat==22.10.0
7
+ # beartype==0.16.4
8
+ # blessed==1.20.0
9
+ # buildtools==1.0.6
10
+ # # causal-conv1d==1.1.3.post1
11
+ # certifi==2023.11.17
12
+ # cffi==1.16.0
13
+ # chardet==5.2.0
14
+ # charset-normalizer==3.3.2
15
+ # clean-fid==0.1.35
16
+ # click==8.1.7
17
+ # # clip==1.0
18
+ # cmake==3.28.1
19
+ # colorama==0.4.6
20
+ # coloredlogs==15.0.1
21
+ # constantly==23.10.4
22
+ # contourpy==1.2.0
23
+ # crcmod==1.7
24
+ # cryptography==41.0.7
25
+ # cycler==0.12.1
26
+ # decorator==5.1.1
27
+ # decord==0.6.0
28
+ # diffusers==0.26.3
29
+ # docopt==0.6.2
30
+ easydict==1.11
31
+ einops==0.7.0
32
+ # exceptiongroup==1.2.0
33
+ # executing==2.0.1
34
+ fairscale==0.4.13
35
+ # fastapi==0.109.0
36
+ # ffmpeg==1.4
37
+ # ffmpy==0.3.1
38
+ # filelock==3.13.1
39
+ # flatbuffers==24.3.25
40
+ # fonttools==4.47.2
41
+ # frozenlist==1.4.1
42
+ # fsspec==2023.12.2
43
+ # ftfy==6.1.3
44
+ # furl==2.1.3
45
+ # gpustat==1.1.1
46
+ # gradio==4.14.0
47
+ # gradio_client==0.8.0
48
+ # greenlet==3.0.3
49
+ # h11==0.14.0
50
+ # httpcore==1.0.2
51
+ # httpx==0.26.0
52
+ # huggingface-hub==0.20.2
53
+ # humanfriendly==10.0
54
+ # hyperlink==21.0.0
55
+ # idna==3.6
56
+ imageio==2.33.1
57
+ imageio-ffmpeg==0.4.9
58
+ # importlib-metadata==7.0.1
59
+ # importlib-resources==6.1.1
60
+ # incremental==22.10.0
61
+ # ipdb==0.13.13
62
+ # ipython==8.18.1
63
+ # jedi==0.19.1
64
+ # Jinja2==3.1.3
65
+ # jmespath==0.10.0
66
+ # joblib==1.3.2
67
+ # jsonschema==4.21.0
68
+ # jsonschema-specifications==2023.12.1
69
+ # kiwisolver==1.4.5
70
+ # kornia==0.7.1
71
+ # lazy_loader==0.3
72
+ # lightning-utilities==0.10.0
73
+ # lit==17.0.6
74
+ # lpips==0.1.4
75
+ # markdown-it-py==3.0.0
76
+ # MarkupSafe==2.1.3
77
+ matplotlib==3.8.2
78
+ matplotlib-inline==0.1.6
79
+ # mdurl==0.1.2
80
+ # # motion-vector-extractor==1.0.6
81
+ # mpmath==1.3.0
82
+ # multidict==6.0.4
83
+ # mypy-extensions==1.0.0
84
+ # networkx==3.2.1
85
+ # ninja==1.11.1.1
86
+ # numpy==1.26.3
87
+ # nvidia-cublas-cu11==11.10.3.66
88
+ # nvidia-cublas-cu12==12.1.3.1
89
+ # nvidia-cuda-cupti-cu11==11.7.101
90
+ # nvidia-cuda-cupti-cu12==12.1.105
91
+ # nvidia-cuda-nvrtc-cu11==11.7.99
92
+ # nvidia-cuda-nvrtc-cu12==12.1.105
93
+ # nvidia-cuda-runtime-cu11==11.7.99
94
+ # nvidia-cuda-runtime-cu12==12.1.105
95
+ # nvidia-cudnn-cu11==8.5.0.96
96
+ # nvidia-cudnn-cu12==8.9.2.26
97
+ # nvidia-cufft-cu11==10.9.0.58
98
+ # nvidia-cufft-cu12==11.0.2.54
99
+ # nvidia-curand-cu11==10.2.10.91
100
+ # nvidia-curand-cu12==10.3.2.106
101
+ # nvidia-cusolver-cu11==11.4.0.1
102
+ # nvidia-cusolver-cu12==11.4.5.107
103
+ # nvidia-cusparse-cu11==11.7.4.91
104
+ # nvidia-cusparse-cu12==12.1.0.106
105
+ # nvidia-ml-py==12.535.133
106
+ # nvidia-nccl-cu11==2.14.3
107
+ # nvidia-nccl-cu12==2.19.3
108
+ # nvidia-nvjitlink-cu12==12.3.101
109
+ # nvidia-nvtx-cu11==11.7.91
110
+ # nvidia-nvtx-cu12==12.1.105
111
+ # omegaconf==2.3.0
112
+ onnxruntime==1.18.0
113
+ open-clip-torch==2.24.0
114
+ opencv-python==4.5.3.56
115
+ opencv-python-headless==4.9.0.80
116
+ # orderedmultidict==1.0.1
117
+ # orjson==3.9.10
118
+ oss2==2.18.4
119
+ # # packaging==23.2
120
+ # pandas==2.1.4
121
+ # parso==0.8.3
122
+ # pexpect==4.9.0
123
+ pillow==10.2.0
124
+ # piq==0.8.0
125
+ # pkgconfig==1.5.5
126
+ # prompt-toolkit==3.0.43
127
+ # protobuf==4.25.2
128
+ # psutil==5.9.8
129
+ ptflops==0.7.2.2
130
+ # ptyprocess==0.7.0
131
+ # pure-eval==0.2.2
132
+ # pycparser==2.21
133
+ # pycryptodome==3.20.0
134
+ # pydantic==2.5.3
135
+ # pydantic_core==2.14.6
136
+ # pydub==0.25.1
137
+ # Pygments==2.17.2
138
+ pynvml==11.5.0
139
+ # pyparsing==3.1.1
140
+ # pyre-extensions==0.0.29
141
+ # python-dateutil==2.8.2
142
+ # python-multipart==0.0.6
143
+ # pytorch-lightning==2.1.3
144
+ # pytz==2023.3.post1
145
+ PyYAML==6.0.1
146
+ # redo==2.0.4
147
+ # referencing==0.32.1
148
+ # regex==2023.12.25
149
+ requests==2.31.0
150
+ # rich==13.7.0
151
+ rotary-embedding-torch==0.5.3
152
+ # rpds-py==0.17.1
153
+ # ruff==0.2.0
154
+ # safetensors==0.4.1
155
+ # scikit-image==0.22.0
156
+ # scikit-learn==1.4.0
157
+ # scipy==1.11.4
158
+ # semantic-version==2.10.0
159
+ # sentencepiece==0.1.99
160
+ # shellingham==1.5.4
161
+ simplejson==3.19.2
162
+ # six==1.16.0
163
+ # sk-video==1.1.10
164
+ # sniffio==1.3.0
165
+ # SQLAlchemy==2.0.27
166
+ # stack-data==0.6.3
167
+ # starlette==0.35.1
168
+ # sympy==1.12
169
+ thop==0.1.1.post2209072238
170
+ # threadpoolctl==3.2.0
171
+ # tifffile==2023.12.9
172
+ # timm==0.9.12
173
+ # tokenizers==0.15.0
174
+ # tomli==2.0.1
175
+ # tomlkit==0.12.0
176
+ # toolz==0.12.0
177
+ torch==2.0.1+cu118
178
+ # torchaudio==2.0.2+cu118
179
+ # torchdiffeq==0.2.3
180
+ # torchmetrics==1.3.0.post0
181
+ torchsde==0.2.6
182
+ torchvision==0.15.2+cu118
183
+ tqdm==4.66.1
184
+ # traitlets==5.14.1
185
+ # trampoline==0.1.2
186
+ # transformers==4.36.2
187
+ # triton==2.0.0
188
+ # Twisted==23.10.0
189
+ # typer==0.9.0
190
+ typing-inspect==0.9.0
191
+ typing_extensions==4.9.0
192
+ # tzdata==2023.4
193
+ # urllib3==2.1.0
194
+ # uvicorn==0.26.0
195
+ # wcwidth==0.2.13
196
+ # websockets==11.0.3
197
+ xformers==0.0.20
198
+ # yarl==1.9.4
199
+ # zipp==3.17.0
200
+ # zope.interface==6.2
201
+ onnxruntime-gpu==1.13.1
UniAnimate/run_align_pose.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+
7
+ import os
8
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+ import json
13
+ import copy
14
+ import torch
15
+ import random
16
+ import argparse
17
+ import shutil
18
+ import tempfile
19
+ import subprocess
20
+ import numpy as np
21
+ import math
22
+
23
+ import torch.multiprocessing as mp
24
+ import torch.distributed as dist
25
+ import pickle
26
+ import logging
27
+ from io import BytesIO
28
+ import oss2 as oss
29
+ import os.path as osp
30
+
31
+ import sys
32
+ import dwpose.util as util
33
+ from dwpose.wholebody import Wholebody
34
+
35
+
36
+ def smoothing_factor(t_e, cutoff):
37
+ r = 2 * math.pi * cutoff * t_e
38
+ return r / (r + 1)
39
+
40
+
41
+ def exponential_smoothing(a, x, x_prev):
42
+ return a * x + (1 - a) * x_prev
43
+
44
+
45
+ class OneEuroFilter:
46
+ def __init__(self, t0, x0, dx0=0.0, min_cutoff=1.0, beta=0.0,
47
+ d_cutoff=1.0):
48
+ """Initialize the one euro filter."""
49
+ # The parameters.
50
+ self.min_cutoff = float(min_cutoff)
51
+ self.beta = float(beta)
52
+ self.d_cutoff = float(d_cutoff)
53
+ # Previous values.
54
+ self.x_prev = x0
55
+ self.dx_prev = float(dx0)
56
+ self.t_prev = float(t0)
57
+
58
+ def __call__(self, t, x):
59
+ """Compute the filtered signal."""
60
+ t_e = t - self.t_prev
61
+
62
+ # The filtered derivative of the signal.
63
+ a_d = smoothing_factor(t_e, self.d_cutoff)
64
+ dx = (x - self.x_prev) / t_e
65
+ dx_hat = exponential_smoothing(a_d, dx, self.dx_prev)
66
+
67
+ # The filtered signal.
68
+ cutoff = self.min_cutoff + self.beta * abs(dx_hat)
69
+ a = smoothing_factor(t_e, cutoff)
70
+ x_hat = exponential_smoothing(a, x, self.x_prev)
71
+
72
+ # Memorize the previous values.
73
+ self.x_prev = x_hat
74
+ self.dx_prev = dx_hat
75
+ self.t_prev = t
76
+
77
+ return x_hat
78
+
79
+
80
+ def get_logger(name="essmc2"):
81
+ logger = logging.getLogger(name)
82
+ logger.propagate = False
83
+ if len(logger.handlers) == 0:
84
+ std_handler = logging.StreamHandler(sys.stdout)
85
+ formatter = logging.Formatter(
86
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87
+ std_handler.setFormatter(formatter)
88
+ std_handler.setLevel(logging.INFO)
89
+ logger.setLevel(logging.INFO)
90
+ logger.addHandler(std_handler)
91
+ return logger
92
+
93
+ class DWposeDetector:
94
+ def __init__(self):
95
+
96
+ self.pose_estimation = Wholebody()
97
+
98
+ def __call__(self, oriImg):
99
+ oriImg = oriImg.copy()
100
+ H, W, C = oriImg.shape
101
+ with torch.no_grad():
102
+ candidate, subset = self.pose_estimation(oriImg)
103
+ candidate = candidate[0][np.newaxis, :, :]
104
+ subset = subset[0][np.newaxis, :]
105
+ nums, keys, locs = candidate.shape
106
+ candidate[..., 0] /= float(W)
107
+ candidate[..., 1] /= float(H)
108
+ body = candidate[:,:18].copy()
109
+ body = body.reshape(nums*18, locs)
110
+ score = subset[:,:18].copy()
111
+
112
+ for i in range(len(score)):
113
+ for j in range(len(score[i])):
114
+ if score[i][j] > 0.3:
115
+ score[i][j] = int(18*i+j)
116
+ else:
117
+ score[i][j] = -1
118
+
119
+ un_visible = subset<0.3
120
+ candidate[un_visible] = -1
121
+
122
+ bodyfoot_score = subset[:,:24].copy()
123
+ for i in range(len(bodyfoot_score)):
124
+ for j in range(len(bodyfoot_score[i])):
125
+ if bodyfoot_score[i][j] > 0.3:
126
+ bodyfoot_score[i][j] = int(18*i+j)
127
+ else:
128
+ bodyfoot_score[i][j] = -1
129
+ if -1 not in bodyfoot_score[:,18] and -1 not in bodyfoot_score[:,19]:
130
+ bodyfoot_score[:,18] = np.array([18.])
131
+ else:
132
+ bodyfoot_score[:,18] = np.array([-1.])
133
+ if -1 not in bodyfoot_score[:,21] and -1 not in bodyfoot_score[:,22]:
134
+ bodyfoot_score[:,19] = np.array([19.])
135
+ else:
136
+ bodyfoot_score[:,19] = np.array([-1.])
137
+ bodyfoot_score = bodyfoot_score[:, :20]
138
+
139
+ bodyfoot = candidate[:,:24].copy()
140
+
141
+ for i in range(nums):
142
+ if -1 not in bodyfoot[i][18] and -1 not in bodyfoot[i][19]:
143
+ bodyfoot[i][18] = (bodyfoot[i][18]+bodyfoot[i][19])/2
144
+ else:
145
+ bodyfoot[i][18] = np.array([-1., -1.])
146
+ if -1 not in bodyfoot[i][21] and -1 not in bodyfoot[i][22]:
147
+ bodyfoot[i][19] = (bodyfoot[i][21]+bodyfoot[i][22])/2
148
+ else:
149
+ bodyfoot[i][19] = np.array([-1., -1.])
150
+
151
+ bodyfoot = bodyfoot[:,:20,:]
152
+ bodyfoot = bodyfoot.reshape(nums*20, locs)
153
+
154
+ foot = candidate[:,18:24]
155
+
156
+ faces = candidate[:,24:92]
157
+
158
+ hands = candidate[:,92:113]
159
+ hands = np.vstack([hands, candidate[:,113:]])
160
+
161
+ # bodies = dict(candidate=body, subset=score)
162
+ bodies = dict(candidate=bodyfoot, subset=bodyfoot_score)
163
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
164
+
165
+ # return draw_pose(pose, H, W)
166
+ return pose
167
+
168
+ def draw_pose(pose, H, W):
169
+ bodies = pose['bodies']
170
+ faces = pose['faces']
171
+ hands = pose['hands']
172
+ candidate = bodies['candidate']
173
+ subset = bodies['subset']
174
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
175
+
176
+ canvas = util.draw_body_and_foot(canvas, candidate, subset)
177
+
178
+ canvas = util.draw_handpose(canvas, hands)
179
+
180
+ canvas_without_face = copy.deepcopy(canvas)
181
+
182
+ canvas = util.draw_facepose(canvas, faces)
183
+
184
+ return canvas_without_face, canvas
185
+
186
+ def dw_func(_id, frame, dwpose_model, dwpose_woface_folder='tmp_dwpose_wo_face', dwpose_withface_folder='tmp_dwpose_with_face'):
187
+
188
+ # frame = cv2.imread(frame_name, cv2.IMREAD_COLOR)
189
+ pose = dwpose_model(frame)
190
+
191
+ return pose
192
+
193
+
194
+ def mp_main(args):
195
+
196
+ if args.source_video_paths.endswith('mp4'):
197
+ video_paths = [args.source_video_paths]
198
+ else:
199
+ # video list
200
+ video_paths = [os.path.join(args.source_video_paths, frame_name) for frame_name in os.listdir(args.source_video_paths)]
201
+
202
+
203
+ logger.info("There are {} videos for extracting poses".format(len(video_paths)))
204
+
205
+ logger.info('LOAD: DW Pose Model')
206
+ dwpose_model = DWposeDetector()
207
+
208
+ results_vis = []
209
+ for i, file_path in enumerate(video_paths):
210
+ logger.info(f"{i}/{len(video_paths)}, {file_path}")
211
+ videoCapture = cv2.VideoCapture(file_path)
212
+ while videoCapture.isOpened():
213
+ # get a frame
214
+ ret, frame = videoCapture.read()
215
+ if ret:
216
+ pose = dw_func(i, frame, dwpose_model)
217
+ results_vis.append(pose)
218
+ else:
219
+ break
220
+ logger.info(f'all frames in {file_path} have been read.')
221
+ videoCapture.release()
222
+
223
+ # added
224
+ # results_vis = results_vis[8:]
225
+ print(len(results_vis))
226
+
227
+ ref_name = args.ref_name
228
+ save_motion = args.saved_pose_dir
229
+ os.system(f'rm -rf {save_motion}');
230
+ os.makedirs(save_motion, exist_ok=True)
231
+ save_warp = args.saved_pose_dir
232
+ # os.makedirs(save_warp, exist_ok=True)
233
+
234
+ ref_frame = cv2.imread(ref_name, cv2.IMREAD_COLOR)
235
+ pose_ref = dw_func(i, ref_frame, dwpose_model)
236
+
237
+ bodies = results_vis[0]['bodies']
238
+ faces = results_vis[0]['faces']
239
+ hands = results_vis[0]['hands']
240
+ candidate = bodies['candidate']
241
+
242
+ ref_bodies = pose_ref['bodies']
243
+ ref_faces = pose_ref['faces']
244
+ ref_hands = pose_ref['hands']
245
+ ref_candidate = ref_bodies['candidate']
246
+
247
+
248
+ ref_2_x = ref_candidate[2][0]
249
+ ref_2_y = ref_candidate[2][1]
250
+ ref_5_x = ref_candidate[5][0]
251
+ ref_5_y = ref_candidate[5][1]
252
+ ref_8_x = ref_candidate[8][0]
253
+ ref_8_y = ref_candidate[8][1]
254
+ ref_11_x = ref_candidate[11][0]
255
+ ref_11_y = ref_candidate[11][1]
256
+ ref_center1 = 0.5*(ref_candidate[2]+ref_candidate[5])
257
+ ref_center2 = 0.5*(ref_candidate[8]+ref_candidate[11])
258
+
259
+ zero_2_x = candidate[2][0]
260
+ zero_2_y = candidate[2][1]
261
+ zero_5_x = candidate[5][0]
262
+ zero_5_y = candidate[5][1]
263
+ zero_8_x = candidate[8][0]
264
+ zero_8_y = candidate[8][1]
265
+ zero_11_x = candidate[11][0]
266
+ zero_11_y = candidate[11][1]
267
+ zero_center1 = 0.5*(candidate[2]+candidate[5])
268
+ zero_center2 = 0.5*(candidate[8]+candidate[11])
269
+
270
+ x_ratio = (ref_5_x-ref_2_x)/(zero_5_x-zero_2_x)
271
+ y_ratio = (ref_center2[1]-ref_center1[1])/(zero_center2[1]-zero_center1[1])
272
+
273
+ results_vis[0]['bodies']['candidate'][:,0] *= x_ratio
274
+ results_vis[0]['bodies']['candidate'][:,1] *= y_ratio
275
+ results_vis[0]['faces'][:,:,0] *= x_ratio
276
+ results_vis[0]['faces'][:,:,1] *= y_ratio
277
+ results_vis[0]['hands'][:,:,0] *= x_ratio
278
+ results_vis[0]['hands'][:,:,1] *= y_ratio
279
+
280
+ ########neck########
281
+ l_neck_ref = ((ref_candidate[0][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[0][1] - ref_candidate[1][1]) ** 2) ** 0.5
282
+ l_neck_0 = ((candidate[0][0] - candidate[1][0]) ** 2 + (candidate[0][1] - candidate[1][1]) ** 2) ** 0.5
283
+ neck_ratio = l_neck_ref / l_neck_0
284
+
285
+ x_offset_neck = (candidate[1][0]-candidate[0][0])*(1.-neck_ratio)
286
+ y_offset_neck = (candidate[1][1]-candidate[0][1])*(1.-neck_ratio)
287
+
288
+ results_vis[0]['bodies']['candidate'][0,0] += x_offset_neck
289
+ results_vis[0]['bodies']['candidate'][0,1] += y_offset_neck
290
+ results_vis[0]['bodies']['candidate'][14,0] += x_offset_neck
291
+ results_vis[0]['bodies']['candidate'][14,1] += y_offset_neck
292
+ results_vis[0]['bodies']['candidate'][15,0] += x_offset_neck
293
+ results_vis[0]['bodies']['candidate'][15,1] += y_offset_neck
294
+ results_vis[0]['bodies']['candidate'][16,0] += x_offset_neck
295
+ results_vis[0]['bodies']['candidate'][16,1] += y_offset_neck
296
+ results_vis[0]['bodies']['candidate'][17,0] += x_offset_neck
297
+ results_vis[0]['bodies']['candidate'][17,1] += y_offset_neck
298
+
299
+ ########shoulder2########
300
+ l_shoulder2_ref = ((ref_candidate[2][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[2][1] - ref_candidate[1][1]) ** 2) ** 0.5
301
+ l_shoulder2_0 = ((candidate[2][0] - candidate[1][0]) ** 2 + (candidate[2][1] - candidate[1][1]) ** 2) ** 0.5
302
+
303
+ shoulder2_ratio = l_shoulder2_ref / l_shoulder2_0
304
+
305
+ x_offset_shoulder2 = (candidate[1][0]-candidate[2][0])*(1.-shoulder2_ratio)
306
+ y_offset_shoulder2 = (candidate[1][1]-candidate[2][1])*(1.-shoulder2_ratio)
307
+
308
+ results_vis[0]['bodies']['candidate'][2,0] += x_offset_shoulder2
309
+ results_vis[0]['bodies']['candidate'][2,1] += y_offset_shoulder2
310
+ results_vis[0]['bodies']['candidate'][3,0] += x_offset_shoulder2
311
+ results_vis[0]['bodies']['candidate'][3,1] += y_offset_shoulder2
312
+ results_vis[0]['bodies']['candidate'][4,0] += x_offset_shoulder2
313
+ results_vis[0]['bodies']['candidate'][4,1] += y_offset_shoulder2
314
+ results_vis[0]['hands'][1,:,0] += x_offset_shoulder2
315
+ results_vis[0]['hands'][1,:,1] += y_offset_shoulder2
316
+
317
+ ########shoulder5########
318
+ l_shoulder5_ref = ((ref_candidate[5][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[5][1] - ref_candidate[1][1]) ** 2) ** 0.5
319
+ l_shoulder5_0 = ((candidate[5][0] - candidate[1][0]) ** 2 + (candidate[5][1] - candidate[1][1]) ** 2) ** 0.5
320
+
321
+ shoulder5_ratio = l_shoulder5_ref / l_shoulder5_0
322
+
323
+ x_offset_shoulder5 = (candidate[1][0]-candidate[5][0])*(1.-shoulder5_ratio)
324
+ y_offset_shoulder5 = (candidate[1][1]-candidate[5][1])*(1.-shoulder5_ratio)
325
+
326
+ results_vis[0]['bodies']['candidate'][5,0] += x_offset_shoulder5
327
+ results_vis[0]['bodies']['candidate'][5,1] += y_offset_shoulder5
328
+ results_vis[0]['bodies']['candidate'][6,0] += x_offset_shoulder5
329
+ results_vis[0]['bodies']['candidate'][6,1] += y_offset_shoulder5
330
+ results_vis[0]['bodies']['candidate'][7,0] += x_offset_shoulder5
331
+ results_vis[0]['bodies']['candidate'][7,1] += y_offset_shoulder5
332
+ results_vis[0]['hands'][0,:,0] += x_offset_shoulder5
333
+ results_vis[0]['hands'][0,:,1] += y_offset_shoulder5
334
+
335
+ ########arm3########
336
+ l_arm3_ref = ((ref_candidate[3][0] - ref_candidate[2][0]) ** 2 + (ref_candidate[3][1] - ref_candidate[2][1]) ** 2) ** 0.5
337
+ l_arm3_0 = ((candidate[3][0] - candidate[2][0]) ** 2 + (candidate[3][1] - candidate[2][1]) ** 2) ** 0.5
338
+
339
+ arm3_ratio = l_arm3_ref / l_arm3_0
340
+
341
+ x_offset_arm3 = (candidate[2][0]-candidate[3][0])*(1.-arm3_ratio)
342
+ y_offset_arm3 = (candidate[2][1]-candidate[3][1])*(1.-arm3_ratio)
343
+
344
+ results_vis[0]['bodies']['candidate'][3,0] += x_offset_arm3
345
+ results_vis[0]['bodies']['candidate'][3,1] += y_offset_arm3
346
+ results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm3
347
+ results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm3
348
+ results_vis[0]['hands'][1,:,0] += x_offset_arm3
349
+ results_vis[0]['hands'][1,:,1] += y_offset_arm3
350
+
351
+ ########arm4########
352
+ l_arm4_ref = ((ref_candidate[4][0] - ref_candidate[3][0]) ** 2 + (ref_candidate[4][1] - ref_candidate[3][1]) ** 2) ** 0.5
353
+ l_arm4_0 = ((candidate[4][0] - candidate[3][0]) ** 2 + (candidate[4][1] - candidate[3][1]) ** 2) ** 0.5
354
+
355
+ arm4_ratio = l_arm4_ref / l_arm4_0
356
+
357
+ x_offset_arm4 = (candidate[3][0]-candidate[4][0])*(1.-arm4_ratio)
358
+ y_offset_arm4 = (candidate[3][1]-candidate[4][1])*(1.-arm4_ratio)
359
+
360
+ results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm4
361
+ results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm4
362
+ results_vis[0]['hands'][1,:,0] += x_offset_arm4
363
+ results_vis[0]['hands'][1,:,1] += y_offset_arm4
364
+
365
+ ########arm6########
366
+ l_arm6_ref = ((ref_candidate[6][0] - ref_candidate[5][0]) ** 2 + (ref_candidate[6][1] - ref_candidate[5][1]) ** 2) ** 0.5
367
+ l_arm6_0 = ((candidate[6][0] - candidate[5][0]) ** 2 + (candidate[6][1] - candidate[5][1]) ** 2) ** 0.5
368
+
369
+ arm6_ratio = l_arm6_ref / l_arm6_0
370
+
371
+ x_offset_arm6 = (candidate[5][0]-candidate[6][0])*(1.-arm6_ratio)
372
+ y_offset_arm6 = (candidate[5][1]-candidate[6][1])*(1.-arm6_ratio)
373
+
374
+ results_vis[0]['bodies']['candidate'][6,0] += x_offset_arm6
375
+ results_vis[0]['bodies']['candidate'][6,1] += y_offset_arm6
376
+ results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm6
377
+ results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm6
378
+ results_vis[0]['hands'][0,:,0] += x_offset_arm6
379
+ results_vis[0]['hands'][0,:,1] += y_offset_arm6
380
+
381
+ ########arm7########
382
+ l_arm7_ref = ((ref_candidate[7][0] - ref_candidate[6][0]) ** 2 + (ref_candidate[7][1] - ref_candidate[6][1]) ** 2) ** 0.5
383
+ l_arm7_0 = ((candidate[7][0] - candidate[6][0]) ** 2 + (candidate[7][1] - candidate[6][1]) ** 2) ** 0.5
384
+
385
+ arm7_ratio = l_arm7_ref / l_arm7_0
386
+
387
+ x_offset_arm7 = (candidate[6][0]-candidate[7][0])*(1.-arm7_ratio)
388
+ y_offset_arm7 = (candidate[6][1]-candidate[7][1])*(1.-arm7_ratio)
389
+
390
+ results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm7
391
+ results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm7
392
+ results_vis[0]['hands'][0,:,0] += x_offset_arm7
393
+ results_vis[0]['hands'][0,:,1] += y_offset_arm7
394
+
395
+ ########head14########
396
+ l_head14_ref = ((ref_candidate[14][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[14][1] - ref_candidate[0][1]) ** 2) ** 0.5
397
+ l_head14_0 = ((candidate[14][0] - candidate[0][0]) ** 2 + (candidate[14][1] - candidate[0][1]) ** 2) ** 0.5
398
+
399
+ head14_ratio = l_head14_ref / l_head14_0
400
+
401
+ x_offset_head14 = (candidate[0][0]-candidate[14][0])*(1.-head14_ratio)
402
+ y_offset_head14 = (candidate[0][1]-candidate[14][1])*(1.-head14_ratio)
403
+
404
+ results_vis[0]['bodies']['candidate'][14,0] += x_offset_head14
405
+ results_vis[0]['bodies']['candidate'][14,1] += y_offset_head14
406
+ results_vis[0]['bodies']['candidate'][16,0] += x_offset_head14
407
+ results_vis[0]['bodies']['candidate'][16,1] += y_offset_head14
408
+
409
+ ########head15########
410
+ l_head15_ref = ((ref_candidate[15][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[15][1] - ref_candidate[0][1]) ** 2) ** 0.5
411
+ l_head15_0 = ((candidate[15][0] - candidate[0][0]) ** 2 + (candidate[15][1] - candidate[0][1]) ** 2) ** 0.5
412
+
413
+ head15_ratio = l_head15_ref / l_head15_0
414
+
415
+ x_offset_head15 = (candidate[0][0]-candidate[15][0])*(1.-head15_ratio)
416
+ y_offset_head15 = (candidate[0][1]-candidate[15][1])*(1.-head15_ratio)
417
+
418
+ results_vis[0]['bodies']['candidate'][15,0] += x_offset_head15
419
+ results_vis[0]['bodies']['candidate'][15,1] += y_offset_head15
420
+ results_vis[0]['bodies']['candidate'][17,0] += x_offset_head15
421
+ results_vis[0]['bodies']['candidate'][17,1] += y_offset_head15
422
+
423
+ ########head16########
424
+ l_head16_ref = ((ref_candidate[16][0] - ref_candidate[14][0]) ** 2 + (ref_candidate[16][1] - ref_candidate[14][1]) ** 2) ** 0.5
425
+ l_head16_0 = ((candidate[16][0] - candidate[14][0]) ** 2 + (candidate[16][1] - candidate[14][1]) ** 2) ** 0.5
426
+
427
+ head16_ratio = l_head16_ref / l_head16_0
428
+
429
+ x_offset_head16 = (candidate[14][0]-candidate[16][0])*(1.-head16_ratio)
430
+ y_offset_head16 = (candidate[14][1]-candidate[16][1])*(1.-head16_ratio)
431
+
432
+ results_vis[0]['bodies']['candidate'][16,0] += x_offset_head16
433
+ results_vis[0]['bodies']['candidate'][16,1] += y_offset_head16
434
+
435
+ ########head17########
436
+ l_head17_ref = ((ref_candidate[17][0] - ref_candidate[15][0]) ** 2 + (ref_candidate[17][1] - ref_candidate[15][1]) ** 2) ** 0.5
437
+ l_head17_0 = ((candidate[17][0] - candidate[15][0]) ** 2 + (candidate[17][1] - candidate[15][1]) ** 2) ** 0.5
438
+
439
+ head17_ratio = l_head17_ref / l_head17_0
440
+
441
+ x_offset_head17 = (candidate[15][0]-candidate[17][0])*(1.-head17_ratio)
442
+ y_offset_head17 = (candidate[15][1]-candidate[17][1])*(1.-head17_ratio)
443
+
444
+ results_vis[0]['bodies']['candidate'][17,0] += x_offset_head17
445
+ results_vis[0]['bodies']['candidate'][17,1] += y_offset_head17
446
+
447
+ ########MovingAverage########
448
+
449
+ ########left leg########
450
+ l_ll1_ref = ((ref_candidate[8][0] - ref_candidate[9][0]) ** 2 + (ref_candidate[8][1] - ref_candidate[9][1]) ** 2) ** 0.5
451
+ l_ll1_0 = ((candidate[8][0] - candidate[9][0]) ** 2 + (candidate[8][1] - candidate[9][1]) ** 2) ** 0.5
452
+ ll1_ratio = l_ll1_ref / l_ll1_0
453
+
454
+ x_offset_ll1 = (candidate[9][0]-candidate[8][0])*(ll1_ratio-1.)
455
+ y_offset_ll1 = (candidate[9][1]-candidate[8][1])*(ll1_ratio-1.)
456
+
457
+ results_vis[0]['bodies']['candidate'][9,0] += x_offset_ll1
458
+ results_vis[0]['bodies']['candidate'][9,1] += y_offset_ll1
459
+ results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll1
460
+ results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll1
461
+ results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll1
462
+ results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll1
463
+
464
+ l_ll2_ref = ((ref_candidate[9][0] - ref_candidate[10][0]) ** 2 + (ref_candidate[9][1] - ref_candidate[10][1]) ** 2) ** 0.5
465
+ l_ll2_0 = ((candidate[9][0] - candidate[10][0]) ** 2 + (candidate[9][1] - candidate[10][1]) ** 2) ** 0.5
466
+ ll2_ratio = l_ll2_ref / l_ll2_0
467
+
468
+ x_offset_ll2 = (candidate[10][0]-candidate[9][0])*(ll2_ratio-1.)
469
+ y_offset_ll2 = (candidate[10][1]-candidate[9][1])*(ll2_ratio-1.)
470
+
471
+ results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll2
472
+ results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll2
473
+ results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll2
474
+ results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll2
475
+
476
+ ########right leg########
477
+ l_rl1_ref = ((ref_candidate[11][0] - ref_candidate[12][0]) ** 2 + (ref_candidate[11][1] - ref_candidate[12][1]) ** 2) ** 0.5
478
+ l_rl1_0 = ((candidate[11][0] - candidate[12][0]) ** 2 + (candidate[11][1] - candidate[12][1]) ** 2) ** 0.5
479
+ rl1_ratio = l_rl1_ref / l_rl1_0
480
+
481
+ x_offset_rl1 = (candidate[12][0]-candidate[11][0])*(rl1_ratio-1.)
482
+ y_offset_rl1 = (candidate[12][1]-candidate[11][1])*(rl1_ratio-1.)
483
+
484
+ results_vis[0]['bodies']['candidate'][12,0] += x_offset_rl1
485
+ results_vis[0]['bodies']['candidate'][12,1] += y_offset_rl1
486
+ results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl1
487
+ results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl1
488
+ results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl1
489
+ results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl1
490
+
491
+ l_rl2_ref = ((ref_candidate[12][0] - ref_candidate[13][0]) ** 2 + (ref_candidate[12][1] - ref_candidate[13][1]) ** 2) ** 0.5
492
+ l_rl2_0 = ((candidate[12][0] - candidate[13][0]) ** 2 + (candidate[12][1] - candidate[13][1]) ** 2) ** 0.5
493
+ rl2_ratio = l_rl2_ref / l_rl2_0
494
+
495
+ x_offset_rl2 = (candidate[13][0]-candidate[12][0])*(rl2_ratio-1.)
496
+ y_offset_rl2 = (candidate[13][1]-candidate[12][1])*(rl2_ratio-1.)
497
+
498
+ results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl2
499
+ results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl2
500
+ results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl2
501
+ results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl2
502
+
503
+ offset = ref_candidate[1] - results_vis[0]['bodies']['candidate'][1]
504
+
505
+ results_vis[0]['bodies']['candidate'] += offset[np.newaxis, :]
506
+ results_vis[0]['faces'] += offset[np.newaxis, np.newaxis, :]
507
+ results_vis[0]['hands'] += offset[np.newaxis, np.newaxis, :]
508
+
509
+ for i in range(1, len(results_vis)):
510
+ results_vis[i]['bodies']['candidate'][:,0] *= x_ratio
511
+ results_vis[i]['bodies']['candidate'][:,1] *= y_ratio
512
+ results_vis[i]['faces'][:,:,0] *= x_ratio
513
+ results_vis[i]['faces'][:,:,1] *= y_ratio
514
+ results_vis[i]['hands'][:,:,0] *= x_ratio
515
+ results_vis[i]['hands'][:,:,1] *= y_ratio
516
+
517
+ ########neck########
518
+ x_offset_neck = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][0][0])*(1.-neck_ratio)
519
+ y_offset_neck = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][0][1])*(1.-neck_ratio)
520
+
521
+ results_vis[i]['bodies']['candidate'][0,0] += x_offset_neck
522
+ results_vis[i]['bodies']['candidate'][0,1] += y_offset_neck
523
+ results_vis[i]['bodies']['candidate'][14,0] += x_offset_neck
524
+ results_vis[i]['bodies']['candidate'][14,1] += y_offset_neck
525
+ results_vis[i]['bodies']['candidate'][15,0] += x_offset_neck
526
+ results_vis[i]['bodies']['candidate'][15,1] += y_offset_neck
527
+ results_vis[i]['bodies']['candidate'][16,0] += x_offset_neck
528
+ results_vis[i]['bodies']['candidate'][16,1] += y_offset_neck
529
+ results_vis[i]['bodies']['candidate'][17,0] += x_offset_neck
530
+ results_vis[i]['bodies']['candidate'][17,1] += y_offset_neck
531
+
532
+ ########shoulder2########
533
+
534
+
535
+ x_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][2][0])*(1.-shoulder2_ratio)
536
+ y_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][2][1])*(1.-shoulder2_ratio)
537
+
538
+ results_vis[i]['bodies']['candidate'][2,0] += x_offset_shoulder2
539
+ results_vis[i]['bodies']['candidate'][2,1] += y_offset_shoulder2
540
+ results_vis[i]['bodies']['candidate'][3,0] += x_offset_shoulder2
541
+ results_vis[i]['bodies']['candidate'][3,1] += y_offset_shoulder2
542
+ results_vis[i]['bodies']['candidate'][4,0] += x_offset_shoulder2
543
+ results_vis[i]['bodies']['candidate'][4,1] += y_offset_shoulder2
544
+ results_vis[i]['hands'][1,:,0] += x_offset_shoulder2
545
+ results_vis[i]['hands'][1,:,1] += y_offset_shoulder2
546
+
547
+ ########shoulder5########
548
+
549
+ x_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][5][0])*(1.-shoulder5_ratio)
550
+ y_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][5][1])*(1.-shoulder5_ratio)
551
+
552
+ results_vis[i]['bodies']['candidate'][5,0] += x_offset_shoulder5
553
+ results_vis[i]['bodies']['candidate'][5,1] += y_offset_shoulder5
554
+ results_vis[i]['bodies']['candidate'][6,0] += x_offset_shoulder5
555
+ results_vis[i]['bodies']['candidate'][6,1] += y_offset_shoulder5
556
+ results_vis[i]['bodies']['candidate'][7,0] += x_offset_shoulder5
557
+ results_vis[i]['bodies']['candidate'][7,1] += y_offset_shoulder5
558
+ results_vis[i]['hands'][0,:,0] += x_offset_shoulder5
559
+ results_vis[i]['hands'][0,:,1] += y_offset_shoulder5
560
+
561
+ ########arm3########
562
+
563
+ x_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][0]-results_vis[i]['bodies']['candidate'][3][0])*(1.-arm3_ratio)
564
+ y_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][1]-results_vis[i]['bodies']['candidate'][3][1])*(1.-arm3_ratio)
565
+
566
+ results_vis[i]['bodies']['candidate'][3,0] += x_offset_arm3
567
+ results_vis[i]['bodies']['candidate'][3,1] += y_offset_arm3
568
+ results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm3
569
+ results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm3
570
+ results_vis[i]['hands'][1,:,0] += x_offset_arm3
571
+ results_vis[i]['hands'][1,:,1] += y_offset_arm3
572
+
573
+ ########arm4########
574
+
575
+ x_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][0]-results_vis[i]['bodies']['candidate'][4][0])*(1.-arm4_ratio)
576
+ y_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][1]-results_vis[i]['bodies']['candidate'][4][1])*(1.-arm4_ratio)
577
+
578
+ results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm4
579
+ results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm4
580
+ results_vis[i]['hands'][1,:,0] += x_offset_arm4
581
+ results_vis[i]['hands'][1,:,1] += y_offset_arm4
582
+
583
+ ########arm6########
584
+
585
+ x_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][0]-results_vis[i]['bodies']['candidate'][6][0])*(1.-arm6_ratio)
586
+ y_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][1]-results_vis[i]['bodies']['candidate'][6][1])*(1.-arm6_ratio)
587
+
588
+ results_vis[i]['bodies']['candidate'][6,0] += x_offset_arm6
589
+ results_vis[i]['bodies']['candidate'][6,1] += y_offset_arm6
590
+ results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm6
591
+ results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm6
592
+ results_vis[i]['hands'][0,:,0] += x_offset_arm6
593
+ results_vis[i]['hands'][0,:,1] += y_offset_arm6
594
+
595
+ ########arm7########
596
+
597
+ x_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][0]-results_vis[i]['bodies']['candidate'][7][0])*(1.-arm7_ratio)
598
+ y_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][1]-results_vis[i]['bodies']['candidate'][7][1])*(1.-arm7_ratio)
599
+
600
+ results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm7
601
+ results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm7
602
+ results_vis[i]['hands'][0,:,0] += x_offset_arm7
603
+ results_vis[i]['hands'][0,:,1] += y_offset_arm7
604
+
605
+ ########head14########
606
+
607
+ x_offset_head14 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][14][0])*(1.-head14_ratio)
608
+ y_offset_head14 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][14][1])*(1.-head14_ratio)
609
+
610
+ results_vis[i]['bodies']['candidate'][14,0] += x_offset_head14
611
+ results_vis[i]['bodies']['candidate'][14,1] += y_offset_head14
612
+ results_vis[i]['bodies']['candidate'][16,0] += x_offset_head14
613
+ results_vis[i]['bodies']['candidate'][16,1] += y_offset_head14
614
+
615
+ ########head15########
616
+
617
+ x_offset_head15 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][15][0])*(1.-head15_ratio)
618
+ y_offset_head15 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][15][1])*(1.-head15_ratio)
619
+
620
+ results_vis[i]['bodies']['candidate'][15,0] += x_offset_head15
621
+ results_vis[i]['bodies']['candidate'][15,1] += y_offset_head15
622
+ results_vis[i]['bodies']['candidate'][17,0] += x_offset_head15
623
+ results_vis[i]['bodies']['candidate'][17,1] += y_offset_head15
624
+
625
+ ########head16########
626
+
627
+ x_offset_head16 = (results_vis[i]['bodies']['candidate'][14][0]-results_vis[i]['bodies']['candidate'][16][0])*(1.-head16_ratio)
628
+ y_offset_head16 = (results_vis[i]['bodies']['candidate'][14][1]-results_vis[i]['bodies']['candidate'][16][1])*(1.-head16_ratio)
629
+
630
+ results_vis[i]['bodies']['candidate'][16,0] += x_offset_head16
631
+ results_vis[i]['bodies']['candidate'][16,1] += y_offset_head16
632
+
633
+ ########head17########
634
+ x_offset_head17 = (results_vis[i]['bodies']['candidate'][15][0]-results_vis[i]['bodies']['candidate'][17][0])*(1.-head17_ratio)
635
+ y_offset_head17 = (results_vis[i]['bodies']['candidate'][15][1]-results_vis[i]['bodies']['candidate'][17][1])*(1.-head17_ratio)
636
+
637
+ results_vis[i]['bodies']['candidate'][17,0] += x_offset_head17
638
+ results_vis[i]['bodies']['candidate'][17,1] += y_offset_head17
639
+
640
+ # ########MovingAverage########
641
+
642
+ ########left leg########
643
+ x_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][0]-results_vis[i]['bodies']['candidate'][8][0])*(ll1_ratio-1.)
644
+ y_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][1]-results_vis[i]['bodies']['candidate'][8][1])*(ll1_ratio-1.)
645
+
646
+ results_vis[i]['bodies']['candidate'][9,0] += x_offset_ll1
647
+ results_vis[i]['bodies']['candidate'][9,1] += y_offset_ll1
648
+ results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll1
649
+ results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll1
650
+ results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll1
651
+ results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll1
652
+
653
+
654
+
655
+ x_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][0]-results_vis[i]['bodies']['candidate'][9][0])*(ll2_ratio-1.)
656
+ y_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][1]-results_vis[i]['bodies']['candidate'][9][1])*(ll2_ratio-1.)
657
+
658
+ results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll2
659
+ results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll2
660
+ results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll2
661
+ results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll2
662
+
663
+ ########right leg########
664
+
665
+ x_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][0]-results_vis[i]['bodies']['candidate'][11][0])*(rl1_ratio-1.)
666
+ y_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][1]-results_vis[i]['bodies']['candidate'][11][1])*(rl1_ratio-1.)
667
+
668
+ results_vis[i]['bodies']['candidate'][12,0] += x_offset_rl1
669
+ results_vis[i]['bodies']['candidate'][12,1] += y_offset_rl1
670
+ results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl1
671
+ results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl1
672
+ results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl1
673
+ results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl1
674
+
675
+
676
+ x_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][0]-results_vis[i]['bodies']['candidate'][12][0])*(rl2_ratio-1.)
677
+ y_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][1]-results_vis[i]['bodies']['candidate'][12][1])*(rl2_ratio-1.)
678
+
679
+ results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl2
680
+ results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl2
681
+ results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl2
682
+ results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl2
683
+
684
+ results_vis[i]['bodies']['candidate'] += offset[np.newaxis, :]
685
+ results_vis[i]['faces'] += offset[np.newaxis, np.newaxis, :]
686
+ results_vis[i]['hands'] += offset[np.newaxis, np.newaxis, :]
687
+
688
+ for i in range(len(results_vis)):
689
+ dwpose_woface, dwpose_wface = draw_pose(results_vis[i], H=768, W=512)
690
+ img_path = save_motion+'/' + str(i).zfill(4) + '.jpg'
691
+ cv2.imwrite(img_path, dwpose_woface)
692
+
693
+ dwpose_woface, dwpose_wface = draw_pose(pose_ref, H=768, W=512)
694
+ img_path = save_warp+'/' + 'ref_pose.jpg'
695
+ cv2.imwrite(img_path, dwpose_woface)
696
+
697
+
698
+ logger = get_logger('dw pose extraction')
699
+
700
+
701
+ if __name__=='__main__':
702
+ def parse_args():
703
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
704
+ parser.add_argument("--ref_name", type=str, default="data/images/IMG_20240514_104337.jpg",)
705
+ parser.add_argument("--source_video_paths", type=str, default="data/videos/source_video.mp4",)
706
+ parser.add_argument("--saved_pose_dir", type=str, default="data/saved_pose/IMG_20240514_104337",)
707
+ args = parser.parse_args()
708
+
709
+ return args
710
+
711
+ args = parse_args()
712
+ mp_main(args)
UniAnimate/test_func/save_targer_keys.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import imageio
6
+ import numpy as np
7
+ import os.path as osp
8
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
9
+ from thop import profile
10
+ from ptflops import get_model_complexity_info
11
+
12
+ import artist.data as data
13
+ from tools.modules.config import cfg
14
+ from tools.modules.unet.util import *
15
+ from utils.config import Config as pConfig
16
+ from utils.registry_class import ENGINE, MODEL
17
+
18
+
19
+ def save_temporal_key():
20
+ cfg_update = pConfig(load=True)
21
+
22
+ for k, v in cfg_update.cfg_dict.items():
23
+ if isinstance(v, dict) and k in cfg:
24
+ cfg[k].update(v)
25
+ else:
26
+ cfg[k] = v
27
+
28
+ model = MODEL.build(cfg.UNet)
29
+
30
+ temp_name = ''
31
+ temp_key_list = []
32
+ spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json'
33
+ for name, module in model.named_modules():
34
+ if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
35
+ temp_name = name
36
+ print(f'Model: {name}')
37
+ elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
38
+ temp_name = ''
39
+
40
+ if hasattr(module, 'weight'):
41
+ if temp_name != '' and (temp_name in name):
42
+ temp_key_list.append(name)
43
+ print(f'{name}')
44
+ # print(name)
45
+
46
+ save_module_list = []
47
+ for k, p in model.named_parameters():
48
+ for item in temp_key_list:
49
+ if item in k:
50
+ print(f'{item} --> {k}')
51
+ save_module_list.append(k)
52
+
53
+ print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
54
+
55
+ # spth = 'workspace/module_list/{}'
56
+ json.dump(save_module_list, open(spth, 'w'))
57
+ a = 0
58
+
59
+
60
+ def save_spatial_key():
61
+ cfg_update = pConfig(load=True)
62
+
63
+ for k, v in cfg_update.cfg_dict.items():
64
+ if isinstance(v, dict) and k in cfg:
65
+ cfg[k].update(v)
66
+ else:
67
+ cfg[k] = v
68
+
69
+ model = MODEL.build(cfg.UNet)
70
+ temp_name = ''
71
+ temp_key_list = []
72
+ spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json'
73
+ for name, module in model.named_modules():
74
+ if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
75
+ temp_name = name
76
+ print(f'Model: {name}')
77
+ elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
78
+ temp_name = ''
79
+
80
+ if hasattr(module, 'weight'):
81
+ if temp_name != '' and (temp_name in name):
82
+ temp_key_list.append(name)
83
+ print(f'{name}')
84
+ # print(name)
85
+
86
+ save_module_list = []
87
+ for k, p in model.named_parameters():
88
+ for item in temp_key_list:
89
+ if item in k:
90
+ print(f'{item} --> {k}')
91
+ save_module_list.append(k)
92
+
93
+ print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
94
+
95
+ # spth = 'workspace/module_list/{}'
96
+ json.dump(save_module_list, open(spth, 'w'))
97
+ a = 0
98
+
99
+
100
+ if __name__ == '__main__':
101
+ # save_temporal_key()
102
+ save_spatial_key()
103
+
104
+
105
+
106
+ # print([k for (k, _) in self.input_blocks.named_parameters()])
107
+
108
+
UniAnimate/test_func/test_EndDec.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import imageio
5
+ import numpy as np
6
+ import os.path as osp
7
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+ from einops import rearrange
11
+
12
+ from tools import *
13
+ import utils.transforms as data
14
+ from utils.seed import setup_seed
15
+ from tools.modules.config import cfg
16
+ from utils.config import Config as pConfig
17
+ from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER
18
+
19
+
20
+ def test_enc_dec(gpu=0):
21
+ setup_seed(0)
22
+ cfg_update = pConfig(load=True)
23
+
24
+ for k, v in cfg_update.cfg_dict.items():
25
+ if isinstance(v, dict) and k in cfg:
26
+ cfg[k].update(v)
27
+ else:
28
+ cfg[k] = v
29
+
30
+ save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
31
+ os.system('rm -rf %s' % (save_dir))
32
+ os.makedirs(save_dir, exist_ok=True)
33
+
34
+ train_trans = data.Compose([
35
+ data.CenterCropWide(size=cfg.resolution),
36
+ data.ToTensor(),
37
+ data.Normalize(mean=cfg.mean, std=cfg.std)])
38
+
39
+ vit_trans = data.Compose([
40
+ data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
41
+ data.Resize(cfg.vit_resolution),
42
+ data.ToTensor(),
43
+ data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
44
+
45
+ video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
46
+ video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
47
+
48
+ txt_size = cfg.resolution[1]
49
+ nc = int(38 * (txt_size / 256))
50
+ font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
51
+
52
+ dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
53
+ print('There are %d videos' % (len(dataset)))
54
+
55
+ autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
56
+ autoencoder.eval() # freeze
57
+ for param in autoencoder.parameters():
58
+ param.requires_grad = False
59
+ autoencoder.to(gpu)
60
+ for idx, item in enumerate(dataset):
61
+ local_path = os.path.join(save_dir, '%04d.mp4' % idx)
62
+ # ref_frame, video_data, caption = item
63
+ ref_frame, vit_frame, video_data = item[:3]
64
+ video_data = video_data.to(gpu)
65
+
66
+ image_list = []
67
+ video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
68
+ with torch.no_grad():
69
+ decode_data = []
70
+ for chunk_data in video_data_list:
71
+ latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
72
+ # latent_z = get_first_stage_encoding(encoder_posterior).detach()
73
+ kwargs = {"timesteps": chunk_data.shape[0]}
74
+ recons_data = autoencoder.decode(latent_z, **kwargs)
75
+
76
+ vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
77
+ vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384
78
+ vis_data = vis_data.cpu()
79
+ vis_data.clamp_(0, 1)
80
+ vis_data = vis_data.permute(0, 2, 3, 1)
81
+ vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
82
+ image_list.extend(vis_data)
83
+
84
+ num_image = len(image_list)
85
+ frame_dir = os.path.join(save_dir, 'temp')
86
+ os.makedirs(frame_dir, exist_ok=True)
87
+ for idx in range(num_image):
88
+ tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
89
+ cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
90
+ cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
91
+ os.system(cmd); os.system(f'rm -rf {frame_dir}')
92
+
93
+
94
+ if __name__ == '__main__':
95
+ test_enc_dec()
UniAnimate/test_func/test_dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import imageio
4
+ import numpy as np
5
+ import os.path as osp
6
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import torchvision.transforms as T
9
+
10
+ import utils.transforms as data
11
+ from tools.modules.config import cfg
12
+ from utils.config import Config as pConfig
13
+ from utils.registry_class import ENGINE, DATASETS
14
+
15
+ from tools import *
16
+
17
+ def test_video_dataset():
18
+ cfg_update = pConfig(load=True)
19
+
20
+ for k, v in cfg_update.cfg_dict.items():
21
+ if isinstance(v, dict) and k in cfg:
22
+ cfg[k].update(v)
23
+ else:
24
+ cfg[k] = v
25
+
26
+ exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
27
+ save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name)
28
+ os.system('rm -rf %s' % (save_dir))
29
+ os.makedirs(save_dir, exist_ok=True)
30
+
31
+ train_trans = data.Compose([
32
+ data.CenterCropWide(size=cfg.resolution),
33
+ data.ToTensor(),
34
+ data.Normalize(mean=cfg.mean, std=cfg.std)])
35
+ vit_trans = T.Compose([
36
+ data.CenterCropWide(cfg.vit_resolution),
37
+ T.ToTensor(),
38
+ T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
39
+
40
+ video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
41
+ video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
42
+
43
+ img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
44
+ img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
45
+
46
+ vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
47
+ vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
48
+
49
+ txt_size = cfg.resolution[1]
50
+ nc = int(38 * (txt_size / 256))
51
+ font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
52
+
53
+ dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans)
54
+ print('There are %d videos' % (len(dataset)))
55
+ for idx, item in enumerate(dataset):
56
+ ref_frame, vit_frame, video_data, caption, video_key = item
57
+
58
+ video_data = video_data.mul_(video_std).add_(video_mean)
59
+ video_data.clamp_(0, 1)
60
+ video_data = video_data.permute(0, 2, 3, 1)
61
+ video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
62
+
63
+ # Single Image
64
+ ref_frame = ref_frame.mul_(img_mean).add_(img_std)
65
+ ref_frame.clamp_(0, 1)
66
+ ref_frame = ref_frame.permute(1, 2, 0)
67
+ ref_frame = (ref_frame.numpy() * 255).astype('uint8')
68
+
69
+ # Text image
70
+ txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
71
+ draw = ImageDraw.Draw(txt_img)
72
+ lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
73
+ draw.text((0, 0), lines, fill="black", font=font)
74
+ txt_img = np.array(txt_img)
75
+
76
+ video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data]
77
+ spath = os.path.join(save_dir, '%04d.gif' % (idx))
78
+ imageio.mimwrite(spath, video_data, fps =8)
79
+
80
+ # if idx > 100: break
81
+
82
+
83
+ def test_vit_image(test_video_flag=True):
84
+ cfg_update = pConfig(load=True)
85
+
86
+ for k, v in cfg_update.cfg_dict.items():
87
+ if isinstance(v, dict) and k in cfg:
88
+ cfg[k].update(v)
89
+ else:
90
+ cfg[k] = v
91
+
92
+ exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
93
+ save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name)
94
+ os.system('rm -rf %s' % (save_dir))
95
+ os.makedirs(save_dir, exist_ok=True)
96
+
97
+ train_trans = data.Compose([
98
+ data.CenterCropWide(size=cfg.resolution),
99
+ data.ToTensor(),
100
+ data.Normalize(mean=cfg.mean, std=cfg.std)])
101
+ vit_trans = data.Compose([
102
+ data.CenterCropWide(cfg.resolution),
103
+ data.Resize(cfg.vit_resolution),
104
+ data.ToTensor(),
105
+ data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
106
+
107
+ img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
108
+ img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
109
+
110
+ vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
111
+ vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
112
+
113
+ txt_size = cfg.resolution[1]
114
+ nc = int(38 * (txt_size / 256))
115
+ font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13)
116
+
117
+ dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
118
+ print('There are %d videos' % (len(dataset)))
119
+ for idx, item in enumerate(dataset):
120
+ ref_frame, vit_frame, video_data, caption, video_key = item
121
+ video_data = video_data.mul_(img_std).add_(img_mean)
122
+ video_data.clamp_(0, 1)
123
+ video_data = video_data.permute(0, 2, 3, 1)
124
+ video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
125
+
126
+ # Single Image
127
+ vit_frame = vit_frame.mul_(vit_std).add_(vit_mean)
128
+ vit_frame.clamp_(0, 1)
129
+ vit_frame = vit_frame.permute(1, 2, 0)
130
+ vit_frame = (vit_frame.numpy() * 255).astype('uint8')
131
+
132
+ zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8)
133
+ zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame
134
+
135
+ # Text image
136
+ txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
137
+ draw = ImageDraw.Draw(txt_img)
138
+ lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
139
+ draw.text((0, 0), lines, fill="black", font=font)
140
+ txt_img = np.array(txt_img)
141
+
142
+ video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data]
143
+ spath = os.path.join(save_dir, '%04d.gif' % (idx))
144
+ imageio.mimwrite(spath, video_data, fps =8)
145
+
146
+ # if idx > 100: break
147
+
148
+
149
+ if __name__ == '__main__':
150
+ # test_video_dataset()
151
+ test_vit_image()
152
+
UniAnimate/test_func/test_models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import imageio
5
+ import numpy as np
6
+ import os.path as osp
7
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
8
+ from thop import profile
9
+ from ptflops import get_model_complexity_info
10
+
11
+ import artist.data as data
12
+ from tools.modules.config import cfg
13
+ from utils.config import Config as pConfig
14
+ from utils.registry_class import ENGINE, MODEL
15
+
16
+
17
+ def test_model():
18
+ cfg_update = pConfig(load=True)
19
+
20
+ for k, v in cfg_update.cfg_dict.items():
21
+ if isinstance(v, dict) and k in cfg:
22
+ cfg[k].update(v)
23
+ else:
24
+ cfg[k] = v
25
+
26
+ model = MODEL.build(cfg.UNet)
27
+ print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
28
+
29
+ # state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu')
30
+ # model.load_state_dict(state_dict, strict=False)
31
+ model = model.cuda()
32
+
33
+ x = torch.Tensor(1, 4, 16, 32, 56).cuda()
34
+ t = torch.Tensor(1).cuda()
35
+ sims = torch.Tensor(1, 32).cuda()
36
+ fps = torch.Tensor([8]).cuda()
37
+ y = torch.Tensor(1, 1, 1024).cuda()
38
+ image = torch.Tensor(1, 3, 256, 448).cuda()
39
+
40
+ ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps)
41
+ print('Out shape if {}'.format(ret.shape))
42
+
43
+ # flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps))
44
+ # print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))
45
+
46
+ def prepare_input(resolution):
47
+ return dict(x=[x, t, y, image, sims, fps])
48
+
49
+ flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56),
50
+ input_constructor = prepare_input,
51
+ as_strings=True, print_per_layer_stat=True)
52
+ print(' - Flops: ' + flops)
53
+ print(' - Params: ' + params)
54
+
55
+ if __name__ == '__main__':
56
+ test_model()
UniAnimate/test_func/test_save_video.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ cap = cv2.VideoCapture('workspace/img_dir/tst.mp4')
5
+
6
+ fourcc = cv2.VideoWriter_fourcc(*'H264')
7
+
8
+ ret, frame = cap.read()
9
+ vid_size = frame.shape[:2][::-1]
10
+
11
+ out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size)
12
+ out.write(frame)
13
+
14
+ while(cap.isOpened()):
15
+ ret, frame = cap.read()
16
+ if not ret: break
17
+ out.write(frame)
18
+
19
+
20
+ cap.release()
21
+ out.release()
22
+
23
+
24
+
UniAnimate/tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .datasets import *
2
+ from .modules import *
3
+ from .inferences import *
UniAnimate/tools/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .image_dataset import *
2
+ from .video_dataset import *
UniAnimate/tools/datasets/image_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import random
5
+ import logging
6
+ import tempfile
7
+ import numpy as np
8
+ from copy import copy
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from torch.utils.data import Dataset
12
+ from utils.registry_class import DATASETS
13
+
14
+ @DATASETS.register_class()
15
+ class ImageDataset(Dataset):
16
+ def __init__(self,
17
+ data_list,
18
+ data_dir_list,
19
+ max_words=1000,
20
+ vit_resolution=[224, 224],
21
+ resolution=(384, 256),
22
+ max_frames=1,
23
+ transforms=None,
24
+ vit_transforms=None,
25
+ **kwargs):
26
+
27
+ self.max_frames = max_frames
28
+ self.resolution = resolution
29
+ self.transforms = transforms
30
+ self.vit_resolution = vit_resolution
31
+ self.vit_transforms = vit_transforms
32
+
33
+ image_list = []
34
+ for item_path, data_dir in zip(data_list, data_dir_list):
35
+ lines = open(item_path, 'r').readlines()
36
+ lines = [[data_dir, item.strip()] for item in lines]
37
+ image_list.extend(lines)
38
+ self.image_list = image_list
39
+
40
+ def __len__(self):
41
+ return len(self.image_list)
42
+
43
+ def __getitem__(self, index):
44
+ data_dir, file_path = self.image_list[index]
45
+ img_key = file_path.split('|||')[0]
46
+ try:
47
+ ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
48
+ except Exception as e:
49
+ logging.info('{} get frames failed... with error: {}'.format(img_key, e))
50
+ caption = ''
51
+ img_key = ''
52
+ ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
53
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
54
+ video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
55
+ return ref_frame, vit_frame, video_data, caption, img_key
56
+
57
+ def _get_image_data(self, data_dir, file_path):
58
+ frame_list = []
59
+ img_key, caption = file_path.split('|||')
60
+ file_path = os.path.join(data_dir, img_key)
61
+ for _ in range(5):
62
+ try:
63
+ image = Image.open(file_path)
64
+ if image.mode != 'RGB':
65
+ image = image.convert('RGB')
66
+ frame_list.append(image)
67
+ break
68
+ except Exception as e:
69
+ logging.info('{} read video frame failed with error: {}'.format(img_key, e))
70
+ continue
71
+
72
+ video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
73
+ try:
74
+ if len(frame_list) > 0:
75
+ mid_frame = frame_list[0]
76
+ vit_frame = self.vit_transforms(mid_frame)
77
+ frame_tensor = self.transforms(frame_list)
78
+ video_data[:len(frame_list), ...] = frame_tensor
79
+ else:
80
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
81
+ except:
82
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
83
+ ref_frame = copy(video_data[0])
84
+
85
+ return ref_frame, vit_frame, video_data, caption
86
+
UniAnimate/tools/datasets/video_dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import torch
5
+ import random
6
+ import logging
7
+ import tempfile
8
+ import numpy as np
9
+ from copy import copy
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from utils.registry_class import DATASETS
13
+
14
+
15
+ @DATASETS.register_class()
16
+ class VideoDataset(Dataset):
17
+ def __init__(self,
18
+ data_list,
19
+ data_dir_list,
20
+ max_words=1000,
21
+ resolution=(384, 256),
22
+ vit_resolution=(224, 224),
23
+ max_frames=16,
24
+ sample_fps=8,
25
+ transforms=None,
26
+ vit_transforms=None,
27
+ get_first_frame=False,
28
+ **kwargs):
29
+
30
+ self.max_words = max_words
31
+ self.max_frames = max_frames
32
+ self.resolution = resolution
33
+ self.vit_resolution = vit_resolution
34
+ self.sample_fps = sample_fps
35
+ self.transforms = transforms
36
+ self.vit_transforms = vit_transforms
37
+ self.get_first_frame = get_first_frame
38
+
39
+ image_list = []
40
+ for item_path, data_dir in zip(data_list, data_dir_list):
41
+ lines = open(item_path, 'r').readlines()
42
+ lines = [[data_dir, item] for item in lines]
43
+ image_list.extend(lines)
44
+ self.image_list = image_list
45
+
46
+
47
+ def __getitem__(self, index):
48
+ data_dir, file_path = self.image_list[index]
49
+ video_key = file_path.split('|||')[0]
50
+ try:
51
+ ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
52
+ except Exception as e:
53
+ logging.info('{} get frames failed... with error: {}'.format(video_key, e))
54
+ caption = ''
55
+ video_key = ''
56
+ ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
57
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
58
+ video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
59
+ return ref_frame, vit_frame, video_data, caption, video_key
60
+
61
+
62
+ def _get_video_data(self, data_dir, file_path):
63
+ video_key, caption = file_path.split('|||')
64
+ file_path = os.path.join(data_dir, video_key)
65
+
66
+ for _ in range(5):
67
+ try:
68
+ capture = cv2.VideoCapture(file_path)
69
+ _fps = capture.get(cv2.CAP_PROP_FPS)
70
+ _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
71
+ stride = round(_fps / self.sample_fps)
72
+ cover_frame_num = (stride * self.max_frames)
73
+ if _total_frame_num < cover_frame_num + 5:
74
+ start_frame = 0
75
+ end_frame = _total_frame_num
76
+ else:
77
+ start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
78
+ end_frame = start_frame + cover_frame_num
79
+
80
+ pointer, frame_list = 0, []
81
+ while(True):
82
+ ret, frame = capture.read()
83
+ pointer +=1
84
+ if (not ret) or (frame is None): break
85
+ if pointer < start_frame: continue
86
+ if pointer >= end_frame - 1: break
87
+ if (pointer - start_frame) % stride == 0:
88
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
89
+ frame = Image.fromarray(frame)
90
+ frame_list.append(frame)
91
+ break
92
+ except Exception as e:
93
+ logging.info('{} read video frame failed with error: {}'.format(video_key, e))
94
+ continue
95
+
96
+ video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
97
+ if self.get_first_frame:
98
+ ref_idx = 0
99
+ else:
100
+ ref_idx = int(len(frame_list)/2)
101
+ try:
102
+ if len(frame_list)>0:
103
+ mid_frame = copy(frame_list[ref_idx])
104
+ vit_frame = self.vit_transforms(mid_frame)
105
+ frames = self.transforms(frame_list)
106
+ video_data[:len(frame_list), ...] = frames
107
+ else:
108
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
109
+ except:
110
+ vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
111
+ ref_frame = copy(frames[ref_idx])
112
+
113
+ return ref_frame, vit_frame, video_data, caption
114
+
115
+ def __len__(self):
116
+ return len(self.image_list)
117
+
118
+
UniAnimate/tools/inferences/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .inference_unianimate_entrance import *
2
+ from .inference_unianimate_long_entrance import *
UniAnimate/tools/inferences/inference_unianimate_entrance.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ /*
3
+ *Copyright (c) 2021, Alibaba Group;
4
+ *Licensed under the Apache License, Version 2.0 (the "License");
5
+ *you may not use this file except in compliance with the License.
6
+ *You may obtain a copy of the License at
7
+
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ *Unless required by applicable law or agreed to in writing, software
11
+ *distributed under the License is distributed on an "AS IS" BASIS,
12
+ *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ *See the License for the specific language governing permissions and
14
+ *limitations under the License.
15
+ */
16
+ '''
17
+
18
+ import os
19
+ import re
20
+ import os.path as osp
21
+ import sys
22
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
23
+ import json
24
+ import math
25
+ import torch
26
+ import pynvml
27
+ import logging
28
+ import cv2
29
+ import numpy as np
30
+ from PIL import Image
31
+ from tqdm import tqdm
32
+ import torch.cuda.amp as amp
33
+ from importlib import reload
34
+ import torch.distributed as dist
35
+ import torch.multiprocessing as mp
36
+ import random
37
+ from einops import rearrange
38
+ import torchvision.transforms as T
39
+ import torchvision.transforms.functional as TF
40
+ from torch.nn.parallel import DistributedDataParallel
41
+
42
+ import utils.transforms as data
43
+ from ..modules.config import cfg
44
+ from utils.seed import setup_seed
45
+ from utils.multi_port import find_free_port
46
+ from utils.assign_cfg import assign_signle_cfg
47
+ from utils.distributed import generalized_all_gather, all_reduce
48
+ from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col
49
+ from tools.modules.autoencoder import get_first_stage_encoding
50
+ from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION
51
+ from copy import copy
52
+ import cv2
53
+
54
+
55
+ @INFER_ENGINE.register_function()
56
+ def inference_unianimate_entrance(cfg_update, **kwargs):
57
+ for k, v in cfg_update.items():
58
+ if isinstance(v, dict) and k in cfg:
59
+ cfg[k].update(v)
60
+ else:
61
+ cfg[k] = v
62
+
63
+ if not 'MASTER_ADDR' in os.environ:
64
+ os.environ['MASTER_ADDR']='localhost'
65
+ os.environ['MASTER_PORT']= find_free_port()
66
+ cfg.pmi_rank = int(os.getenv('RANK', 0))
67
+ cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
68
+
69
+ if cfg.debug:
70
+ cfg.gpus_per_machine = 1
71
+ cfg.world_size = 1
72
+ else:
73
+ cfg.gpus_per_machine = torch.cuda.device_count()
74
+ cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
75
+
76
+ if cfg.world_size == 1:
77
+ worker(0, cfg, cfg_update)
78
+ else:
79
+ mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update))
80
+ return cfg
81
+
82
+
83
+ def make_masked_images(imgs, masks):
84
+ masked_imgs = []
85
+ for i, mask in enumerate(masks):
86
+ # concatenation
87
+ masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1))
88
+ return torch.stack(masked_imgs, dim=0)
89
+
90
+ def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]):
91
+ for _ in range(5):
92
+ try:
93
+ dwpose_all = {}
94
+ frames_all = {}
95
+ for ii_index in sorted(os.listdir(pose_file_path)):
96
+ if ii_index != "ref_pose.jpg":
97
+ dwpose_all[ii_index] = Image.open(os.path.join(pose_file_path, ii_index))
98
+ frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path), cv2.COLOR_BGR2RGB))
99
+
100
+ pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg"))
101
+
102
+ # Sample max_frames poses for video generation
103
+ stride = frame_interval
104
+ total_frame_num = len(frames_all)
105
+ cover_frame_num = (stride * (max_frames - 1) + 1)
106
+
107
+ if total_frame_num < cover_frame_num:
108
+ print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed')
109
+ start_frame = 0
110
+ end_frame = total_frame_num
111
+ stride = max((total_frame_num - 1) // (max_frames - 1), 1)
112
+ end_frame = stride * max_frames
113
+ else:
114
+ start_frame = 0
115
+ end_frame = start_frame + cover_frame_num
116
+
117
+ frame_list = []
118
+ dwpose_list = []
119
+ random_ref_frame = frames_all[list(frames_all.keys())[0]]
120
+ if random_ref_frame.mode != 'RGB':
121
+ random_ref_frame = random_ref_frame.convert('RGB')
122
+ random_ref_dwpose = pose_ref
123
+ if random_ref_dwpose.mode != 'RGB':
124
+ random_ref_dwpose = random_ref_dwpose.convert('RGB')
125
+
126
+ for i_index in range(start_frame, end_frame, stride):
127
+ if i_index < len(frames_all): # Check index within bounds
128
+ i_key = list(frames_all.keys())[i_index]
129
+ i_frame = frames_all[i_key]
130
+ if i_frame.mode != 'RGB':
131
+ i_frame = i_frame.convert('RGB')
132
+
133
+ i_dwpose = dwpose_all[i_key]
134
+ if i_dwpose.mode != 'RGB':
135
+ i_dwpose = i_dwpose.convert('RGB')
136
+ frame_list.append(i_frame)
137
+ dwpose_list.append(i_dwpose)
138
+
139
+ if frame_list:
140
+ middle_indix = 0
141
+ ref_frame = frame_list[middle_indix]
142
+ vit_frame = vit_transforms(ref_frame)
143
+ random_ref_frame_tmp = train_trans_pose(random_ref_frame)
144
+ random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose)
145
+ misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0)
146
+ video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0)
147
+ dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0)
148
+
149
+ video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
150
+ dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
151
+ misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
152
+ random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
153
+ random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
154
+
155
+ video_data[:len(frame_list), ...] = video_data_tmp
156
+ misc_data[:len(frame_list), ...] = misc_data_tmp
157
+ dwpose_data[:len(frame_list), ...] = dwpose_data_tmp
158
+ random_ref_frame_data[:, ...] = random_ref_frame_tmp
159
+ random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp
160
+
161
+ return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data
162
+
163
+ except Exception as e:
164
+ logging.info(f'Error reading video frame: {e}')
165
+ continue
166
+
167
+ return None, None, None, None, None, None
168
+
169
+ def worker(gpu, cfg, cfg_update):
170
+ '''
171
+ Inference worker for each gpu
172
+ '''
173
+ for k, v in cfg_update.items():
174
+ if isinstance(v, dict) and k in cfg:
175
+ cfg[k].update(v)
176
+ else:
177
+ cfg[k] = v
178
+
179
+ cfg.gpu = gpu
180
+ cfg.seed = int(cfg.seed)
181
+ cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu
182
+ setup_seed(cfg.seed + cfg.rank)
183
+
184
+ if not cfg.debug:
185
+ torch.cuda.set_device(gpu)
186
+ torch.backends.cudnn.benchmark = True
187
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
188
+ torch.backends.cudnn.benchmark = False
189
+ dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)
190
+
191
+ # [Log] Save logging and make log dir
192
+ log_dir = generalized_all_gather(cfg.log_dir)[0]
193
+ inf_name = osp.basename(cfg.cfg_file).split('.')[0]
194
+ test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1]
195
+
196
+ cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name))
197
+ os.makedirs(cfg.log_dir, exist_ok=True)
198
+ log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank))
199
+ cfg.log_file = log_file
200
+ reload(logging)
201
+ logging.basicConfig(
202
+ level=logging.INFO,
203
+ format='[%(asctime)s] %(levelname)s: %(message)s',
204
+ handlers=[
205
+ logging.FileHandler(filename=log_file),
206
+ logging.StreamHandler(stream=sys.stdout)])
207
+ logging.info(cfg)
208
+ logging.info(f"Running UniAnimate inference on gpu {gpu}")
209
+
210
+ # [Diffusion]
211
+ diffusion = DIFFUSION.build(cfg.Diffusion)
212
+
213
+ # [Data] Data Transform
214
+ train_trans = data.Compose([
215
+ data.Resize(cfg.resolution),
216
+ data.ToTensor(),
217
+ data.Normalize(mean=cfg.mean, std=cfg.std)
218
+ ])
219
+
220
+ train_trans_pose = data.Compose([
221
+ data.Resize(cfg.resolution),
222
+ data.ToTensor(),
223
+ ]
224
+ )
225
+
226
+ vit_transforms = T.Compose([
227
+ data.Resize(cfg.vit_resolution),
228
+ T.ToTensor(),
229
+ T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
230
+
231
+ # [Model] embedder
232
+ clip_encoder = EMBEDDER.build(cfg.embedder)
233
+ clip_encoder.model.to(gpu)
234
+ with torch.no_grad():
235
+ _, _, zero_y = clip_encoder(text="")
236
+
237
+
238
+ # [Model] auotoencoder
239
+ autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
240
+ autoencoder.eval() # freeze
241
+ for param in autoencoder.parameters():
242
+ param.requires_grad = False
243
+ autoencoder.cuda()
244
+
245
+ # [Model] UNet
246
+ if "config" in cfg.UNet:
247
+ cfg.UNet["config"] = cfg
248
+ cfg.UNet["zero_y"] = zero_y
249
+ model = MODEL.build(cfg.UNet)
250
+ state_dict = torch.load(cfg.test_model, map_location='cpu')
251
+ if 'state_dict' in state_dict:
252
+ state_dict = state_dict['state_dict']
253
+ if 'step' in state_dict:
254
+ resume_step = state_dict['step']
255
+ else:
256
+ resume_step = 0
257
+ status = model.load_state_dict(state_dict, strict=True)
258
+ logging.info('Load model from {} with status {}'.format(cfg.test_model, status))
259
+ model = model.to(gpu)
260
+ model.eval()
261
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
262
+ model.to(torch.float16)
263
+ else:
264
+ model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
265
+ torch.cuda.empty_cache()
266
+
267
+
268
+
269
+ test_list = cfg.test_list_path
270
+ num_videos = len(test_list)
271
+ logging.info(f'There are {num_videos} videos. with {cfg.round} times')
272
+ # test_list = [item for item in test_list for _ in range(cfg.round)]
273
+ test_list = [item for _ in range(cfg.round) for item in test_list]
274
+
275
+ for idx, file_path in enumerate(test_list):
276
+ cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2]
277
+
278
+ manual_seed = int(cfg.seed + cfg.rank + idx//num_videos)
279
+ setup_seed(manual_seed)
280
+ logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...")
281
+
282
+
283
+ vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution)
284
+ misc_data = misc_data.unsqueeze(0).to(gpu)
285
+ vit_frame = vit_frame.unsqueeze(0).to(gpu)
286
+ dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
287
+ random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
288
+ random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)
289
+
290
+ ### save for visualization
291
+ misc_backups = copy(misc_data)
292
+ frames_num = misc_data.shape[1]
293
+ misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
294
+ mv_data_video = []
295
+
296
+
297
+ ### local image (first frame)
298
+ image_local = []
299
+ if 'local_image' in cfg.video_compositions:
300
+ frames_num = misc_data.shape[1]
301
+ bs_vd_local = misc_data.shape[0]
302
+ image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1)
303
+ image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
304
+ image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
305
+ if hasattr(cfg, "latent_local_image") and cfg.latent_local_image:
306
+ with torch.no_grad():
307
+ temporal_length = frames_num
308
+ encoder_posterior = autoencoder.encode(video_data[:,0])
309
+ local_image_data = get_first_stage_encoding(encoder_posterior).detach()
310
+ image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
311
+
312
+
313
+
314
+ ### encode the video_data
315
+ bs_vd = misc_data.shape[0]
316
+ misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
317
+ misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0)
318
+
319
+
320
+ with torch.no_grad():
321
+
322
+ random_ref_frame = []
323
+ if 'randomref' in cfg.video_compositions:
324
+ random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
325
+ if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref:
326
+
327
+ temporal_length = random_ref_frame_data.shape[1]
328
+ encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5))
329
+ random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach()
330
+ random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
331
+
332
+ random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
333
+
334
+
335
+ if 'dwpose' in cfg.video_compositions:
336
+ bs_vd_local = dwpose_data.shape[0]
337
+ dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local)
338
+ if 'randomref_pose' in cfg.video_compositions:
339
+ dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1)
340
+ dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local)
341
+
342
+
343
+ y_visual = []
344
+ if 'image' in cfg.video_compositions:
345
+ with torch.no_grad():
346
+ vit_frame = vit_frame.squeeze(1)
347
+ y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
348
+ y_visual0 = y_visual.clone()
349
+
350
+
351
+ with amp.autocast(enabled=True):
352
+ pynvml.nvmlInit()
353
+ handle=pynvml.nvmlDeviceGetHandleByIndex(0)
354
+ meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
355
+ cur_seed = torch.initial_seed()
356
+ logging.info(f"Current seed {cur_seed} ...")
357
+
358
+ noise = torch.randn([1, 4, cfg.max_frames, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)])
359
+ noise = noise.to(gpu)
360
+
361
+ if hasattr(cfg.Diffusion, "noise_strength"):
362
+ b, c, f, _, _= noise.shape
363
+ offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
364
+ noise = noise + cfg.Diffusion.noise_strength * offset_noise
365
+
366
+ # add a noise prior
367
+ noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise)
368
+
369
+ # construct model inputs (CFG)
370
+ full_model_kwargs=[{
371
+ 'y': None,
372
+ "local_image": None if len(image_local) == 0 else image_local[:],
373
+ 'image': None if len(y_visual) == 0 else y_visual0[:],
374
+ 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:],
375
+ 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:],
376
+ },
377
+ {
378
+ 'y': None,
379
+ "local_image": None,
380
+ 'image': None,
381
+ 'randomref': None,
382
+ 'dwpose': None,
383
+ }]
384
+
385
+ # for visualization
386
+ full_model_kwargs_vis =[{
387
+ 'y': None,
388
+ "local_image": None if len(image_local) == 0 else image_local_clone[:],
389
+ 'image': None,
390
+ 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:],
391
+ 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3],
392
+ },
393
+ {
394
+ 'y': None,
395
+ "local_image": None,
396
+ 'image': None,
397
+ 'randomref': None,
398
+ 'dwpose': None,
399
+ }]
400
+
401
+
402
+ partial_keys = [
403
+ ['image', 'randomref', "dwpose"],
404
+ ]
405
+ if hasattr(cfg, "partial_keys") and cfg.partial_keys:
406
+ partial_keys = cfg.partial_keys
407
+
408
+
409
+ for partial_keys_one in partial_keys:
410
+ model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one,
411
+ full_model_kwargs = full_model_kwargs,
412
+ use_fps_condition = cfg.use_fps_condition)
413
+ model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one,
414
+ full_model_kwargs = full_model_kwargs_vis,
415
+ use_fps_condition = cfg.use_fps_condition)
416
+ noise_one = noise
417
+
418
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
419
+ clip_encoder.cpu() # add this line
420
+ autoencoder.cpu() # add this line
421
+ torch.cuda.empty_cache() # add this line
422
+
423
+ video_data = diffusion.ddim_sample_loop(
424
+ noise=noise_one,
425
+ model=model.eval(),
426
+ model_kwargs=model_kwargs_one,
427
+ guide_scale=cfg.guide_scale,
428
+ ddim_timesteps=cfg.ddim_timesteps,
429
+ eta=0.0)
430
+
431
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
432
+ # if run forward of autoencoder or clip_encoder second times, load them again
433
+ clip_encoder.cuda()
434
+ autoencoder.cuda()
435
+ video_data = 1. / cfg.scale_factor * video_data
436
+ video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
437
+ chunk_size = min(cfg.decoder_bs, video_data.shape[0])
438
+ video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0)
439
+ decode_data = []
440
+ for vd_data in video_data_list:
441
+ gen_frames = autoencoder.decode(vd_data)
442
+ decode_data.append(gen_frames)
443
+ video_data = torch.cat(decode_data, dim=0)
444
+ video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float()
445
+
446
+ text_size = cfg.resolution[-1]
447
+ cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_')
448
+ name = f'seed_{cur_seed}'
449
+ for ii in partial_keys_one:
450
+ name = name + "_" + ii
451
+ file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4'
452
+ local_path = os.path.join(cfg.log_dir, f'{file_name}')
453
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
454
+ captions = "human"
455
+ del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]]
456
+ del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]]
457
+
458
+ save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups,
459
+ cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps)
460
+
461
+ # try:
462
+ # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
463
+ # logging.info('Save video to dir %s:' % (local_path))
464
+ # except Exception as e:
465
+ # logging.info(f'Step: save text or video error with {e}')
466
+
467
+ logging.info('Congratulations! The inference is completed!')
468
+ # synchronize to finish some processes
469
+ if not cfg.debug:
470
+ torch.cuda.synchronize()
471
+ dist.barrier()
472
+
473
+ def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False):
474
+
475
+ if use_fps_condition is True:
476
+ partial_keys.append('fps')
477
+
478
+ partial_model_kwargs = [{}, {}]
479
+ for partial_key in partial_keys:
480
+ partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
481
+ partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
482
+
483
+ return partial_model_kwargs
UniAnimate/tools/inferences/inference_unianimate_long_entrance.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ /*
3
+ *Copyright (c) 2021, Alibaba Group;
4
+ *Licensed under the Apache License, Version 2.0 (the "License");
5
+ *you may not use this file except in compliance with the License.
6
+ *You may obtain a copy of the License at
7
+
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ *Unless required by applicable law or agreed to in writing, software
11
+ *distributed under the License is distributed on an "AS IS" BASIS,
12
+ *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ *See the License for the specific language governing permissions and
14
+ *limitations under the License.
15
+ */
16
+ '''
17
+
18
+ import os
19
+ import re
20
+ import os.path as osp
21
+ import sys
22
+ sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
23
+ import json
24
+ import math
25
+ import torch
26
+ import pynvml
27
+ import logging
28
+ import cv2
29
+ import numpy as np
30
+ from PIL import Image
31
+ from tqdm import tqdm
32
+ import torch.cuda.amp as amp
33
+ from importlib import reload
34
+ import torch.distributed as dist
35
+ import torch.multiprocessing as mp
36
+ import random
37
+ from einops import rearrange
38
+ import torchvision.transforms as T
39
+ import torchvision.transforms.functional as TF
40
+ from torch.nn.parallel import DistributedDataParallel
41
+
42
+ import utils.transforms as data
43
+ from ..modules.config import cfg
44
+ from utils.seed import setup_seed
45
+ from utils.multi_port import find_free_port
46
+ from utils.assign_cfg import assign_signle_cfg
47
+ from utils.distributed import generalized_all_gather, all_reduce
48
+ from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col
49
+ from tools.modules.autoencoder import get_first_stage_encoding
50
+ from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION
51
+ from copy import copy
52
+ import cv2
53
+
54
+
55
+ @INFER_ENGINE.register_function()
56
+ def inference_unianimate_long_entrance(cfg_update, **kwargs):
57
+ for k, v in cfg_update.items():
58
+ if isinstance(v, dict) and k in cfg:
59
+ cfg[k].update(v)
60
+ else:
61
+ cfg[k] = v
62
+
63
+ if not 'MASTER_ADDR' in os.environ:
64
+ os.environ['MASTER_ADDR']='localhost'
65
+ os.environ['MASTER_PORT']= find_free_port()
66
+ cfg.pmi_rank = int(os.getenv('RANK', 0))
67
+ cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
68
+
69
+ if cfg.debug:
70
+ cfg.gpus_per_machine = 1
71
+ cfg.world_size = 1
72
+ else:
73
+ cfg.gpus_per_machine = torch.cuda.device_count()
74
+ cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
75
+
76
+ if cfg.world_size == 1:
77
+ worker(0, cfg, cfg_update)
78
+ else:
79
+ mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update))
80
+ return cfg
81
+
82
+
83
+ def make_masked_images(imgs, masks):
84
+ masked_imgs = []
85
+ for i, mask in enumerate(masks):
86
+ # concatenation
87
+ masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1))
88
+ return torch.stack(masked_imgs, dim=0)
89
+
90
+ def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]):
91
+
92
+ for _ in range(5):
93
+ try:
94
+ dwpose_all = {}
95
+ frames_all = {}
96
+ for ii_index in sorted(os.listdir(pose_file_path)):
97
+ if ii_index != "ref_pose.jpg":
98
+ dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index)
99
+ frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB))
100
+ # frames_all[ii_index] = Image.open(ref_image_path)
101
+
102
+ pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg"))
103
+ first_eq_ref = False
104
+
105
+ # sample max_frames poses for video generation
106
+ stride = frame_interval
107
+ _total_frame_num = len(frames_all)
108
+ if max_frames == "None":
109
+ max_frames = (_total_frame_num-1)//frame_interval + 1
110
+ cover_frame_num = (stride * (max_frames-1)+1)
111
+ if _total_frame_num < cover_frame_num:
112
+ print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed')
113
+ start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame
114
+ end_frame = _total_frame_num
115
+ stride = max((_total_frame_num-1//(max_frames-1)),1)
116
+ end_frame = stride*max_frames
117
+ else:
118
+ start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame
119
+ end_frame = start_frame + cover_frame_num
120
+
121
+ frame_list = []
122
+ dwpose_list = []
123
+ random_ref_frame = frames_all[list(frames_all.keys())[0]]
124
+ if random_ref_frame.mode != 'RGB':
125
+ random_ref_frame = random_ref_frame.convert('RGB')
126
+ random_ref_dwpose = pose_ref
127
+ if random_ref_dwpose.mode != 'RGB':
128
+ random_ref_dwpose = random_ref_dwpose.convert('RGB')
129
+ for i_index in range(start_frame, end_frame, stride):
130
+ if i_index == start_frame and first_eq_ref:
131
+ i_key = list(frames_all.keys())[i_index]
132
+ i_frame = frames_all[i_key]
133
+
134
+ if i_frame.mode != 'RGB':
135
+ i_frame = i_frame.convert('RGB')
136
+ i_dwpose = frames_pose_ref
137
+ if i_dwpose.mode != 'RGB':
138
+ i_dwpose = i_dwpose.convert('RGB')
139
+ frame_list.append(i_frame)
140
+ dwpose_list.append(i_dwpose)
141
+ else:
142
+ # added
143
+ if first_eq_ref:
144
+ i_index = i_index - stride
145
+
146
+ i_key = list(frames_all.keys())[i_index]
147
+ i_frame = frames_all[i_key]
148
+ if i_frame.mode != 'RGB':
149
+ i_frame = i_frame.convert('RGB')
150
+ i_dwpose = dwpose_all[i_key]
151
+ if i_dwpose.mode != 'RGB':
152
+ i_dwpose = i_dwpose.convert('RGB')
153
+ frame_list.append(i_frame)
154
+ dwpose_list.append(i_dwpose)
155
+ have_frames = len(frame_list)>0
156
+ middle_indix = 0
157
+ if have_frames:
158
+ ref_frame = frame_list[middle_indix]
159
+ vit_frame = vit_transforms(ref_frame)
160
+ random_ref_frame_tmp = train_trans_pose(random_ref_frame)
161
+ random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose)
162
+ misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0)
163
+ video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0)
164
+ dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0)
165
+
166
+ video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
167
+ dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
168
+ misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
169
+ random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768]
170
+ random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
171
+ if have_frames:
172
+ video_data[:len(frame_list), ...] = video_data_tmp
173
+ misc_data[:len(frame_list), ...] = misc_data_tmp
174
+ dwpose_data[:len(frame_list), ...] = dwpose_data_tmp
175
+ random_ref_frame_data[:,...] = random_ref_frame_tmp
176
+ random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp
177
+
178
+ break
179
+
180
+ except Exception as e:
181
+ logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e))
182
+ continue
183
+
184
+ return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames
185
+
186
+
187
+
188
+ def worker(gpu, cfg, cfg_update):
189
+ '''
190
+ Inference worker for each gpu
191
+ '''
192
+ for k, v in cfg_update.items():
193
+ if isinstance(v, dict) and k in cfg:
194
+ cfg[k].update(v)
195
+ else:
196
+ cfg[k] = v
197
+
198
+ cfg.gpu = gpu
199
+ cfg.seed = int(cfg.seed)
200
+ cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu
201
+ setup_seed(cfg.seed + cfg.rank)
202
+
203
+ if not cfg.debug:
204
+ torch.cuda.set_device(gpu)
205
+ torch.backends.cudnn.benchmark = True
206
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
207
+ torch.backends.cudnn.benchmark = False
208
+ dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)
209
+
210
+ # [Log] Save logging and make log dir
211
+ log_dir = generalized_all_gather(cfg.log_dir)[0]
212
+ inf_name = osp.basename(cfg.cfg_file).split('.')[0]
213
+ test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1]
214
+
215
+ cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name))
216
+ os.makedirs(cfg.log_dir, exist_ok=True)
217
+ log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank))
218
+ cfg.log_file = log_file
219
+ reload(logging)
220
+ logging.basicConfig(
221
+ level=logging.INFO,
222
+ format='[%(asctime)s] %(levelname)s: %(message)s',
223
+ handlers=[
224
+ logging.FileHandler(filename=log_file),
225
+ logging.StreamHandler(stream=sys.stdout)])
226
+ logging.info(cfg)
227
+ logging.info(f"Running UniAnimate inference on gpu {gpu}")
228
+
229
+ # [Diffusion]
230
+ diffusion = DIFFUSION.build(cfg.Diffusion)
231
+
232
+ # [Data] Data Transform
233
+ train_trans = data.Compose([
234
+ data.Resize(cfg.resolution),
235
+ data.ToTensor(),
236
+ data.Normalize(mean=cfg.mean, std=cfg.std)
237
+ ])
238
+
239
+ train_trans_pose = data.Compose([
240
+ data.Resize(cfg.resolution),
241
+ data.ToTensor(),
242
+ ]
243
+ )
244
+
245
+ vit_transforms = T.Compose([
246
+ data.Resize(cfg.vit_resolution),
247
+ T.ToTensor(),
248
+ T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
249
+
250
+ # [Model] embedder
251
+ clip_encoder = EMBEDDER.build(cfg.embedder)
252
+ clip_encoder.model.to(gpu)
253
+ with torch.no_grad():
254
+ _, _, zero_y = clip_encoder(text="")
255
+
256
+
257
+ # [Model] auotoencoder
258
+ autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
259
+ autoencoder.eval() # freeze
260
+ for param in autoencoder.parameters():
261
+ param.requires_grad = False
262
+ autoencoder.cuda()
263
+
264
+ # [Model] UNet
265
+ if "config" in cfg.UNet:
266
+ cfg.UNet["config"] = cfg
267
+ cfg.UNet["zero_y"] = zero_y
268
+ model = MODEL.build(cfg.UNet)
269
+ state_dict = torch.load(cfg.test_model, map_location='cpu')
270
+ if 'state_dict' in state_dict:
271
+ state_dict = state_dict['state_dict']
272
+ if 'step' in state_dict:
273
+ resume_step = state_dict['step']
274
+ else:
275
+ resume_step = 0
276
+ status = model.load_state_dict(state_dict, strict=True)
277
+ logging.info('Load model from {} with status {}'.format(cfg.test_model, status))
278
+ model = model.to(gpu)
279
+ model.eval()
280
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
281
+ model.to(torch.float16)
282
+ else:
283
+ model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
284
+ torch.cuda.empty_cache()
285
+
286
+
287
+
288
+ test_list = cfg.test_list_path
289
+ num_videos = len(test_list)
290
+ logging.info(f'There are {num_videos} videos. with {cfg.round} times')
291
+ test_list = [item for _ in range(cfg.round) for item in test_list]
292
+
293
+ for idx, file_path in enumerate(test_list):
294
+ cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2]
295
+
296
+ manual_seed = int(cfg.seed + cfg.rank + idx//num_videos)
297
+ setup_seed(manual_seed)
298
+ logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...")
299
+
300
+
301
+ vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution)
302
+ cfg.max_frames_new = max_frames
303
+ misc_data = misc_data.unsqueeze(0).to(gpu)
304
+ vit_frame = vit_frame.unsqueeze(0).to(gpu)
305
+ dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
306
+ random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
307
+ random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)
308
+
309
+ ### save for visualization
310
+ misc_backups = copy(misc_data)
311
+ frames_num = misc_data.shape[1]
312
+ misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
313
+ mv_data_video = []
314
+
315
+
316
+ ### local image (first frame)
317
+ image_local = []
318
+ if 'local_image' in cfg.video_compositions:
319
+ frames_num = misc_data.shape[1]
320
+ bs_vd_local = misc_data.shape[0]
321
+ image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1)
322
+ image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
323
+ image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
324
+ if hasattr(cfg, "latent_local_image") and cfg.latent_local_image:
325
+ with torch.no_grad():
326
+ temporal_length = frames_num
327
+ encoder_posterior = autoencoder.encode(video_data[:,0])
328
+ local_image_data = get_first_stage_encoding(encoder_posterior).detach()
329
+ image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
330
+
331
+
332
+
333
+ ### encode the video_data
334
+ bs_vd = misc_data.shape[0]
335
+ misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
336
+ misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0)
337
+
338
+
339
+ with torch.no_grad():
340
+
341
+ random_ref_frame = []
342
+ if 'randomref' in cfg.video_compositions:
343
+ random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
344
+ if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref:
345
+
346
+ temporal_length = random_ref_frame_data.shape[1]
347
+ encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5))
348
+ random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach()
349
+ random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
350
+
351
+ random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
352
+
353
+
354
+ if 'dwpose' in cfg.video_compositions:
355
+ bs_vd_local = dwpose_data.shape[0]
356
+ dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local)
357
+ if 'randomref_pose' in cfg.video_compositions:
358
+ dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1)
359
+ dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local)
360
+
361
+
362
+ y_visual = []
363
+ if 'image' in cfg.video_compositions:
364
+ with torch.no_grad():
365
+ vit_frame = vit_frame.squeeze(1)
366
+ y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
367
+ y_visual0 = y_visual.clone()
368
+
369
+
370
+ with amp.autocast(enabled=True):
371
+ pynvml.nvmlInit()
372
+ handle=pynvml.nvmlDeviceGetHandleByIndex(0)
373
+ meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
374
+ cur_seed = torch.initial_seed()
375
+ logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....")
376
+
377
+ noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)])
378
+ noise = noise.to(gpu)
379
+
380
+ # add a noise prior
381
+ noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise)
382
+
383
+ if hasattr(cfg.Diffusion, "noise_strength"):
384
+ b, c, f, _, _= noise.shape
385
+ offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
386
+ noise = noise + cfg.Diffusion.noise_strength * offset_noise
387
+
388
+ # construct model inputs (CFG)
389
+ full_model_kwargs=[{
390
+ 'y': None,
391
+ "local_image": None if len(image_local) == 0 else image_local[:],
392
+ 'image': None if len(y_visual) == 0 else y_visual0[:],
393
+ 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:],
394
+ 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:],
395
+ },
396
+ {
397
+ 'y': None,
398
+ "local_image": None,
399
+ 'image': None,
400
+ 'randomref': None,
401
+ 'dwpose': None,
402
+ }]
403
+
404
+ # for visualization
405
+ full_model_kwargs_vis =[{
406
+ 'y': None,
407
+ "local_image": None if len(image_local) == 0 else image_local_clone[:],
408
+ 'image': None,
409
+ 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:],
410
+ 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3],
411
+ },
412
+ {
413
+ 'y': None,
414
+ "local_image": None,
415
+ 'image': None,
416
+ 'randomref': None,
417
+ 'dwpose': None,
418
+ }]
419
+
420
+
421
+ partial_keys = [
422
+ ['image', 'randomref', "dwpose"],
423
+ ]
424
+ if hasattr(cfg, "partial_keys") and cfg.partial_keys:
425
+ partial_keys = cfg.partial_keys
426
+
427
+ for partial_keys_one in partial_keys:
428
+ model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one,
429
+ full_model_kwargs = full_model_kwargs,
430
+ use_fps_condition = cfg.use_fps_condition)
431
+ model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one,
432
+ full_model_kwargs = full_model_kwargs_vis,
433
+ use_fps_condition = cfg.use_fps_condition)
434
+ noise_one = noise
435
+
436
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
437
+ clip_encoder.cpu() # add this line
438
+ autoencoder.cpu() # add this line
439
+ torch.cuda.empty_cache() # add this line
440
+
441
+ video_data = diffusion.ddim_sample_loop(
442
+ noise=noise_one,
443
+ context_size=cfg.context_size,
444
+ context_stride=cfg.context_stride,
445
+ context_overlap=cfg.context_overlap,
446
+ model=model.eval(),
447
+ model_kwargs=model_kwargs_one,
448
+ guide_scale=cfg.guide_scale,
449
+ ddim_timesteps=cfg.ddim_timesteps,
450
+ eta=0.0,
451
+ context_batch_size=getattr(cfg, "context_batch_size", 1)
452
+ )
453
+
454
+ if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
455
+ # if run forward of autoencoder or clip_encoder second times, load them again
456
+ clip_encoder.cuda()
457
+ autoencoder.cuda()
458
+
459
+
460
+ video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w]
461
+ video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
462
+ chunk_size = min(cfg.decoder_bs, video_data.shape[0])
463
+ video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0)
464
+ decode_data = []
465
+ for vd_data in video_data_list:
466
+ gen_frames = autoencoder.decode(vd_data)
467
+ decode_data.append(gen_frames)
468
+ video_data = torch.cat(decode_data, dim=0)
469
+ video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float()
470
+
471
+ text_size = cfg.resolution[-1]
472
+ cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_')
473
+ name = f'seed_{cur_seed}'
474
+ for ii in partial_keys_one:
475
+ name = name + "_" + ii
476
+ file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4'
477
+ local_path = os.path.join(cfg.log_dir, f'{file_name}')
478
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
479
+ captions = "human"
480
+ del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]]
481
+ del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]]
482
+
483
+ save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups,
484
+ cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps)
485
+
486
+ # try:
487
+ # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
488
+ # logging.info('Save video to dir %s:' % (local_path))
489
+ # except Exception as e:
490
+ # logging.info(f'Step: save text or video error with {e}')
491
+
492
+ logging.info('Congratulations! The inference is completed!')
493
+ # synchronize to finish some processes
494
+ if not cfg.debug:
495
+ torch.cuda.synchronize()
496
+ dist.barrier()
497
+
498
+ def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False):
499
+
500
+ if use_fps_condition is True:
501
+ partial_keys.append('fps')
502
+
503
+ partial_model_kwargs = [{}, {}]
504
+ for partial_key in partial_keys:
505
+ partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
506
+ partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
507
+
508
+ return partial_model_kwargs
UniAnimate/tools/modules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .clip_embedder import FrozenOpenCLIPEmbedder
2
+ from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL
3
+ from .clip_embedder import *
4
+ from .autoencoder import *
5
+ from .unet import *
6
+ from .diffusions import *
7
+ from .embedding_manager import *
UniAnimate/tools/modules/autoencoder.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import collections
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from utils.registry_class import AUTO_ENCODER,DISTRIBUTION
9
+
10
+
11
+ def nonlinearity(x):
12
+ # swish
13
+ return x*torch.sigmoid(x)
14
+
15
+ def Normalize(in_channels, num_groups=32):
16
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
17
+
18
+
19
+ @torch.no_grad()
20
+ def get_first_stage_encoding(encoder_posterior, scale_factor=0.18215):
21
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
22
+ z = encoder_posterior.sample()
23
+ elif isinstance(encoder_posterior, torch.Tensor):
24
+ z = encoder_posterior
25
+ else:
26
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
27
+ return scale_factor * z
28
+
29
+
30
+ @AUTO_ENCODER.register_class()
31
+ class AutoencoderKL(nn.Module):
32
+ def __init__(self,
33
+ ddconfig,
34
+ embed_dim,
35
+ pretrained=None,
36
+ ignore_keys=[],
37
+ image_key="image",
38
+ colorize_nlabels=None,
39
+ monitor=None,
40
+ ema_decay=None,
41
+ learn_logvar=False,
42
+ use_vid_decoder=False,
43
+ **kwargs):
44
+ super().__init__()
45
+ self.learn_logvar = learn_logvar
46
+ self.image_key = image_key
47
+ self.encoder = Encoder(**ddconfig)
48
+ self.decoder = Decoder(**ddconfig)
49
+ assert ddconfig["double_z"]
50
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
51
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
52
+ self.embed_dim = embed_dim
53
+ if colorize_nlabels is not None:
54
+ assert type(colorize_nlabels)==int
55
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
56
+ if monitor is not None:
57
+ self.monitor = monitor
58
+
59
+ self.use_ema = ema_decay is not None
60
+
61
+ if pretrained is not None:
62
+ self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
63
+
64
+ def init_from_ckpt(self, path, ignore_keys=list()):
65
+ sd = torch.load(path, map_location="cpu")["state_dict"]
66
+ keys = list(sd.keys())
67
+ sd_new = collections.OrderedDict()
68
+ for k in keys:
69
+ if k.find('first_stage_model') >= 0:
70
+ k_new = k.split('first_stage_model.')[-1]
71
+ sd_new[k_new] = sd[k]
72
+ self.load_state_dict(sd_new, strict=True)
73
+ logging.info(f"Restored from {path}")
74
+
75
+ def on_train_batch_end(self, *args, **kwargs):
76
+ if self.use_ema:
77
+ self.model_ema(self)
78
+
79
+ def encode(self, x):
80
+ h = self.encoder(x)
81
+ moments = self.quant_conv(h)
82
+ posterior = DiagonalGaussianDistribution(moments)
83
+ return posterior
84
+
85
+ def encode_firsr_stage(self, x, scale_factor=1.0):
86
+ h = self.encoder(x)
87
+ moments = self.quant_conv(h)
88
+ posterior = DiagonalGaussianDistribution(moments)
89
+ z = get_first_stage_encoding(posterior, scale_factor)
90
+ return z
91
+
92
+ def encode_ms(self, x):
93
+ hs = self.encoder(x, True)
94
+ h = hs[-1]
95
+ moments = self.quant_conv(h)
96
+ posterior = DiagonalGaussianDistribution(moments)
97
+ hs[-1] = h
98
+ return hs
99
+
100
+ def decode(self, z, **kwargs):
101
+ z = self.post_quant_conv(z)
102
+ dec = self.decoder(z, **kwargs)
103
+ return dec
104
+
105
+
106
+ def forward(self, input, sample_posterior=True):
107
+ posterior = self.encode(input)
108
+ if sample_posterior:
109
+ z = posterior.sample()
110
+ else:
111
+ z = posterior.mode()
112
+ dec = self.decode(z)
113
+ return dec, posterior
114
+
115
+ def get_input(self, batch, k):
116
+ x = batch[k]
117
+ if len(x.shape) == 3:
118
+ x = x[..., None]
119
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
120
+ return x
121
+
122
+ def get_last_layer(self):
123
+ return self.decoder.conv_out.weight
124
+
125
+ @torch.no_grad()
126
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
127
+ log = dict()
128
+ x = self.get_input(batch, self.image_key)
129
+ x = x.to(self.device)
130
+ if not only_inputs:
131
+ xrec, posterior = self(x)
132
+ if x.shape[1] > 3:
133
+ # colorize with random projection
134
+ assert xrec.shape[1] > 3
135
+ x = self.to_rgb(x)
136
+ xrec = self.to_rgb(xrec)
137
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
138
+ log["reconstructions"] = xrec
139
+ if log_ema or self.use_ema:
140
+ with self.ema_scope():
141
+ xrec_ema, posterior_ema = self(x)
142
+ if x.shape[1] > 3:
143
+ # colorize with random projection
144
+ assert xrec_ema.shape[1] > 3
145
+ xrec_ema = self.to_rgb(xrec_ema)
146
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
147
+ log["reconstructions_ema"] = xrec_ema
148
+ log["inputs"] = x
149
+ return log
150
+
151
+ def to_rgb(self, x):
152
+ assert self.image_key == "segmentation"
153
+ if not hasattr(self, "colorize"):
154
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
155
+ x = F.conv2d(x, weight=self.colorize)
156
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
157
+ return x
158
+
159
+
160
+ @AUTO_ENCODER.register_class()
161
+ class AutoencoderVideo(AutoencoderKL):
162
+ def __init__(self,
163
+ ddconfig,
164
+ embed_dim,
165
+ pretrained=None,
166
+ ignore_keys=[],
167
+ image_key="image",
168
+ colorize_nlabels=None,
169
+ monitor=None,
170
+ ema_decay=None,
171
+ use_vid_decoder=True,
172
+ learn_logvar=False,
173
+ **kwargs):
174
+ use_vid_decoder = True
175
+ super().__init__(ddconfig, embed_dim, pretrained, ignore_keys, image_key, colorize_nlabels, monitor, ema_decay, learn_logvar, use_vid_decoder, **kwargs)
176
+
177
+ def decode(self, z, **kwargs):
178
+ # z = self.post_quant_conv(z)
179
+ dec = self.decoder(z, **kwargs)
180
+ return dec
181
+
182
+ def encode(self, x):
183
+ h = self.encoder(x)
184
+ # moments = self.quant_conv(h)
185
+ moments = h
186
+ posterior = DiagonalGaussianDistribution(moments)
187
+ return posterior
188
+
189
+
190
+ class IdentityFirstStage(torch.nn.Module):
191
+ def __init__(self, *args, vq_interface=False, **kwargs):
192
+ self.vq_interface = vq_interface
193
+ super().__init__()
194
+
195
+ def encode(self, x, *args, **kwargs):
196
+ return x
197
+
198
+ def decode(self, x, *args, **kwargs):
199
+ return x
200
+
201
+ def quantize(self, x, *args, **kwargs):
202
+ if self.vq_interface:
203
+ return x, None, [None, None, None]
204
+ return x
205
+
206
+ def forward(self, x, *args, **kwargs):
207
+ return x
208
+
209
+
210
+
211
+ @DISTRIBUTION.register_class()
212
+ class DiagonalGaussianDistribution(object):
213
+ def __init__(self, parameters, deterministic=False):
214
+ self.parameters = parameters
215
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
216
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
217
+ self.deterministic = deterministic
218
+ self.std = torch.exp(0.5 * self.logvar)
219
+ self.var = torch.exp(self.logvar)
220
+ if self.deterministic:
221
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
222
+
223
+ def sample(self):
224
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
225
+ return x
226
+
227
+ def kl(self, other=None):
228
+ if self.deterministic:
229
+ return torch.Tensor([0.])
230
+ else:
231
+ if other is None:
232
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
233
+ + self.var - 1.0 - self.logvar,
234
+ dim=[1, 2, 3])
235
+ else:
236
+ return 0.5 * torch.sum(
237
+ torch.pow(self.mean - other.mean, 2) / other.var
238
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
239
+ dim=[1, 2, 3])
240
+
241
+ def nll(self, sample, dims=[1,2,3]):
242
+ if self.deterministic:
243
+ return torch.Tensor([0.])
244
+ logtwopi = np.log(2.0 * np.pi)
245
+ return 0.5 * torch.sum(
246
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
247
+ dim=dims)
248
+
249
+ def mode(self):
250
+ return self.mean
251
+
252
+
253
+ # -------------------------------modules--------------------------------
254
+
255
+ class Downsample(nn.Module):
256
+ def __init__(self, in_channels, with_conv):
257
+ super().__init__()
258
+ self.with_conv = with_conv
259
+ if self.with_conv:
260
+ # no asymmetric padding in torch conv, must do it ourselves
261
+ self.conv = torch.nn.Conv2d(in_channels,
262
+ in_channels,
263
+ kernel_size=3,
264
+ stride=2,
265
+ padding=0)
266
+
267
+ def forward(self, x):
268
+ if self.with_conv:
269
+ pad = (0,1,0,1)
270
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
271
+ x = self.conv(x)
272
+ else:
273
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
274
+ return x
275
+
276
+ class ResnetBlock(nn.Module):
277
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
278
+ dropout, temb_channels=512):
279
+ super().__init__()
280
+ self.in_channels = in_channels
281
+ out_channels = in_channels if out_channels is None else out_channels
282
+ self.out_channels = out_channels
283
+ self.use_conv_shortcut = conv_shortcut
284
+
285
+ self.norm1 = Normalize(in_channels)
286
+ self.conv1 = torch.nn.Conv2d(in_channels,
287
+ out_channels,
288
+ kernel_size=3,
289
+ stride=1,
290
+ padding=1)
291
+ if temb_channels > 0:
292
+ self.temb_proj = torch.nn.Linear(temb_channels,
293
+ out_channels)
294
+ self.norm2 = Normalize(out_channels)
295
+ self.dropout = torch.nn.Dropout(dropout)
296
+ self.conv2 = torch.nn.Conv2d(out_channels,
297
+ out_channels,
298
+ kernel_size=3,
299
+ stride=1,
300
+ padding=1)
301
+ if self.in_channels != self.out_channels:
302
+ if self.use_conv_shortcut:
303
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
304
+ out_channels,
305
+ kernel_size=3,
306
+ stride=1,
307
+ padding=1)
308
+ else:
309
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
310
+ out_channels,
311
+ kernel_size=1,
312
+ stride=1,
313
+ padding=0)
314
+
315
+ def forward(self, x, temb):
316
+ h = x
317
+ h = self.norm1(h)
318
+ h = nonlinearity(h)
319
+ h = self.conv1(h)
320
+
321
+ if temb is not None:
322
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
323
+
324
+ h = self.norm2(h)
325
+ h = nonlinearity(h)
326
+ h = self.dropout(h)
327
+ h = self.conv2(h)
328
+
329
+ if self.in_channels != self.out_channels:
330
+ if self.use_conv_shortcut:
331
+ x = self.conv_shortcut(x)
332
+ else:
333
+ x = self.nin_shortcut(x)
334
+
335
+ return x+h
336
+
337
+
338
+ class AttnBlock(nn.Module):
339
+ def __init__(self, in_channels):
340
+ super().__init__()
341
+ self.in_channels = in_channels
342
+
343
+ self.norm = Normalize(in_channels)
344
+ self.q = torch.nn.Conv2d(in_channels,
345
+ in_channels,
346
+ kernel_size=1,
347
+ stride=1,
348
+ padding=0)
349
+ self.k = torch.nn.Conv2d(in_channels,
350
+ in_channels,
351
+ kernel_size=1,
352
+ stride=1,
353
+ padding=0)
354
+ self.v = torch.nn.Conv2d(in_channels,
355
+ in_channels,
356
+ kernel_size=1,
357
+ stride=1,
358
+ padding=0)
359
+ self.proj_out = torch.nn.Conv2d(in_channels,
360
+ in_channels,
361
+ kernel_size=1,
362
+ stride=1,
363
+ padding=0)
364
+
365
+ def forward(self, x):
366
+ h_ = x
367
+ h_ = self.norm(h_)
368
+ q = self.q(h_)
369
+ k = self.k(h_)
370
+ v = self.v(h_)
371
+
372
+ # compute attention
373
+ b,c,h,w = q.shape
374
+ q = q.reshape(b,c,h*w)
375
+ q = q.permute(0,2,1) # b,hw,c
376
+ k = k.reshape(b,c,h*w) # b,c,hw
377
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
378
+ w_ = w_ * (int(c)**(-0.5))
379
+ w_ = torch.nn.functional.softmax(w_, dim=2)
380
+
381
+ # attend to values
382
+ v = v.reshape(b,c,h*w)
383
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
384
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
385
+ h_ = h_.reshape(b,c,h,w)
386
+
387
+ h_ = self.proj_out(h_)
388
+
389
+ return x+h_
390
+
391
+ class AttnBlock(nn.Module):
392
+ def __init__(self, in_channels):
393
+ super().__init__()
394
+ self.in_channels = in_channels
395
+
396
+ self.norm = Normalize(in_channels)
397
+ self.q = torch.nn.Conv2d(in_channels,
398
+ in_channels,
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0)
402
+ self.k = torch.nn.Conv2d(in_channels,
403
+ in_channels,
404
+ kernel_size=1,
405
+ stride=1,
406
+ padding=0)
407
+ self.v = torch.nn.Conv2d(in_channels,
408
+ in_channels,
409
+ kernel_size=1,
410
+ stride=1,
411
+ padding=0)
412
+ self.proj_out = torch.nn.Conv2d(in_channels,
413
+ in_channels,
414
+ kernel_size=1,
415
+ stride=1,
416
+ padding=0)
417
+
418
+ def forward(self, x):
419
+ h_ = x
420
+ h_ = self.norm(h_)
421
+ q = self.q(h_)
422
+ k = self.k(h_)
423
+ v = self.v(h_)
424
+
425
+ # compute attention
426
+ b,c,h,w = q.shape
427
+ q = q.reshape(b,c,h*w)
428
+ q = q.permute(0,2,1) # b,hw,c
429
+ k = k.reshape(b,c,h*w) # b,c,hw
430
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
431
+ w_ = w_ * (int(c)**(-0.5))
432
+ w_ = torch.nn.functional.softmax(w_, dim=2)
433
+
434
+ # attend to values
435
+ v = v.reshape(b,c,h*w)
436
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
437
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
438
+ h_ = h_.reshape(b,c,h,w)
439
+
440
+ h_ = self.proj_out(h_)
441
+
442
+ return x+h_
443
+
444
+ class Upsample(nn.Module):
445
+ def __init__(self, in_channels, with_conv):
446
+ super().__init__()
447
+ self.with_conv = with_conv
448
+ if self.with_conv:
449
+ self.conv = torch.nn.Conv2d(in_channels,
450
+ in_channels,
451
+ kernel_size=3,
452
+ stride=1,
453
+ padding=1)
454
+
455
+ def forward(self, x):
456
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
457
+ if self.with_conv:
458
+ x = self.conv(x)
459
+ return x
460
+
461
+
462
+ class Downsample(nn.Module):
463
+ def __init__(self, in_channels, with_conv):
464
+ super().__init__()
465
+ self.with_conv = with_conv
466
+ if self.with_conv:
467
+ # no asymmetric padding in torch conv, must do it ourselves
468
+ self.conv = torch.nn.Conv2d(in_channels,
469
+ in_channels,
470
+ kernel_size=3,
471
+ stride=2,
472
+ padding=0)
473
+
474
+ def forward(self, x):
475
+ if self.with_conv:
476
+ pad = (0,1,0,1)
477
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
478
+ x = self.conv(x)
479
+ else:
480
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
481
+ return x
482
+
483
+ class Encoder(nn.Module):
484
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
485
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
486
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
487
+ **ignore_kwargs):
488
+ super().__init__()
489
+ if use_linear_attn: attn_type = "linear"
490
+ self.ch = ch
491
+ self.temb_ch = 0
492
+ self.num_resolutions = len(ch_mult)
493
+ self.num_res_blocks = num_res_blocks
494
+ self.resolution = resolution
495
+ self.in_channels = in_channels
496
+
497
+ # downsampling
498
+ self.conv_in = torch.nn.Conv2d(in_channels,
499
+ self.ch,
500
+ kernel_size=3,
501
+ stride=1,
502
+ padding=1)
503
+
504
+ curr_res = resolution
505
+ in_ch_mult = (1,)+tuple(ch_mult)
506
+ self.in_ch_mult = in_ch_mult
507
+ self.down = nn.ModuleList()
508
+ for i_level in range(self.num_resolutions):
509
+ block = nn.ModuleList()
510
+ attn = nn.ModuleList()
511
+ block_in = ch*in_ch_mult[i_level]
512
+ block_out = ch*ch_mult[i_level]
513
+ for i_block in range(self.num_res_blocks):
514
+ block.append(ResnetBlock(in_channels=block_in,
515
+ out_channels=block_out,
516
+ temb_channels=self.temb_ch,
517
+ dropout=dropout))
518
+ block_in = block_out
519
+ if curr_res in attn_resolutions:
520
+ attn.append(AttnBlock(block_in))
521
+ down = nn.Module()
522
+ down.block = block
523
+ down.attn = attn
524
+ if i_level != self.num_resolutions-1:
525
+ down.downsample = Downsample(block_in, resamp_with_conv)
526
+ curr_res = curr_res // 2
527
+ self.down.append(down)
528
+
529
+ # middle
530
+ self.mid = nn.Module()
531
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
532
+ out_channels=block_in,
533
+ temb_channels=self.temb_ch,
534
+ dropout=dropout)
535
+ self.mid.attn_1 = AttnBlock(block_in)
536
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
537
+ out_channels=block_in,
538
+ temb_channels=self.temb_ch,
539
+ dropout=dropout)
540
+
541
+ # end
542
+ self.norm_out = Normalize(block_in)
543
+ self.conv_out = torch.nn.Conv2d(block_in,
544
+ 2*z_channels if double_z else z_channels,
545
+ kernel_size=3,
546
+ stride=1,
547
+ padding=1)
548
+
549
+ def forward(self, x, return_feat=False):
550
+ # timestep embedding
551
+ temb = None
552
+
553
+ # downsampling
554
+ hs = [self.conv_in(x)]
555
+ for i_level in range(self.num_resolutions):
556
+ for i_block in range(self.num_res_blocks):
557
+ h = self.down[i_level].block[i_block](hs[-1], temb)
558
+ if len(self.down[i_level].attn) > 0:
559
+ h = self.down[i_level].attn[i_block](h)
560
+ hs.append(h)
561
+ if i_level != self.num_resolutions-1:
562
+ hs.append(self.down[i_level].downsample(hs[-1]))
563
+
564
+ # middle
565
+ h = hs[-1]
566
+ h = self.mid.block_1(h, temb)
567
+ h = self.mid.attn_1(h)
568
+ h = self.mid.block_2(h, temb)
569
+
570
+ # end
571
+ h = self.norm_out(h)
572
+ h = nonlinearity(h)
573
+ h = self.conv_out(h)
574
+ if return_feat:
575
+ hs[-1] = h
576
+ return hs
577
+ else:
578
+ return h
579
+
580
+
581
+ class Decoder(nn.Module):
582
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
583
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
584
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
585
+ attn_type="vanilla", **ignorekwargs):
586
+ super().__init__()
587
+ if use_linear_attn: attn_type = "linear"
588
+ self.ch = ch
589
+ self.temb_ch = 0
590
+ self.num_resolutions = len(ch_mult)
591
+ self.num_res_blocks = num_res_blocks
592
+ self.resolution = resolution
593
+ self.in_channels = in_channels
594
+ self.give_pre_end = give_pre_end
595
+ self.tanh_out = tanh_out
596
+
597
+ # compute in_ch_mult, block_in and curr_res at lowest res
598
+ in_ch_mult = (1,)+tuple(ch_mult)
599
+ block_in = ch*ch_mult[self.num_resolutions-1]
600
+ curr_res = resolution // 2**(self.num_resolutions-1)
601
+ self.z_shape = (1,z_channels, curr_res, curr_res)
602
+ # logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
603
+
604
+ # z to block_in
605
+ self.conv_in = torch.nn.Conv2d(z_channels,
606
+ block_in,
607
+ kernel_size=3,
608
+ stride=1,
609
+ padding=1)
610
+
611
+ # middle
612
+ self.mid = nn.Module()
613
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
614
+ out_channels=block_in,
615
+ temb_channels=self.temb_ch,
616
+ dropout=dropout)
617
+ self.mid.attn_1 = AttnBlock(block_in)
618
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
619
+ out_channels=block_in,
620
+ temb_channels=self.temb_ch,
621
+ dropout=dropout)
622
+
623
+ # upsampling
624
+ self.up = nn.ModuleList()
625
+ for i_level in reversed(range(self.num_resolutions)):
626
+ block = nn.ModuleList()
627
+ attn = nn.ModuleList()
628
+ block_out = ch*ch_mult[i_level]
629
+ for i_block in range(self.num_res_blocks+1):
630
+ block.append(ResnetBlock(in_channels=block_in,
631
+ out_channels=block_out,
632
+ temb_channels=self.temb_ch,
633
+ dropout=dropout))
634
+ block_in = block_out
635
+ if curr_res in attn_resolutions:
636
+ attn.append(AttnBlock(block_in))
637
+ up = nn.Module()
638
+ up.block = block
639
+ up.attn = attn
640
+ if i_level != 0:
641
+ up.upsample = Upsample(block_in, resamp_with_conv)
642
+ curr_res = curr_res * 2
643
+ self.up.insert(0, up) # prepend to get consistent order
644
+
645
+ # end
646
+ self.norm_out = Normalize(block_in)
647
+ self.conv_out = torch.nn.Conv2d(block_in,
648
+ out_ch,
649
+ kernel_size=3,
650
+ stride=1,
651
+ padding=1)
652
+
653
+ def forward(self, z, **kwargs):
654
+ #assert z.shape[1:] == self.z_shape[1:]
655
+ self.last_z_shape = z.shape
656
+
657
+ # timestep embedding
658
+ temb = None
659
+
660
+ # z to block_in
661
+ h = self.conv_in(z)
662
+
663
+ # middle
664
+ h = self.mid.block_1(h, temb)
665
+ h = self.mid.attn_1(h)
666
+ h = self.mid.block_2(h, temb)
667
+
668
+ # upsampling
669
+ for i_level in reversed(range(self.num_resolutions)):
670
+ for i_block in range(self.num_res_blocks+1):
671
+ h = self.up[i_level].block[i_block](h, temb)
672
+ if len(self.up[i_level].attn) > 0:
673
+ h = self.up[i_level].attn[i_block](h)
674
+ if i_level != 0:
675
+ h = self.up[i_level].upsample(h)
676
+
677
+ # end
678
+ if self.give_pre_end:
679
+ return h
680
+
681
+ h = self.norm_out(h)
682
+ h = nonlinearity(h)
683
+ h = self.conv_out(h)
684
+ if self.tanh_out:
685
+ h = torch.tanh(h)
686
+ return h
687
+
688
+
689
+
690
+
UniAnimate/tools/modules/clip_embedder.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import open_clip
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torchvision.transforms as T
8
+
9
+ from utils.registry_class import EMBEDDER
10
+
11
+
12
+ @EMBEDDER.register_class()
13
+ class FrozenOpenCLIPEmbedder(nn.Module):
14
+ """
15
+ Uses the OpenCLIP transformer encoder for text
16
+ """
17
+ LAYERS = [
18
+ #"pooled",
19
+ "last",
20
+ "penultimate"
21
+ ]
22
+ def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
23
+ freeze=True, layer="last"):
24
+ super().__init__()
25
+ assert layer in self.LAYERS
26
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
27
+ del model.visual
28
+ self.model = model
29
+
30
+ self.device = device
31
+ self.max_length = max_length
32
+ if freeze:
33
+ self.freeze()
34
+ self.layer = layer
35
+ if self.layer == "last":
36
+ self.layer_idx = 0
37
+ elif self.layer == "penultimate":
38
+ self.layer_idx = 1
39
+ else:
40
+ raise NotImplementedError()
41
+
42
+ def freeze(self):
43
+ self.model = self.model.eval()
44
+ for param in self.parameters():
45
+ param.requires_grad = False
46
+
47
+ def forward(self, text):
48
+ tokens = open_clip.tokenize(text)
49
+ z = self.encode_with_transformer(tokens.to(self.device))
50
+ return z
51
+
52
+ def encode_with_transformer(self, text):
53
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
54
+ x = x + self.model.positional_embedding
55
+ x = x.permute(1, 0, 2) # NLD -> LND
56
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
57
+ x = x.permute(1, 0, 2) # LND -> NLD
58
+ x = self.model.ln_final(x)
59
+ return x
60
+
61
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
62
+ for i, r in enumerate(self.model.transformer.resblocks):
63
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
64
+ break
65
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
66
+ x = checkpoint(r, x, attn_mask)
67
+ else:
68
+ x = r(x, attn_mask=attn_mask)
69
+ return x
70
+
71
+ def encode(self, text):
72
+ return self(text)
73
+
74
+
75
+ @EMBEDDER.register_class()
76
+ class FrozenOpenCLIPVisualEmbedder(nn.Module):
77
+ """
78
+ Uses the OpenCLIP transformer encoder for text
79
+ """
80
+ LAYERS = [
81
+ #"pooled",
82
+ "last",
83
+ "penultimate"
84
+ ]
85
+ def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77,
86
+ freeze=True, layer="last"):
87
+ super().__init__()
88
+ assert layer in self.LAYERS
89
+ model, _, preprocess = open_clip.create_model_and_transforms(
90
+ arch, device=torch.device('cpu'), pretrained=pretrained)
91
+
92
+ del model.transformer
93
+ self.model = model
94
+ data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255
95
+ self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
96
+
97
+ self.device = device
98
+ self.max_length = max_length # 77
99
+ if freeze:
100
+ self.freeze()
101
+ self.layer = layer # 'penultimate'
102
+ if self.layer == "last":
103
+ self.layer_idx = 0
104
+ elif self.layer == "penultimate":
105
+ self.layer_idx = 1
106
+ else:
107
+ raise NotImplementedError()
108
+
109
+ def freeze(self):
110
+ self.model = self.model.eval()
111
+ for param in self.parameters():
112
+ param.requires_grad = False
113
+
114
+ def forward(self, image):
115
+ # tokens = open_clip.tokenize(text)
116
+ z = self.model.encode_image(image.to(self.device))
117
+ return z
118
+
119
+ def encode_with_transformer(self, text):
120
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
121
+ x = x + self.model.positional_embedding
122
+ x = x.permute(1, 0, 2) # NLD -> LND
123
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
124
+ x = x.permute(1, 0, 2) # LND -> NLD
125
+ x = self.model.ln_final(x)
126
+
127
+ return x
128
+
129
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
130
+ for i, r in enumerate(self.model.transformer.resblocks):
131
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
132
+ break
133
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
134
+ x = checkpoint(r, x, attn_mask)
135
+ else:
136
+ x = r(x, attn_mask=attn_mask)
137
+ return x
138
+
139
+ def encode(self, text):
140
+ return self(text)
141
+
142
+
143
+
144
+ @EMBEDDER.register_class()
145
+ class FrozenOpenCLIPTextVisualEmbedder(nn.Module):
146
+ """
147
+ Uses the OpenCLIP transformer encoder for text
148
+ """
149
+ LAYERS = [
150
+ #"pooled",
151
+ "last",
152
+ "penultimate"
153
+ ]
154
+ def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
155
+ freeze=True, layer="last", **kwargs):
156
+ super().__init__()
157
+ assert layer in self.LAYERS
158
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
159
+ self.model = model
160
+
161
+ self.device = device
162
+ self.max_length = max_length
163
+ if freeze:
164
+ self.freeze()
165
+ self.layer = layer
166
+ if self.layer == "last":
167
+ self.layer_idx = 0
168
+ elif self.layer == "penultimate":
169
+ self.layer_idx = 1
170
+ else:
171
+ raise NotImplementedError()
172
+
173
+ def freeze(self):
174
+ self.model = self.model.eval()
175
+ for param in self.parameters():
176
+ param.requires_grad = False
177
+
178
+
179
+ def forward(self, image=None, text=None):
180
+
181
+ xi = self.model.encode_image(image.to(self.device)) if image is not None else None
182
+ tokens = open_clip.tokenize(text)
183
+ xt, x = self.encode_with_transformer(tokens.to(self.device))
184
+ return xi, xt, x
185
+
186
+ def encode_with_transformer(self, text):
187
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
188
+ x = x + self.model.positional_embedding
189
+ x = x.permute(1, 0, 2) # NLD -> LND
190
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
191
+ x = x.permute(1, 0, 2) # LND -> NLD
192
+ x = self.model.ln_final(x)
193
+ xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
194
+ return xt, x
195
+
196
+
197
+ def encode_image(self, image):
198
+ return self.model.visual(image)
199
+
200
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
201
+ for i, r in enumerate(self.model.transformer.resblocks):
202
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
203
+ break
204
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
205
+ x = checkpoint(r, x, attn_mask)
206
+ else:
207
+ x = r(x, attn_mask=attn_mask)
208
+ return x
209
+
210
+ def encode(self, text):
211
+
212
+ return self(text)
UniAnimate/tools/modules/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import os.path as osp
4
+ from datetime import datetime
5
+ from easydict import EasyDict
6
+ import os
7
+
8
+ cfg = EasyDict(__name__='Config: VideoLDM Decoder')
9
+
10
+ # -------------------------------distributed training--------------------------
11
+ pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
12
+ gpus_per_machine = torch.cuda.device_count()
13
+ world_size = pmi_world_size * gpus_per_machine
14
+ # -----------------------------------------------------------------------------
15
+
16
+
17
+ # ---------------------------Dataset Parameter---------------------------------
18
+ cfg.mean = [0.5, 0.5, 0.5]
19
+ cfg.std = [0.5, 0.5, 0.5]
20
+ cfg.max_words = 1000
21
+ cfg.num_workers = 8
22
+ cfg.prefetch_factor = 2
23
+
24
+ # PlaceHolder
25
+ cfg.resolution = [448, 256]
26
+ cfg.vit_out_dim = 1024
27
+ cfg.vit_resolution = 336
28
+ cfg.depth_clamp = 10.0
29
+ cfg.misc_size = 384
30
+ cfg.depth_std = 20.0
31
+
32
+ cfg.save_fps = 8
33
+
34
+ cfg.frame_lens = [32, 32, 32, 1]
35
+ cfg.sample_fps = [4, ]
36
+ cfg.vid_dataset = {
37
+ 'type': 'VideoBaseDataset',
38
+ 'data_list': [],
39
+ 'max_words': cfg.max_words,
40
+ 'resolution': cfg.resolution}
41
+ cfg.img_dataset = {
42
+ 'type': 'ImageBaseDataset',
43
+ 'data_list': ['laion_400m',],
44
+ 'max_words': cfg.max_words,
45
+ 'resolution': cfg.resolution}
46
+
47
+ cfg.batch_sizes = {
48
+ str(1):256,
49
+ str(4):4,
50
+ str(8):4,
51
+ str(16):4}
52
+ # -----------------------------------------------------------------------------
53
+
54
+
55
+ # ---------------------------Mode Parameters-----------------------------------
56
+ # Diffusion
57
+ cfg.Diffusion = {
58
+ 'type': 'DiffusionDDIM',
59
+ 'schedule': 'cosine', # cosine
60
+ 'schedule_param': {
61
+ 'num_timesteps': 1000,
62
+ 'cosine_s': 0.008,
63
+ 'zero_terminal_snr': True,
64
+ },
65
+ 'mean_type': 'v', # [v, eps]
66
+ 'loss_type': 'mse',
67
+ 'var_type': 'fixed_small',
68
+ 'rescale_timesteps': False,
69
+ 'noise_strength': 0.1,
70
+ 'ddim_timesteps': 50
71
+ }
72
+ cfg.ddim_timesteps = 50 # official: 250
73
+ cfg.use_div_loss = False
74
+ # classifier-free guidance
75
+ cfg.p_zero = 0.9
76
+ cfg.guide_scale = 3.0
77
+
78
+ # clip vision encoder
79
+ cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
80
+ cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
81
+
82
+ # sketch
83
+ cfg.sketch_mean = [0.485, 0.456, 0.406]
84
+ cfg.sketch_std = [0.229, 0.224, 0.225]
85
+ # cfg.misc_size = 256
86
+ cfg.depth_std = 20.0
87
+ cfg.depth_clamp = 10.0
88
+ cfg.hist_sigma = 10.0
89
+
90
+ # Model
91
+ cfg.scale_factor = 0.18215
92
+ cfg.use_checkpoint = True
93
+ cfg.use_sharded_ddp = False
94
+ cfg.use_fsdp = False
95
+ cfg.use_fp16 = True
96
+ cfg.temporal_attention = True
97
+
98
+ cfg.UNet = {
99
+ 'type': 'UNetSD',
100
+ 'in_dim': 4,
101
+ 'dim': 320,
102
+ 'y_dim': cfg.vit_out_dim,
103
+ 'context_dim': 1024,
104
+ 'out_dim': 8,
105
+ 'dim_mult': [1, 2, 4, 4],
106
+ 'num_heads': 8,
107
+ 'head_dim': 64,
108
+ 'num_res_blocks': 2,
109
+ 'attn_scales': [1 / 1, 1 / 2, 1 / 4],
110
+ 'dropout': 0.1,
111
+ 'temporal_attention': cfg.temporal_attention,
112
+ 'temporal_attn_times': 1,
113
+ 'use_checkpoint': cfg.use_checkpoint,
114
+ 'use_fps_condition': False,
115
+ 'use_sim_mask': False
116
+ }
117
+
118
+ # auotoencoder from stabel diffusion
119
+ cfg.guidances = []
120
+ cfg.auto_encoder = {
121
+ 'type': 'AutoencoderKL',
122
+ 'ddconfig': {
123
+ 'double_z': True,
124
+ 'z_channels': 4,
125
+ 'resolution': 256,
126
+ 'in_channels': 3,
127
+ 'out_ch': 3,
128
+ 'ch': 128,
129
+ 'ch_mult': [1, 2, 4, 4],
130
+ 'num_res_blocks': 2,
131
+ 'attn_resolutions': [],
132
+ 'dropout': 0.0,
133
+ 'video_kernel_size': [3, 1, 1]
134
+ },
135
+ 'embed_dim': 4,
136
+ 'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
137
+ }
138
+ # clip embedder
139
+ cfg.embedder = {
140
+ 'type': 'FrozenOpenCLIPEmbedder',
141
+ 'layer': 'penultimate',
142
+ 'pretrained': 'models/open_clip_pytorch_model.bin'
143
+ }
144
+ # -----------------------------------------------------------------------------
145
+
146
+ # ---------------------------Training Settings---------------------------------
147
+ # training and optimizer
148
+ cfg.ema_decay = 0.9999
149
+ cfg.num_steps = 600000
150
+ cfg.lr = 5e-5
151
+ cfg.weight_decay = 0.0
152
+ cfg.betas = (0.9, 0.999)
153
+ cfg.eps = 1.0e-8
154
+ cfg.chunk_size = 16
155
+ cfg.decoder_bs = 8
156
+ cfg.alpha = 0.7
157
+ cfg.save_ckp_interval = 1000
158
+
159
+ # scheduler
160
+ cfg.warmup_steps = 10
161
+ cfg.decay_mode = 'cosine'
162
+
163
+ # acceleration
164
+ cfg.use_ema = True
165
+ if world_size<2:
166
+ cfg.use_ema = False
167
+ cfg.load_from = None
168
+ # -----------------------------------------------------------------------------
169
+
170
+
171
+ # ----------------------------Pretrain Settings---------------------------------
172
+ cfg.Pretrain = {
173
+ 'type': 'pretrain_specific_strategies',
174
+ 'fix_weight': False,
175
+ 'grad_scale': 0.2,
176
+ 'resume_checkpoint': 'models/jiuniu_0267000.pth',
177
+ 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json',
178
+ }
179
+ # -----------------------------------------------------------------------------
180
+
181
+
182
+ # -----------------------------Visual-------------------------------------------
183
+ # Visual videos
184
+ cfg.viz_interval = 1000
185
+ cfg.visual_train = {
186
+ 'type': 'VisualTrainTextImageToVideo',
187
+ }
188
+ cfg.visual_inference = {
189
+ 'type': 'VisualGeneratedVideos',
190
+ }
191
+ cfg.inference_list_path = ''
192
+
193
+ # logging
194
+ cfg.log_interval = 100
195
+
196
+ ### Default log_dir
197
+ cfg.log_dir = 'outputs/'
198
+ # -----------------------------------------------------------------------------
199
+
200
+
201
+ # ---------------------------Others--------------------------------------------
202
+ # seed
203
+ cfg.seed = 8888
204
+ cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms'
205
+ # -----------------------------------------------------------------------------
206
+
UniAnimate/tools/modules/diffusions/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion_ddim import *
UniAnimate/tools/modules/diffusions/diffusion_ddim.py ADDED
@@ -0,0 +1,1121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from utils.registry_class import DIFFUSION
5
+ from .schedules import beta_schedule, sigma_schedule
6
+ from .losses import kl_divergence, discretized_gaussian_log_likelihood
7
+ # from .dpm_solver import NoiseScheduleVP, model_wrapper_guided_diffusion, model_wrapper, DPM_Solver
8
+ from typing import Callable, List, Optional
9
+ import numpy as np
10
+
11
+ def _i(tensor, t, x):
12
+ r"""Index tensor using t and format the output according to x.
13
+ """
14
+ if tensor.device != x.device:
15
+ tensor = tensor.to(x.device)
16
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
17
+ return tensor[t].view(shape).to(x)
18
+
19
+ @DIFFUSION.register_class()
20
+ class DiffusionDDIMSR(object):
21
+ def __init__(self, reverse_diffusion, forward_diffusion, **kwargs):
22
+ from .diffusion_gauss import GaussianDiffusion
23
+ self.reverse_diffusion = GaussianDiffusion(sigmas=sigma_schedule(reverse_diffusion.schedule, **reverse_diffusion.schedule_param),
24
+ prediction_type=reverse_diffusion.mean_type)
25
+ self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param),
26
+ prediction_type=forward_diffusion.mean_type)
27
+
28
+
29
+ @DIFFUSION.register_class()
30
+ class DiffusionDPM(object):
31
+ def __init__(self, forward_diffusion, **kwargs):
32
+ from .diffusion_gauss import GaussianDiffusion
33
+ self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param),
34
+ prediction_type=forward_diffusion.mean_type)
35
+
36
+
37
+ @DIFFUSION.register_class()
38
+ class DiffusionDDIM(object):
39
+ def __init__(self,
40
+ schedule='linear_sd',
41
+ schedule_param={},
42
+ mean_type='eps',
43
+ var_type='learned_range',
44
+ loss_type='mse',
45
+ epsilon = 1e-12,
46
+ rescale_timesteps=False,
47
+ noise_strength=0.0,
48
+ **kwargs):
49
+
50
+ assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
51
+ assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small']
52
+ assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
53
+
54
+ betas = beta_schedule(schedule, **schedule_param)
55
+ assert min(betas) > 0 and max(betas) <= 1
56
+
57
+ if not isinstance(betas, torch.DoubleTensor):
58
+ betas = torch.tensor(betas, dtype=torch.float64)
59
+
60
+ self.betas = betas
61
+ self.num_timesteps = len(betas)
62
+ self.mean_type = mean_type # eps
63
+ self.var_type = var_type # 'fixed_small'
64
+ self.loss_type = loss_type # mse
65
+ self.epsilon = epsilon # 1e-12
66
+ self.rescale_timesteps = rescale_timesteps # False
67
+ self.noise_strength = noise_strength # 0.0
68
+
69
+ # alphas
70
+ alphas = 1 - self.betas
71
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
72
+ self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
73
+ self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])])
74
+
75
+ # q(x_t | x_{t-1})
76
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
77
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
78
+ self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
79
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
80
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
81
+
82
+ # q(x_{t-1} | x_t, x_0)
83
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
84
+ self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
85
+ self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
86
+ self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
87
+
88
+
89
+ def sample_loss(self, x0, noise=None):
90
+ if noise is None:
91
+ noise = torch.randn_like(x0)
92
+ if self.noise_strength > 0:
93
+ b, c, f, _, _= x0.shape
94
+ offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
95
+ noise = noise + self.noise_strength * offset_noise
96
+ return noise
97
+
98
+
99
+ def q_sample(self, x0, t, noise=None):
100
+ r"""Sample from q(x_t | x_0).
101
+ """
102
+ # noise = torch.randn_like(x0) if noise is None else noise
103
+ noise = self.sample_loss(x0, noise)
104
+ return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
105
+ _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
106
+
107
+ def q_mean_variance(self, x0, t):
108
+ r"""Distribution of q(x_t | x_0).
109
+ """
110
+ mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
111
+ var = _i(1.0 - self.alphas_cumprod, t, x0)
112
+ log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
113
+ return mu, var, log_var
114
+
115
+ def q_posterior_mean_variance(self, x0, xt, t):
116
+ r"""Distribution of q(x_{t-1} | x_t, x_0).
117
+ """
118
+ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt
119
+ var = _i(self.posterior_variance, t, xt)
120
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
121
+ return mu, var, log_var
122
+
123
+ @torch.no_grad()
124
+ def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
125
+ r"""Sample from p(x_{t-1} | x_t).
126
+ - condition_fn: for classifier-based guidance (guided-diffusion).
127
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
128
+ """
129
+ # predict distribution of p(x_{t-1} | x_t)
130
+ mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
131
+
132
+ # random sample (with optional conditional function)
133
+ noise = torch.randn_like(xt)
134
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0
135
+ if condition_fn is not None:
136
+ grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
137
+ mu = mu.float() + var * grad.float()
138
+ xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
139
+ return xt_1, x0
140
+
141
+ @torch.no_grad()
142
+ def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
143
+ r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
144
+ """
145
+ # prepare input
146
+ b = noise.size(0)
147
+ xt = noise
148
+
149
+ # diffusion process
150
+ for step in torch.arange(self.num_timesteps).flip(0):
151
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
152
+ xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale)
153
+ return xt
154
+
155
+ def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None):
156
+ r"""Distribution of p(x_{t-1} | x_t).
157
+ """
158
+ # predict distribution
159
+ if guide_scale is None:
160
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
161
+ else:
162
+ # classifier-free guidance
163
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
164
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
165
+ y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
166
+ u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
167
+ dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
168
+ out = torch.cat([
169
+ u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]),
170
+ y_out[:, dim:]], dim=1) # guide_scale=9.0
171
+
172
+ # compute variance
173
+ if self.var_type == 'learned':
174
+ out, log_var = out.chunk(2, dim=1)
175
+ var = torch.exp(log_var)
176
+ elif self.var_type == 'learned_range':
177
+ out, fraction = out.chunk(2, dim=1)
178
+ min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
179
+ max_log_var = _i(torch.log(self.betas), t, xt)
180
+ fraction = (fraction + 1) / 2.0
181
+ log_var = fraction * max_log_var + (1 - fraction) * min_log_var
182
+ var = torch.exp(log_var)
183
+ elif self.var_type == 'fixed_large':
184
+ var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
185
+ log_var = torch.log(var)
186
+ elif self.var_type == 'fixed_small':
187
+ var = _i(self.posterior_variance, t, xt)
188
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
189
+
190
+ # compute mean and x0
191
+ if self.mean_type == 'x_{t-1}':
192
+ mu = out # x_{t-1}
193
+ x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
194
+ _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
195
+ elif self.mean_type == 'x0':
196
+ x0 = out
197
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
198
+ elif self.mean_type == 'eps':
199
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
200
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
201
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
202
+ elif self.mean_type == 'v':
203
+ x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
204
+ _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
205
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
206
+
207
+ # restrict the range of x0
208
+ if percentile is not None:
209
+ assert percentile > 0 and percentile <= 1 # e.g., 0.995
210
+ s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
211
+ x0 = torch.min(s, torch.max(-s, x0)) / s
212
+ elif clamp is not None:
213
+ x0 = x0.clamp(-clamp, clamp)
214
+ return mu, var, log_var, x0
215
+
216
+ @torch.no_grad()
217
+ def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
218
+ r"""Sample from p(x_{t-1} | x_t) using DDIM.
219
+ - condition_fn: for classifier-based guidance (guided-diffusion).
220
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
221
+ """
222
+ stride = self.num_timesteps // ddim_timesteps
223
+
224
+ # predict distribution of p(x_{t-1} | x_t)
225
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
226
+ if condition_fn is not None:
227
+ # x0 -> eps
228
+ alpha = _i(self.alphas_cumprod, t, xt)
229
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
230
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
231
+ eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
232
+
233
+ # eps -> x0
234
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
235
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
236
+
237
+ # derive variables
238
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
239
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
240
+ alphas = _i(self.alphas_cumprod, t, xt)
241
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
242
+ sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
243
+
244
+ # random sample
245
+ noise = torch.randn_like(xt)
246
+ direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
247
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
248
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
249
+ return xt_1, x0
250
+
251
+ @torch.no_grad()
252
+ def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
253
+ # prepare input
254
+ b = noise.size(0)
255
+ xt = noise
256
+
257
+ # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
258
+ steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
259
+ from tqdm import tqdm
260
+ for step in tqdm(steps):
261
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
262
+ xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
263
+ # from ipdb import set_trace; set_trace()
264
+ return xt
265
+
266
+ @torch.no_grad()
267
+ def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
268
+ r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
269
+ """
270
+ stride = self.num_timesteps // ddim_timesteps
271
+
272
+ # predict distribution of p(x_{t-1} | x_t)
273
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
274
+
275
+ # derive variables
276
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
277
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
278
+ alphas_next = _i(
279
+ torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]),
280
+ (t + stride).clamp(0, self.num_timesteps), xt)
281
+
282
+ # reverse sample
283
+ mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
284
+ return mu, x0
285
+
286
+ @torch.no_grad()
287
+ def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
288
+ # prepare input
289
+ b = x0.size(0)
290
+ xt = x0
291
+
292
+ # reconstruction steps
293
+ steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)
294
+ for step in steps:
295
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
296
+ xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps)
297
+ return xt
298
+
299
+ @torch.no_grad()
300
+ def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
301
+ r"""Sample from p(x_{t-1} | x_t) using PLMS.
302
+ - condition_fn: for classifier-based guidance (guided-diffusion).
303
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
304
+ """
305
+ stride = self.num_timesteps // plms_timesteps
306
+
307
+ # function for compute eps
308
+ def compute_eps(xt, t):
309
+ # predict distribution of p(x_{t-1} | x_t)
310
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
311
+
312
+ # condition
313
+ if condition_fn is not None:
314
+ # x0 -> eps
315
+ alpha = _i(self.alphas_cumprod, t, xt)
316
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
317
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
318
+ eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
319
+
320
+ # eps -> x0
321
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
322
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
323
+
324
+ # derive eps
325
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
326
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
327
+ return eps
328
+
329
+ # function for compute x_0 and x_{t-1}
330
+ def compute_x0(eps, t):
331
+ # eps -> x0
332
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
333
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
334
+
335
+ # deterministic sample
336
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
337
+ direction = torch.sqrt(1 - alphas_prev) * eps
338
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
339
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction
340
+ return xt_1, x0
341
+
342
+ # PLMS sample
343
+ eps = compute_eps(xt, t)
344
+ if len(eps_cache) == 0:
345
+ # 2nd order pseudo improved Euler
346
+ xt_1, x0 = compute_x0(eps, t)
347
+ eps_next = compute_eps(xt_1, (t - stride).clamp(0))
348
+ eps_prime = (eps + eps_next) / 2.0
349
+ elif len(eps_cache) == 1:
350
+ # 2nd order pseudo linear multistep (Adams-Bashforth)
351
+ eps_prime = (3 * eps - eps_cache[-1]) / 2.0
352
+ elif len(eps_cache) == 2:
353
+ # 3nd order pseudo linear multistep (Adams-Bashforth)
354
+ eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0
355
+ elif len(eps_cache) >= 3:
356
+ # 4nd order pseudo linear multistep (Adams-Bashforth)
357
+ eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0
358
+ xt_1, x0 = compute_x0(eps_prime, t)
359
+ return xt_1, x0, eps
360
+
361
+ @torch.no_grad()
362
+ def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
363
+ # prepare input
364
+ b = noise.size(0)
365
+ xt = noise
366
+
367
+ # diffusion process
368
+ steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
369
+ eps_cache = []
370
+ for step in steps:
371
+ # PLMS sampling step
372
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
373
+ xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache)
374
+
375
+ # update eps cache
376
+ eps_cache.append(eps)
377
+ if len(eps_cache) >= 4:
378
+ eps_cache.pop(0)
379
+ return xt
380
+
381
+ def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None):
382
+
383
+ # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32]
384
+ noise = self.sample_loss(x0, noise)
385
+
386
+ xt = self.q_sample(x0, t, noise=noise)
387
+
388
+ # compute loss
389
+ if self.loss_type in ['kl', 'rescaled_kl']:
390
+ loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
391
+ if self.loss_type == 'rescaled_kl':
392
+ loss = loss * self.num_timesteps
393
+ elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse
394
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
395
+
396
+ # VLB for variation
397
+ loss_vlb = 0.0
398
+ if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small'
399
+ out, var = out.chunk(2, dim=1)
400
+ frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
401
+ loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
402
+ if self.loss_type.startswith('rescaled_'):
403
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
404
+
405
+ # MSE/L1 for x0/eps
406
+ # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
407
+ target = {
408
+ 'eps': noise,
409
+ 'x0': x0,
410
+ 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0],
411
+ 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type]
412
+ if loss_mask is not None:
413
+ loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same)
414
+ loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w
415
+ # use masked diffusion
416
+ loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
417
+ else:
418
+ loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
419
+ if weight is not None:
420
+ loss = loss*weight
421
+
422
+ # div loss
423
+ if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1:
424
+
425
+ # derive x0
426
+ x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
427
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
428
+
429
+ # # derive xt_1, set eta=0 as ddim
430
+ # alphas_prev = _i(self.alphas_cumprod, (t - 1).clamp(0), xt)
431
+ # direction = torch.sqrt(1 - alphas_prev) * out
432
+ # xt_1 = torch.sqrt(alphas_prev) * x0_ + direction
433
+
434
+ # ncfhw, std on f
435
+ div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4)
436
+ # print(div_loss,loss)
437
+ loss = loss+div_loss
438
+
439
+ # total loss
440
+ loss = loss + loss_vlb
441
+ elif self.loss_type in ['charbonnier']:
442
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
443
+
444
+ # VLB for variation
445
+ loss_vlb = 0.0
446
+ if self.var_type in ['learned', 'learned_range']:
447
+ out, var = out.chunk(2, dim=1)
448
+ frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
449
+ loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
450
+ if self.loss_type.startswith('rescaled_'):
451
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
452
+
453
+ # MSE/L1 for x0/eps
454
+ target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
455
+ loss = torch.sqrt((out - target)**2 + self.epsilon)
456
+ if weight is not None:
457
+ loss = loss*weight
458
+ loss = loss.flatten(1).mean(dim=1)
459
+
460
+ # total loss
461
+ loss = loss + loss_vlb
462
+ return loss
463
+
464
+ def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None):
465
+ # compute groundtruth and predicted distributions
466
+ mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
467
+ mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile)
468
+
469
+ # compute KL loss
470
+ kl = kl_divergence(mu1, log_var1, mu2, log_var2)
471
+ kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
472
+
473
+ # compute discretized NLL loss (for p(x0 | x1) only)
474
+ nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2)
475
+ nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
476
+
477
+ # NLL for p(x0 | x1) and KL otherwise
478
+ vlb = torch.where(t == 0, nll, kl)
479
+ return vlb, x0
480
+
481
+ @torch.no_grad()
482
+ def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None):
483
+ r"""Compute the entire variational lower bound, measured in bits-per-dim.
484
+ """
485
+ # prepare input and output
486
+ b = x0.size(0)
487
+ metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
488
+
489
+ # loop
490
+ for step in torch.arange(self.num_timesteps).flip(0):
491
+ # compute VLB
492
+ t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
493
+ # noise = torch.randn_like(x0)
494
+ noise = self.sample_loss(x0)
495
+ xt = self.q_sample(x0, t, noise)
496
+ vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile)
497
+
498
+ # predict eps from x0
499
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
500
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
501
+
502
+ # collect metrics
503
+ metrics['vlb'].append(vlb)
504
+ metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1))
505
+ metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1))
506
+ metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
507
+
508
+ # compute the prior KL term for VLB, measured in bits-per-dim
509
+ mu, _, log_var = self.q_mean_variance(x0, t)
510
+ kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var))
511
+ kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
512
+
513
+ # update metrics
514
+ metrics['prior_bits_per_dim'] = kl_prior
515
+ metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
516
+ return metrics
517
+
518
+ def _scale_timesteps(self, t):
519
+ if self.rescale_timesteps:
520
+ return t.float() * 1000.0 / self.num_timesteps
521
+ return t
522
+ #return t.float()
523
+
524
+
525
+
526
+
527
+
528
+
529
+ @DIFFUSION.register_class()
530
+ class DiffusionDDIMLong(object):
531
+ def __init__(self,
532
+ schedule='linear_sd',
533
+ schedule_param={},
534
+ mean_type='eps',
535
+ var_type='learned_range',
536
+ loss_type='mse',
537
+ epsilon = 1e-12,
538
+ rescale_timesteps=False,
539
+ noise_strength=0.0,
540
+ **kwargs):
541
+
542
+ assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
543
+ assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small']
544
+ assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
545
+
546
+ betas = beta_schedule(schedule, **schedule_param)
547
+ assert min(betas) > 0 and max(betas) <= 1
548
+
549
+ if not isinstance(betas, torch.DoubleTensor):
550
+ betas = torch.tensor(betas, dtype=torch.float64)
551
+
552
+ self.betas = betas
553
+ self.num_timesteps = len(betas)
554
+ self.mean_type = mean_type # v
555
+ self.var_type = var_type # 'fixed_small'
556
+ self.loss_type = loss_type # mse
557
+ self.epsilon = epsilon # 1e-12
558
+ self.rescale_timesteps = rescale_timesteps # False
559
+ self.noise_strength = noise_strength
560
+
561
+ # alphas
562
+ alphas = 1 - self.betas
563
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
564
+ self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
565
+ self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])])
566
+
567
+ # q(x_t | x_{t-1})
568
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
569
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
570
+ self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
571
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
572
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
573
+
574
+ # q(x_{t-1} | x_t, x_0)
575
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
576
+ self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
577
+ self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
578
+ self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
579
+
580
+
581
+ def sample_loss(self, x0, noise=None):
582
+ if noise is None:
583
+ noise = torch.randn_like(x0)
584
+ if self.noise_strength > 0:
585
+ b, c, f, _, _= x0.shape
586
+ offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
587
+ noise = noise + self.noise_strength * offset_noise
588
+ return noise
589
+
590
+
591
+ def q_sample(self, x0, t, noise=None):
592
+ r"""Sample from q(x_t | x_0).
593
+ """
594
+ # noise = torch.randn_like(x0) if noise is None else noise
595
+ noise = self.sample_loss(x0, noise)
596
+ return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
597
+ _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
598
+
599
+ def q_mean_variance(self, x0, t):
600
+ r"""Distribution of q(x_t | x_0).
601
+ """
602
+ mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
603
+ var = _i(1.0 - self.alphas_cumprod, t, x0)
604
+ log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
605
+ return mu, var, log_var
606
+
607
+ def q_posterior_mean_variance(self, x0, xt, t):
608
+ r"""Distribution of q(x_{t-1} | x_t, x_0).
609
+ """
610
+ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt
611
+ var = _i(self.posterior_variance, t, xt)
612
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
613
+ return mu, var, log_var
614
+
615
+ @torch.no_grad()
616
+ def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
617
+ r"""Sample from p(x_{t-1} | x_t).
618
+ - condition_fn: for classifier-based guidance (guided-diffusion).
619
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
620
+ """
621
+ # predict distribution of p(x_{t-1} | x_t)
622
+ mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
623
+
624
+ # random sample (with optional conditional function)
625
+ noise = torch.randn_like(xt)
626
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0
627
+ if condition_fn is not None:
628
+ grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
629
+ mu = mu.float() + var * grad.float()
630
+ xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
631
+ return xt_1, x0
632
+
633
+ @torch.no_grad()
634
+ def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
635
+ r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
636
+ """
637
+ # prepare input
638
+ b = noise.size(0)
639
+ xt = noise
640
+
641
+ # diffusion process
642
+ for step in torch.arange(self.num_timesteps).flip(0):
643
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
644
+ xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale)
645
+ return xt
646
+
647
+ def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1):
648
+ r"""Distribution of p(x_{t-1} | x_t).
649
+ """
650
+ noise = xt
651
+ context_queue = list(
652
+ context_scheduler(
653
+ 0,
654
+ 31,
655
+ noise.shape[2],
656
+ context_size=context_size,
657
+ context_stride=context_stride,
658
+ context_overlap=context_overlap,
659
+ )
660
+ )
661
+ context_step = min(
662
+ context_stride, int(np.ceil(np.log2(noise.shape[2] / context_size))) + 1
663
+ )
664
+ # replace the final segment to improve temporal consistency
665
+ num_frames = noise.shape[2]
666
+ context_queue[-1] = [
667
+ e % num_frames
668
+ for e in range(num_frames - context_size * context_step, num_frames, context_step)
669
+ ]
670
+
671
+ import math
672
+ # context_batch_size = 1
673
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
674
+ global_context = []
675
+ for i in range(num_context_batches):
676
+ global_context.append(
677
+ context_queue[
678
+ i * context_batch_size : (i + 1) * context_batch_size
679
+ ]
680
+ )
681
+ noise_pred = torch.zeros_like(noise)
682
+ noise_pred_uncond = torch.zeros_like(noise)
683
+ counter = torch.zeros(
684
+ (1, 1, xt.shape[2], 1, 1),
685
+ device=xt.device,
686
+ dtype=xt.dtype,
687
+ )
688
+
689
+ for i_index, context in enumerate(global_context):
690
+
691
+
692
+ latent_model_input = torch.cat([xt[:, :, c] for c in context])
693
+ bs_context = len(context)
694
+
695
+ model_kwargs_new = [{
696
+ 'y': None,
697
+ "local_image": None if not model_kwargs[0].__contains__('local_image') else torch.cat([model_kwargs[0]["local_image"][:, :, c] for c in context]),
698
+ 'image': None if not model_kwargs[0].__contains__('image') else model_kwargs[0]["image"].repeat(bs_context, 1, 1),
699
+ 'dwpose': None if not model_kwargs[0].__contains__('dwpose') else torch.cat([model_kwargs[0]["dwpose"][:, :, [0]+[ii+1 for ii in c]] for c in context]),
700
+ 'randomref': None if not model_kwargs[0].__contains__('randomref') else torch.cat([model_kwargs[0]["randomref"][:, :, c] for c in context]),
701
+ },
702
+ {
703
+ 'y': None,
704
+ "local_image": None,
705
+ 'image': None,
706
+ 'randomref': None,
707
+ 'dwpose': None,
708
+ }]
709
+
710
+ if guide_scale is None:
711
+ out = model(latent_model_input, self._scale_timesteps(t), **model_kwargs)
712
+ for j, c in enumerate(context):
713
+ noise_pred[:, :, c] = noise_pred[:, :, c] + out
714
+ counter[:, :, c] = counter[:, :, c] + 1
715
+ else:
716
+ # classifier-free guidance
717
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
718
+ # assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
719
+ y_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[0])
720
+ u_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[1])
721
+ dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
722
+ for j, c in enumerate(context):
723
+ noise_pred[:, :, c] = noise_pred[:, :, c] + y_out[j:j+1]
724
+ noise_pred_uncond[:, :, c] = noise_pred_uncond[:, :, c] + u_out[j:j+1]
725
+ counter[:, :, c] = counter[:, :, c] + 1
726
+
727
+ noise_pred = noise_pred / counter
728
+ noise_pred_uncond = noise_pred_uncond / counter
729
+ out = torch.cat([
730
+ noise_pred_uncond[:, :dim] + guide_scale * (noise_pred[:, :dim] - noise_pred_uncond[:, :dim]),
731
+ noise_pred[:, dim:]], dim=1) # guide_scale=2.5
732
+
733
+
734
+ # compute variance
735
+ if self.var_type == 'learned':
736
+ out, log_var = out.chunk(2, dim=1)
737
+ var = torch.exp(log_var)
738
+ elif self.var_type == 'learned_range':
739
+ out, fraction = out.chunk(2, dim=1)
740
+ min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
741
+ max_log_var = _i(torch.log(self.betas), t, xt)
742
+ fraction = (fraction + 1) / 2.0
743
+ log_var = fraction * max_log_var + (1 - fraction) * min_log_var
744
+ var = torch.exp(log_var)
745
+ elif self.var_type == 'fixed_large':
746
+ var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
747
+ log_var = torch.log(var)
748
+ elif self.var_type == 'fixed_small':
749
+ var = _i(self.posterior_variance, t, xt)
750
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
751
+
752
+ # compute mean and x0
753
+ if self.mean_type == 'x_{t-1}':
754
+ mu = out # x_{t-1}
755
+ x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
756
+ _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
757
+ elif self.mean_type == 'x0':
758
+ x0 = out
759
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
760
+ elif self.mean_type == 'eps':
761
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
762
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
763
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
764
+ elif self.mean_type == 'v':
765
+ x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
766
+ _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
767
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
768
+
769
+ # restrict the range of x0
770
+ if percentile is not None:
771
+ assert percentile > 0 and percentile <= 1 # e.g., 0.995
772
+ s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
773
+ x0 = torch.min(s, torch.max(-s, x0)) / s
774
+ elif clamp is not None:
775
+ x0 = x0.clamp(-clamp, clamp)
776
+ return mu, var, log_var, x0
777
+
778
+ @torch.no_grad()
779
+ def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1):
780
+ r"""Sample from p(x_{t-1} | x_t) using DDIM.
781
+ - condition_fn: for classifier-based guidance (guided-diffusion).
782
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
783
+ """
784
+ stride = self.num_timesteps // ddim_timesteps
785
+
786
+ # predict distribution of p(x_{t-1} | x_t)
787
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale, context_size, context_stride, context_overlap, context_batch_size)
788
+ if condition_fn is not None:
789
+ # x0 -> eps
790
+ alpha = _i(self.alphas_cumprod, t, xt)
791
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
792
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
793
+ eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
794
+
795
+ # eps -> x0
796
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
797
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
798
+
799
+ # derive variables
800
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
801
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
802
+ alphas = _i(self.alphas_cumprod, t, xt)
803
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
804
+ sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
805
+
806
+ # random sample
807
+ noise = torch.randn_like(xt)
808
+ direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
809
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
810
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
811
+ return xt_1, x0
812
+
813
+ @torch.no_grad()
814
+ def ddim_sample_loop(self, noise, context_size, context_stride, context_overlap, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_batch_size=1):
815
+ # prepare input
816
+ b = noise.size(0)
817
+ xt = noise
818
+
819
+ # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
820
+ steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
821
+ from tqdm import tqdm
822
+
823
+ for step in tqdm(steps):
824
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
825
+ xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta, context_size=context_size, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size)
826
+ return xt
827
+
828
+ @torch.no_grad()
829
+ def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
830
+ r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
831
+ """
832
+ stride = self.num_timesteps // ddim_timesteps
833
+
834
+ # predict distribution of p(x_{t-1} | x_t)
835
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
836
+
837
+ # derive variables
838
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
839
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
840
+ alphas_next = _i(
841
+ torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]),
842
+ (t + stride).clamp(0, self.num_timesteps), xt)
843
+
844
+ # reverse sample
845
+ mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
846
+ return mu, x0
847
+
848
+ @torch.no_grad()
849
+ def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
850
+ # prepare input
851
+ b = x0.size(0)
852
+ xt = x0
853
+
854
+ # reconstruction steps
855
+ steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)
856
+ for step in steps:
857
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
858
+ xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps)
859
+ return xt
860
+
861
+ @torch.no_grad()
862
+ def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
863
+ r"""Sample from p(x_{t-1} | x_t) using PLMS.
864
+ - condition_fn: for classifier-based guidance (guided-diffusion).
865
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
866
+ """
867
+ stride = self.num_timesteps // plms_timesteps
868
+
869
+ # function for compute eps
870
+ def compute_eps(xt, t):
871
+ # predict distribution of p(x_{t-1} | x_t)
872
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
873
+
874
+ # condition
875
+ if condition_fn is not None:
876
+ # x0 -> eps
877
+ alpha = _i(self.alphas_cumprod, t, xt)
878
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
879
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
880
+ eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
881
+
882
+ # eps -> x0
883
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
884
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
885
+
886
+ # derive eps
887
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
888
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
889
+ return eps
890
+
891
+ # function for compute x_0 and x_{t-1}
892
+ def compute_x0(eps, t):
893
+ # eps -> x0
894
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
895
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
896
+
897
+ # deterministic sample
898
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
899
+ direction = torch.sqrt(1 - alphas_prev) * eps
900
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
901
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction
902
+ return xt_1, x0
903
+
904
+ # PLMS sample
905
+ eps = compute_eps(xt, t)
906
+ if len(eps_cache) == 0:
907
+ # 2nd order pseudo improved Euler
908
+ xt_1, x0 = compute_x0(eps, t)
909
+ eps_next = compute_eps(xt_1, (t - stride).clamp(0))
910
+ eps_prime = (eps + eps_next) / 2.0
911
+ elif len(eps_cache) == 1:
912
+ # 2nd order pseudo linear multistep (Adams-Bashforth)
913
+ eps_prime = (3 * eps - eps_cache[-1]) / 2.0
914
+ elif len(eps_cache) == 2:
915
+ # 3nd order pseudo linear multistep (Adams-Bashforth)
916
+ eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0
917
+ elif len(eps_cache) >= 3:
918
+ # 4nd order pseudo linear multistep (Adams-Bashforth)
919
+ eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0
920
+ xt_1, x0 = compute_x0(eps_prime, t)
921
+ return xt_1, x0, eps
922
+
923
+ @torch.no_grad()
924
+ def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
925
+ # prepare input
926
+ b = noise.size(0)
927
+ xt = noise
928
+
929
+ # diffusion process
930
+ steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
931
+ eps_cache = []
932
+ for step in steps:
933
+ # PLMS sampling step
934
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
935
+ xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache)
936
+
937
+ # update eps cache
938
+ eps_cache.append(eps)
939
+ if len(eps_cache) >= 4:
940
+ eps_cache.pop(0)
941
+ return xt
942
+
943
+ def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None):
944
+
945
+ # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32]
946
+ noise = self.sample_loss(x0, noise)
947
+
948
+ xt = self.q_sample(x0, t, noise=noise)
949
+
950
+ # compute loss
951
+ if self.loss_type in ['kl', 'rescaled_kl']:
952
+ loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
953
+ if self.loss_type == 'rescaled_kl':
954
+ loss = loss * self.num_timesteps
955
+ elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse
956
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
957
+
958
+ # VLB for variation
959
+ loss_vlb = 0.0
960
+ if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small'
961
+ out, var = out.chunk(2, dim=1)
962
+ frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
963
+ loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
964
+ if self.loss_type.startswith('rescaled_'):
965
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
966
+
967
+ # MSE/L1 for x0/eps
968
+ # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
969
+ target = {
970
+ 'eps': noise,
971
+ 'x0': x0,
972
+ 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0],
973
+ 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type]
974
+ if loss_mask is not None:
975
+ loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same)
976
+ loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w
977
+ # use masked diffusion
978
+ loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
979
+ else:
980
+ loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
981
+ if weight is not None:
982
+ loss = loss*weight
983
+
984
+ # div loss
985
+ if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1:
986
+
987
+ # derive x0
988
+ x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
989
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
990
+
991
+
992
+ # ncfhw, std on f
993
+ div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4)
994
+ # print(div_loss,loss)
995
+ loss = loss+div_loss
996
+
997
+ # total loss
998
+ loss = loss + loss_vlb
999
+ elif self.loss_type in ['charbonnier']:
1000
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
1001
+
1002
+ # VLB for variation
1003
+ loss_vlb = 0.0
1004
+ if self.var_type in ['learned', 'learned_range']:
1005
+ out, var = out.chunk(2, dim=1)
1006
+ frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
1007
+ loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
1008
+ if self.loss_type.startswith('rescaled_'):
1009
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
1010
+
1011
+ # MSE/L1 for x0/eps
1012
+ target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
1013
+ loss = torch.sqrt((out - target)**2 + self.epsilon)
1014
+ if weight is not None:
1015
+ loss = loss*weight
1016
+ loss = loss.flatten(1).mean(dim=1)
1017
+
1018
+ # total loss
1019
+ loss = loss + loss_vlb
1020
+ return loss
1021
+
1022
+ def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None):
1023
+ # compute groundtruth and predicted distributions
1024
+ mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
1025
+ mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile)
1026
+
1027
+ # compute KL loss
1028
+ kl = kl_divergence(mu1, log_var1, mu2, log_var2)
1029
+ kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
1030
+
1031
+ # compute discretized NLL loss (for p(x0 | x1) only)
1032
+ nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2)
1033
+ nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
1034
+
1035
+ # NLL for p(x0 | x1) and KL otherwise
1036
+ vlb = torch.where(t == 0, nll, kl)
1037
+ return vlb, x0
1038
+
1039
+ @torch.no_grad()
1040
+ def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None):
1041
+ r"""Compute the entire variational lower bound, measured in bits-per-dim.
1042
+ """
1043
+ # prepare input and output
1044
+ b = x0.size(0)
1045
+ metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
1046
+
1047
+ # loop
1048
+ for step in torch.arange(self.num_timesteps).flip(0):
1049
+ # compute VLB
1050
+ t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
1051
+ # noise = torch.randn_like(x0)
1052
+ noise = self.sample_loss(x0)
1053
+ xt = self.q_sample(x0, t, noise)
1054
+ vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile)
1055
+
1056
+ # predict eps from x0
1057
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
1058
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
1059
+
1060
+ # collect metrics
1061
+ metrics['vlb'].append(vlb)
1062
+ metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1))
1063
+ metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1))
1064
+ metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
1065
+
1066
+ # compute the prior KL term for VLB, measured in bits-per-dim
1067
+ mu, _, log_var = self.q_mean_variance(x0, t)
1068
+ kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var))
1069
+ kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
1070
+
1071
+ # update metrics
1072
+ metrics['prior_bits_per_dim'] = kl_prior
1073
+ metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
1074
+ return metrics
1075
+
1076
+ def _scale_timesteps(self, t):
1077
+ if self.rescale_timesteps:
1078
+ return t.float() * 1000.0 / self.num_timesteps
1079
+ return t
1080
+ #return t.float()
1081
+
1082
+
1083
+
1084
+ def ordered_halving(val):
1085
+ bin_str = f"{val:064b}"
1086
+ bin_flip = bin_str[::-1]
1087
+ as_int = int(bin_flip, 2)
1088
+
1089
+ return as_int / (1 << 64)
1090
+
1091
+
1092
+ def context_scheduler(
1093
+ step: int = ...,
1094
+ num_steps: Optional[int] = None,
1095
+ num_frames: int = ...,
1096
+ context_size: Optional[int] = None,
1097
+ context_stride: int = 3,
1098
+ context_overlap: int = 4,
1099
+ closed_loop: bool = False,
1100
+ ):
1101
+ if num_frames <= context_size:
1102
+ yield list(range(num_frames))
1103
+ return
1104
+
1105
+ context_stride = min(
1106
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
1107
+ )
1108
+
1109
+ for context_step in 1 << np.arange(context_stride):
1110
+ pad = int(round(num_frames * ordered_halving(step)))
1111
+ for j in range(
1112
+ int(ordered_halving(step) * context_step) + pad,
1113
+ num_frames + pad + (0 if closed_loop else -context_overlap),
1114
+ (context_size * context_step - context_overlap),
1115
+ ):
1116
+
1117
+ yield [
1118
+ e % num_frames
1119
+ for e in range(j, j + context_size * context_step, context_step)
1120
+ ]
1121
+
UniAnimate/tools/modules/diffusions/diffusion_gauss.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GaussianDiffusion wraps operators for denoising diffusion models, including the
3
+ diffusion and denoising processes, as well as the loss evaluation.
4
+ """
5
+ import torch
6
+ import torchsde
7
+ import random
8
+ from tqdm.auto import trange
9
+
10
+
11
+ __all__ = ['GaussianDiffusion']
12
+
13
+
14
+ def _i(tensor, t, x):
15
+ """
16
+ Index tensor using t and format the output according to x.
17
+ """
18
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
19
+ return tensor[t.to(tensor.device)].view(shape).to(x.device)
20
+
21
+
22
+ class BatchedBrownianTree:
23
+ """
24
+ A wrapper around torchsde.BrownianTree that enables batches of entropy.
25
+ """
26
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
27
+ t0, t1, self.sign = self.sort(t0, t1)
28
+ w0 = kwargs.get('w0', torch.zeros_like(x))
29
+ if seed is None:
30
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
31
+ self.batched = True
32
+ try:
33
+ assert len(seed) == x.shape[0]
34
+ w0 = w0[0]
35
+ except TypeError:
36
+ seed = [seed]
37
+ self.batched = False
38
+ self.trees = [torchsde.BrownianTree(
39
+ t0, w0, t1, entropy=s, **kwargs
40
+ ) for s in seed]
41
+
42
+ @staticmethod
43
+ def sort(a, b):
44
+ return (a, b, 1) if a < b else (b, a, -1)
45
+
46
+ def __call__(self, t0, t1):
47
+ t0, t1, sign = self.sort(t0, t1)
48
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
49
+ return w if self.batched else w[0]
50
+
51
+
52
+ class BrownianTreeNoiseSampler:
53
+ """
54
+ A noise sampler backed by a torchsde.BrownianTree.
55
+
56
+ Args:
57
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
58
+ random samples.
59
+ sigma_min (float): The low end of the valid interval.
60
+ sigma_max (float): The high end of the valid interval.
61
+ seed (int or List[int]): The random seed. If a list of seeds is
62
+ supplied instead of a single integer, then the noise sampler will
63
+ use one BrownianTree per batch item, each with its own seed.
64
+ transform (callable): A function that maps sigma to the sampler's
65
+ internal timestep.
66
+ """
67
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
68
+ self.transform = transform
69
+ t0 = self.transform(torch.as_tensor(sigma_min))
70
+ t1 = self.transform(torch.as_tensor(sigma_max))
71
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
72
+
73
+ def __call__(self, sigma, sigma_next):
74
+ t0 = self.transform(torch.as_tensor(sigma))
75
+ t1 = self.transform(torch.as_tensor(sigma_next))
76
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
77
+
78
+
79
+ def get_scalings(sigma):
80
+ c_out = -sigma
81
+ c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5
82
+ return c_out, c_in
83
+
84
+
85
+ @torch.no_grad()
86
+ def sample_dpmpp_2m_sde(
87
+ noise,
88
+ model,
89
+ sigmas,
90
+ eta=1.,
91
+ s_noise=1.,
92
+ solver_type='midpoint',
93
+ show_progress=True
94
+ ):
95
+ """
96
+ DPM-Solver++ (2M) SDE.
97
+ """
98
+ assert solver_type in {'heun', 'midpoint'}
99
+
100
+ x = noise * sigmas[0]
101
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max()
102
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
103
+ old_denoised = None
104
+ h_last = None
105
+
106
+ for i in trange(len(sigmas) - 1, disable=not show_progress):
107
+ if sigmas[i] == float('inf'):
108
+ # Euler method
109
+ denoised = model(noise, sigmas[i])
110
+ x = denoised + sigmas[i + 1] * noise
111
+ else:
112
+ _, c_in = get_scalings(sigmas[i])
113
+ denoised = model(x * c_in, sigmas[i])
114
+ if sigmas[i + 1] == 0:
115
+ # Denoising step
116
+ x = denoised
117
+ else:
118
+ # DPM-Solver++(2M) SDE
119
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
120
+ h = s - t
121
+ eta_h = eta * h
122
+
123
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
124
+ (-h - eta_h).expm1().neg() * denoised
125
+
126
+ if old_denoised is not None:
127
+ r = h_last / h
128
+ if solver_type == 'heun':
129
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
130
+ (1 / r) * (denoised - old_denoised)
131
+ elif solver_type == 'midpoint':
132
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * \
133
+ (1 / r) * (denoised - old_denoised)
134
+
135
+ x = x + noise_sampler(
136
+ sigmas[i],
137
+ sigmas[i + 1]
138
+ ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
139
+
140
+ old_denoised = denoised
141
+ h_last = h
142
+ return x
143
+
144
+
145
+ class GaussianDiffusion(object):
146
+
147
+ def __init__(self, sigmas, prediction_type='eps'):
148
+ assert prediction_type in {'x0', 'eps', 'v'}
149
+ self.sigmas = sigmas.float() # noise coefficients
150
+ self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients
151
+ self.num_timesteps = len(sigmas)
152
+ self.prediction_type = prediction_type
153
+
154
+ def diffuse(self, x0, t, noise=None):
155
+ """
156
+ Add Gaussian noise to signal x0 according to:
157
+ q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I).
158
+ """
159
+ noise = torch.randn_like(x0) if noise is None else noise
160
+ xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
161
+ return xt
162
+
163
+ def denoise(
164
+ self,
165
+ xt,
166
+ t,
167
+ s,
168
+ model,
169
+ model_kwargs={},
170
+ guide_scale=None,
171
+ guide_rescale=None,
172
+ clamp=None,
173
+ percentile=None
174
+ ):
175
+ """
176
+ Apply one step of denoising from the posterior distribution q(x_s | x_t, x0).
177
+ Since x0 is not available, estimate the denoising results using the learned
178
+ distribution p(x_s | x_t, \hat{x}_0 == f(x_t)).
179
+ """
180
+ s = t - 1 if s is None else s
181
+
182
+ # hyperparams
183
+ sigmas = _i(self.sigmas, t, xt)
184
+ alphas = _i(self.alphas, t, xt)
185
+ alphas_s = _i(self.alphas, s.clamp(0), xt)
186
+ alphas_s[s < 0] = 1.
187
+ sigmas_s = torch.sqrt(1 - alphas_s ** 2)
188
+
189
+ # precompute variables
190
+ betas = 1 - (alphas / alphas_s) ** 2
191
+ coef1 = betas * alphas_s / sigmas ** 2
192
+ coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2)
193
+ var = betas * (sigmas_s / sigmas) ** 2
194
+ log_var = torch.log(var).clamp_(-20, 20)
195
+
196
+ # prediction
197
+ if guide_scale is None:
198
+ assert isinstance(model_kwargs, dict)
199
+ out = model(xt, t=t, **model_kwargs)
200
+ else:
201
+ # classifier-free guidance (arXiv:2207.12598)
202
+ # model_kwargs[0]: conditional kwargs
203
+ # model_kwargs[1]: non-conditional kwargs
204
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
205
+ y_out = model(xt, t=t, **model_kwargs[0])
206
+ if guide_scale == 1.:
207
+ out = y_out
208
+ else:
209
+ u_out = model(xt, t=t, **model_kwargs[1])
210
+ out = u_out + guide_scale * (y_out - u_out)
211
+
212
+ # rescale the output according to arXiv:2305.08891
213
+ if guide_rescale is not None:
214
+ assert guide_rescale >= 0 and guide_rescale <= 1
215
+ ratio = (y_out.flatten(1).std(dim=1) / (
216
+ out.flatten(1).std(dim=1) + 1e-12
217
+ )).view((-1, ) + (1, ) * (y_out.ndim - 1))
218
+ out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
219
+
220
+ # compute x0
221
+ if self.prediction_type == 'x0':
222
+ x0 = out
223
+ elif self.prediction_type == 'eps':
224
+ x0 = (xt - sigmas * out) / alphas
225
+ elif self.prediction_type == 'v':
226
+ x0 = alphas * xt - sigmas * out
227
+ else:
228
+ raise NotImplementedError(
229
+ f'prediction_type {self.prediction_type} not implemented'
230
+ )
231
+
232
+ # restrict the range of x0
233
+ if percentile is not None:
234
+ # NOTE: percentile should only be used when data is within range [-1, 1]
235
+ assert percentile > 0 and percentile <= 1
236
+ s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
237
+ s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
238
+ x0 = torch.min(s, torch.max(-s, x0)) / s
239
+ elif clamp is not None:
240
+ x0 = x0.clamp(-clamp, clamp)
241
+
242
+ # recompute eps using the restricted x0
243
+ eps = (xt - alphas * x0) / sigmas
244
+
245
+ # compute mu (mean of posterior distribution) using the restricted x0
246
+ mu = coef1 * x0 + coef2 * xt
247
+ return mu, var, log_var, x0, eps
248
+
249
+ @torch.no_grad()
250
+ def sample(
251
+ self,
252
+ noise,
253
+ model,
254
+ model_kwargs={},
255
+ condition_fn=None,
256
+ guide_scale=None,
257
+ guide_rescale=None,
258
+ clamp=None,
259
+ percentile=None,
260
+ solver='euler_a',
261
+ steps=20,
262
+ t_max=None,
263
+ t_min=None,
264
+ discretization=None,
265
+ discard_penultimate_step=None,
266
+ return_intermediate=None,
267
+ show_progress=False,
268
+ seed=-1,
269
+ **kwargs
270
+ ):
271
+ # sanity check
272
+ assert isinstance(steps, (int, torch.LongTensor))
273
+ assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
274
+ assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
275
+ assert discretization in (None, 'leading', 'linspace', 'trailing')
276
+ assert discard_penultimate_step in (None, True, False)
277
+ assert return_intermediate in (None, 'x0', 'xt')
278
+
279
+ # function of diffusion solver
280
+ solver_fn = {
281
+ # 'heun': sample_heun,
282
+ 'dpmpp_2m_sde': sample_dpmpp_2m_sde
283
+ }[solver]
284
+
285
+ # options
286
+ schedule = 'karras' if 'karras' in solver else None
287
+ discretization = discretization or 'linspace'
288
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 31)
289
+ if isinstance(steps, torch.LongTensor):
290
+ discard_penultimate_step = False
291
+ if discard_penultimate_step is None:
292
+ discard_penultimate_step = True if solver in (
293
+ 'dpm2',
294
+ 'dpm2_ancestral',
295
+ 'dpmpp_2m_sde',
296
+ 'dpm2_karras',
297
+ 'dpm2_ancestral_karras',
298
+ 'dpmpp_2m_sde_karras'
299
+ ) else False
300
+
301
+ # function for denoising xt to get x0
302
+ intermediates = []
303
+ def model_fn(xt, sigma):
304
+ # denoising
305
+ t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
306
+ x0 = self.denoise(
307
+ xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
308
+ percentile
309
+ )[-2]
310
+
311
+ # collect intermediate outputs
312
+ if return_intermediate == 'xt':
313
+ intermediates.append(xt)
314
+ elif return_intermediate == 'x0':
315
+ intermediates.append(x0)
316
+ return x0
317
+
318
+ # get timesteps
319
+ if isinstance(steps, int):
320
+ steps += 1 if discard_penultimate_step else 0
321
+ t_max = self.num_timesteps - 1 if t_max is None else t_max
322
+ t_min = 0 if t_min is None else t_min
323
+
324
+ # discretize timesteps
325
+ if discretization == 'leading':
326
+ steps = torch.arange(
327
+ t_min, t_max + 1, (t_max - t_min + 1) / steps
328
+ ).flip(0)
329
+ elif discretization == 'linspace':
330
+ steps = torch.linspace(t_max, t_min, steps)
331
+ elif discretization == 'trailing':
332
+ steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps))
333
+ else:
334
+ raise NotImplementedError(
335
+ f'{discretization} discretization not implemented'
336
+ )
337
+ steps = steps.clamp_(t_min, t_max)
338
+ steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device)
339
+
340
+ # get sigmas
341
+ sigmas = self._t_to_sigma(steps)
342
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
343
+ if schedule == 'karras':
344
+ if sigmas[0] == float('inf'):
345
+ sigmas = karras_schedule(
346
+ n=len(steps) - 1,
347
+ sigma_min=sigmas[sigmas > 0].min().item(),
348
+ sigma_max=sigmas[sigmas < float('inf')].max().item(),
349
+ rho=7.
350
+ ).to(sigmas)
351
+ sigmas = torch.cat([
352
+ sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1])
353
+ ])
354
+ else:
355
+ sigmas = karras_schedule(
356
+ n=len(steps),
357
+ sigma_min=sigmas[sigmas > 0].min().item(),
358
+ sigma_max=sigmas.max().item(),
359
+ rho=7.
360
+ ).to(sigmas)
361
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
362
+ if discard_penultimate_step:
363
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
364
+
365
+ # sampling
366
+ x0 = solver_fn(
367
+ noise,
368
+ model_fn,
369
+ sigmas,
370
+ show_progress=show_progress,
371
+ **kwargs
372
+ )
373
+ return (x0, intermediates) if return_intermediate is not None else x0
374
+
375
+ @torch.no_grad()
376
+ def ddim_reverse_sample(
377
+ self,
378
+ xt,
379
+ t,
380
+ model,
381
+ model_kwargs={},
382
+ clamp=None,
383
+ percentile=None,
384
+ guide_scale=None,
385
+ guide_rescale=None,
386
+ ddim_timesteps=20,
387
+ reverse_steps=600
388
+ ):
389
+ r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
390
+ """
391
+ stride = reverse_steps // ddim_timesteps
392
+
393
+ # predict distribution of p(x_{t-1} | x_t)
394
+ _, _, _, x0, eps = self.denoise(
395
+ xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
396
+ percentile
397
+ )
398
+ # derive variables
399
+ s = (t + stride).clamp(0, reverse_steps-1)
400
+ # hyperparams
401
+ sigmas = _i(self.sigmas, t, xt)
402
+ alphas = _i(self.alphas, t, xt)
403
+ alphas_s = _i(self.alphas, s.clamp(0), xt)
404
+ alphas_s[s < 0] = 1.
405
+ sigmas_s = torch.sqrt(1 - alphas_s ** 2)
406
+
407
+ # reverse sample
408
+ mu = alphas_s * x0 + sigmas_s * eps
409
+ return mu, x0
410
+
411
+ @torch.no_grad()
412
+ def ddim_reverse_sample_loop(
413
+ self,
414
+ x0,
415
+ model,
416
+ model_kwargs={},
417
+ clamp=None,
418
+ percentile=None,
419
+ guide_scale=None,
420
+ guide_rescale=None,
421
+ ddim_timesteps=20,
422
+ reverse_steps=600
423
+ ):
424
+ # prepare input
425
+ b = x0.size(0)
426
+ xt = x0
427
+
428
+ # reconstruction steps
429
+ steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps)
430
+ for step in steps:
431
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
432
+ xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps)
433
+ return xt
434
+
435
+ def _sigma_to_t(self, sigma):
436
+ if sigma == float('inf'):
437
+ t = torch.full_like(sigma, len(self.sigmas) - 1)
438
+ else:
439
+ log_sigmas = torch.sqrt(
440
+ self.sigmas ** 2 / (1 - self.sigmas ** 2)
441
+ ).log().to(sigma)
442
+ log_sigma = sigma.log()
443
+ dists = log_sigma - log_sigmas[:, None]
444
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
445
+ max=log_sigmas.shape[0] - 2
446
+ )
447
+ high_idx = low_idx + 1
448
+ low, high = log_sigmas[low_idx], log_sigmas[high_idx]
449
+ w = (low - log_sigma) / (low - high)
450
+ w = w.clamp(0, 1)
451
+ t = (1 - w) * low_idx + w * high_idx
452
+ t = t.view(sigma.shape)
453
+ if t.ndim == 0:
454
+ t = t.unsqueeze(0)
455
+ return t
456
+
457
+ def _t_to_sigma(self, t):
458
+ t = t.float()
459
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
460
+ log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t)
461
+ log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
462
+ log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf')
463
+ return log_sigma.exp()
464
+
465
+ def prev_step(self, model_out, t, xt, inference_steps=50):
466
+ prev_t = t - self.num_timesteps // inference_steps
467
+
468
+ sigmas = _i(self.sigmas, t, xt)
469
+ alphas = _i(self.alphas, t, xt)
470
+ alphas_prev = _i(self.alphas, prev_t.clamp(0), xt)
471
+ alphas_prev[prev_t < 0] = 1.
472
+ sigmas_prev = torch.sqrt(1 - alphas_prev ** 2)
473
+
474
+ x0 = alphas * xt - sigmas * model_out
475
+ eps = (xt - alphas * x0) / sigmas
476
+ prev_sample = alphas_prev * x0 + sigmas_prev * eps
477
+ return prev_sample
478
+
479
+ def next_step(self, model_out, t, xt, inference_steps=50):
480
+ t, next_t = min(t - self.num_timesteps // inference_steps, 999), t
481
+
482
+ sigmas = _i(self.sigmas, t, xt)
483
+ alphas = _i(self.alphas, t, xt)
484
+ alphas_next = _i(self.alphas, next_t.clamp(0), xt)
485
+ alphas_next[next_t < 0] = 1.
486
+ sigmas_next = torch.sqrt(1 - alphas_next ** 2)
487
+
488
+ x0 = alphas * xt - sigmas * model_out
489
+ eps = (xt - alphas * x0) / sigmas
490
+ next_sample = alphas_next * x0 + sigmas_next * eps
491
+ return next_sample
492
+
493
+ def get_noise_pred_single(self, xt, t, model, model_kwargs):
494
+ assert isinstance(model_kwargs, dict)
495
+ out = model(xt, t=t, **model_kwargs)
496
+ return out
497
+
498
+
UniAnimate/tools/modules/diffusions/losses.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
5
+
6
+ def kl_divergence(mu1, logvar1, mu2, logvar2):
7
+ return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2))
8
+
9
+ def standard_normal_cdf(x):
10
+ r"""A fast approximation of the cumulative distribution function of the standard normal.
11
+ """
12
+ return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
13
+
14
+ def discretized_gaussian_log_likelihood(x0, mean, log_scale):
15
+ assert x0.shape == mean.shape == log_scale.shape
16
+ cx = x0 - mean
17
+ inv_stdv = torch.exp(-log_scale)
18
+ cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
19
+ cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
20
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
21
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
22
+ cdf_delta = cdf_plus - cdf_min
23
+ log_probs = torch.where(
24
+ x0 < -0.999,
25
+ log_cdf_plus,
26
+ torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))))
27
+ assert log_probs.shape == x0.shape
28
+ return log_probs
UniAnimate/tools/modules/diffusions/schedules.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ def beta_schedule(schedule='cosine',
6
+ num_timesteps=1000,
7
+ zero_terminal_snr=False,
8
+ **kwargs):
9
+ # compute betas
10
+ betas = {
11
+ # 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
12
+ 'linear': linear_schedule,
13
+ 'linear_sd': linear_sd_schedule,
14
+ 'quadratic': quadratic_schedule,
15
+ 'cosine': cosine_schedule
16
+ }[schedule](num_timesteps, **kwargs)
17
+
18
+ if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001:
19
+ betas = rescale_zero_terminal_snr(betas)
20
+
21
+ return betas
22
+
23
+
24
+ def sigma_schedule(schedule='cosine',
25
+ num_timesteps=1000,
26
+ zero_terminal_snr=False,
27
+ **kwargs):
28
+ # compute betas
29
+ betas = {
30
+ 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
31
+ 'linear': linear_schedule,
32
+ 'linear_sd': linear_sd_schedule,
33
+ 'quadratic': quadratic_schedule,
34
+ 'cosine': cosine_schedule
35
+ }[schedule](num_timesteps, **kwargs)
36
+ if schedule == 'logsnr_cosine_interp':
37
+ sigma = betas
38
+ else:
39
+ sigma = betas_to_sigmas(betas)
40
+ if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001:
41
+ sigma = rescale_zero_terminal_snr(sigma)
42
+
43
+ return sigma
44
+
45
+
46
+ def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
47
+ scale = 1000.0 / num_timesteps
48
+ init_beta = init_beta or scale * 0.0001
49
+ ast_beta = last_beta or scale * 0.02
50
+ return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
51
+
52
+ def logsnr_cosine_interp_schedule(
53
+ num_timesteps,
54
+ scale_min=2,
55
+ scale_max=4,
56
+ logsnr_min=-15,
57
+ logsnr_max=15,
58
+ **kwargs):
59
+ return logsnrs_to_sigmas(
60
+ _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
61
+
62
+ def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
63
+ return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
64
+
65
+
66
+ def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
67
+ init_beta = init_beta or 0.0015
68
+ last_beta = last_beta or 0.0195
69
+ return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
70
+
71
+
72
+ def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
73
+ betas = []
74
+ for step in range(num_timesteps):
75
+ t1 = step / num_timesteps
76
+ t2 = (step + 1) / num_timesteps
77
+ fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
78
+ betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
79
+ return torch.tensor(betas, dtype=torch.float64)
80
+
81
+
82
+ # def cosine_schedule(n, cosine_s=0.008, **kwargs):
83
+ # ramp = torch.linspace(0, 1, n + 1)
84
+ # square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
85
+ # betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
86
+ # return betas_to_sigmas(betas)
87
+
88
+
89
+ def betas_to_sigmas(betas):
90
+ return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
91
+
92
+
93
+ def sigmas_to_betas(sigmas):
94
+ square_alphas = 1 - sigmas**2
95
+ betas = 1 - torch.cat(
96
+ [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
97
+ return betas
98
+
99
+
100
+
101
+ def sigmas_to_logsnrs(sigmas):
102
+ square_sigmas = sigmas**2
103
+ return torch.log(square_sigmas / (1 - square_sigmas))
104
+
105
+
106
+ def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
107
+ t_min = math.atan(math.exp(-0.5 * logsnr_min))
108
+ t_max = math.atan(math.exp(-0.5 * logsnr_max))
109
+ t = torch.linspace(1, 0, n)
110
+ logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
111
+ return logsnrs
112
+
113
+
114
+ def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
115
+ logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
116
+ logsnrs += 2 * math.log(1 / scale)
117
+ return logsnrs
118
+
119
+ def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
120
+ ramp = torch.linspace(1, 0, n)
121
+ min_inv_rho = sigma_min**(1 / rho)
122
+ max_inv_rho = sigma_max**(1 / rho)
123
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
124
+ sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
125
+ return sigmas
126
+
127
+ def _logsnr_cosine_interp(n,
128
+ logsnr_min=-15,
129
+ logsnr_max=15,
130
+ scale_min=2,
131
+ scale_max=4):
132
+ t = torch.linspace(1, 0, n)
133
+ logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
134
+ logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
135
+ logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
136
+ return logsnrs
137
+
138
+
139
+ def logsnrs_to_sigmas(logsnrs):
140
+ return torch.sqrt(torch.sigmoid(-logsnrs))
141
+
142
+
143
+ def rescale_zero_terminal_snr(betas):
144
+ """
145
+ Rescale Schedule to Zero Terminal SNR
146
+ """
147
+ # Convert betas to alphas_bar_sqrt
148
+ alphas = 1 - betas
149
+ alphas_bar = alphas.cumprod(0)
150
+ alphas_bar_sqrt = alphas_bar.sqrt()
151
+
152
+ # Store old values. 8 alphas_bar_sqrt_0 = a
153
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
154
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
155
+ # Shift so last timestep is zero.
156
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
157
+ # Scale so first timestep is back to old value.
158
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
159
+
160
+ # Convert alphas_bar_sqrt to betas
161
+ alphas_bar = alphas_bar_sqrt ** 2
162
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
163
+ alphas = torch.cat([alphas_bar[0:1], alphas])
164
+ betas = 1 - alphas
165
+ return betas
166
+
UniAnimate/tools/modules/embedding_manager.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import open_clip
5
+
6
+ from functools import partial
7
+ from utils.registry_class import EMBEDMANAGER
8
+
9
+ DEFAULT_PLACEHOLDER_TOKEN = ["*"]
10
+
11
+ PROGRESSIVE_SCALE = 2000
12
+
13
+ per_img_token_list = [
14
+ 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
15
+ ]
16
+
17
+ def get_clip_token_for_string(string):
18
+ tokens = open_clip.tokenize(string)
19
+
20
+ return tokens[0, 1]
21
+
22
+ def get_embedding_for_clip_token(embedder, token):
23
+ return embedder(token.unsqueeze(0))[0]
24
+
25
+
26
+ @EMBEDMANAGER.register_class()
27
+ class EmbeddingManager(nn.Module):
28
+ def __init__(
29
+ self,
30
+ embedder,
31
+ placeholder_strings=None,
32
+ initializer_words=None,
33
+ per_image_tokens=False,
34
+ num_vectors_per_token=1,
35
+ progressive_words=False,
36
+ temporal_prompt_length=1,
37
+ token_dim=1024,
38
+ **kwargs
39
+ ):
40
+ super().__init__()
41
+
42
+ self.string_to_token_dict = {}
43
+
44
+ self.string_to_param_dict = nn.ParameterDict()
45
+
46
+ self.initial_embeddings = nn.ParameterDict() # These should not be optimized
47
+
48
+ self.progressive_words = progressive_words
49
+ self.progressive_counter = 0
50
+
51
+ self.max_vectors_per_token = num_vectors_per_token
52
+
53
+ get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu())
54
+
55
+ if per_image_tokens:
56
+ placeholder_strings.extend(per_img_token_list)
57
+
58
+ for idx, placeholder_string in enumerate(placeholder_strings):
59
+
60
+ token = get_clip_token_for_string(placeholder_string)
61
+
62
+ if initializer_words and idx < len(initializer_words):
63
+ init_word_token = get_clip_token_for_string(initializer_words[idx])
64
+
65
+ with torch.no_grad():
66
+ init_word_embedding = get_embedding_for_tkn(init_word_token)
67
+
68
+ token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
69
+ self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
70
+ else:
71
+ token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
72
+
73
+ self.string_to_token_dict[placeholder_string] = token
74
+ self.string_to_param_dict[placeholder_string] = token_params
75
+
76
+
77
+ def forward(
78
+ self,
79
+ tokenized_text,
80
+ embedded_text,
81
+ ):
82
+ b, n, device = *tokenized_text.shape, tokenized_text.device
83
+
84
+ for placeholder_string, placeholder_token in self.string_to_token_dict.items():
85
+
86
+ placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
87
+
88
+ if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
89
+ placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
90
+ embedded_text[placeholder_idx] = placeholder_embedding
91
+ else: # otherwise, need to insert and keep track of changing indices
92
+ if self.progressive_words:
93
+ self.progressive_counter += 1
94
+ max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
95
+ else:
96
+ max_step_tokens = self.max_vectors_per_token
97
+
98
+ num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
99
+
100
+ placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
101
+
102
+ if placeholder_rows.nelement() == 0:
103
+ continue
104
+
105
+ sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
106
+ sorted_rows = placeholder_rows[sort_idx]
107
+
108
+ for idx in range(len(sorted_rows)):
109
+ row = sorted_rows[idx]
110
+ col = sorted_cols[idx]
111
+
112
+ new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
113
+ new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
114
+
115
+ embedded_text[row] = new_embed_row
116
+ tokenized_text[row] = new_token_row
117
+
118
+ return embedded_text
119
+
120
+ def forward_with_text_img(
121
+ self,
122
+ tokenized_text,
123
+ embedded_text,
124
+ embedded_img,
125
+ ):
126
+ device = tokenized_text.device
127
+ for placeholder_string, placeholder_token in self.string_to_token_dict.items():
128
+ placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
129
+ placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
130
+ embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding
131
+ return embedded_text
132
+
133
+ def forward_with_text(
134
+ self,
135
+ tokenized_text,
136
+ embedded_text
137
+ ):
138
+ device = tokenized_text.device
139
+ for placeholder_string, placeholder_token in self.string_to_token_dict.items():
140
+ placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
141
+ placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
142
+ embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding
143
+ return embedded_text
144
+
145
+ def save(self, ckpt_path):
146
+ torch.save({"string_to_token": self.string_to_token_dict,
147
+ "string_to_param": self.string_to_param_dict}, ckpt_path)
148
+
149
+ def load(self, ckpt_path):
150
+ ckpt = torch.load(ckpt_path, map_location='cpu')
151
+
152
+ string_to_token = ckpt["string_to_token"]
153
+ string_to_param = ckpt["string_to_param"]
154
+ for string, token in string_to_token.items():
155
+ self.string_to_token_dict[string] = token
156
+ for string, param in string_to_param.items():
157
+ self.string_to_param_dict[string] = param
158
+
159
+ def get_embedding_norms_squared(self):
160
+ all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
161
+ param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
162
+
163
+ return param_norm_squared
164
+
165
+ def embedding_parameters(self):
166
+ return self.string_to_param_dict.parameters()
167
+
168
+ def embedding_to_coarse_loss(self):
169
+
170
+ loss = 0.
171
+ num_embeddings = len(self.initial_embeddings)
172
+
173
+ for key in self.initial_embeddings:
174
+ optimized = self.string_to_param_dict[key]
175
+ coarse = self.initial_embeddings[key].clone().to(optimized.device)
176
+
177
+ loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
178
+
179
+ return loss
UniAnimate/tools/modules/unet/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .unet_unianimate import *
2
+
UniAnimate/tools/modules/unet/mha_flash.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.cuda.amp as amp
4
+ import torch.nn.functional as F
5
+ import math
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import random
10
+
11
+ # from flash_attn.flash_attention import FlashAttention
12
+ class FlashAttentionBlock(nn.Module):
13
+
14
+ def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
15
+ # consider head_dim first, then num_heads
16
+ num_heads = dim // head_dim if head_dim else num_heads
17
+ head_dim = dim // num_heads
18
+ assert num_heads * head_dim == dim
19
+ super(FlashAttentionBlock, self).__init__()
20
+ self.dim = dim
21
+ self.context_dim = context_dim
22
+ self.num_heads = num_heads
23
+ self.head_dim = head_dim
24
+ self.scale = math.pow(head_dim, -0.25)
25
+
26
+ # layers
27
+ self.norm = nn.GroupNorm(32, dim)
28
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
29
+ if context_dim is not None:
30
+ self.context_kv = nn.Linear(context_dim, dim * 2)
31
+ self.proj = nn.Conv2d(dim, dim, 1)
32
+
33
+ if self.head_dim <= 128 and (self.head_dim % 8) == 0:
34
+ new_scale = math.pow(head_dim, -0.5)
35
+ self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
36
+
37
+ # zero out the last layer params
38
+ nn.init.zeros_(self.proj.weight)
39
+ # self.apply(self._init_weight)
40
+
41
+
42
+ def _init_weight(self, module):
43
+ if isinstance(module, nn.Linear):
44
+ module.weight.data.normal_(mean=0.0, std=0.15)
45
+ if module.bias is not None:
46
+ module.bias.data.zero_()
47
+ elif isinstance(module, nn.Conv2d):
48
+ module.weight.data.normal_(mean=0.0, std=0.15)
49
+ if module.bias is not None:
50
+ module.bias.data.zero_()
51
+
52
+ def forward(self, x, context=None):
53
+ r"""x: [B, C, H, W].
54
+ context: [B, L, C] or None.
55
+ """
56
+ identity = x
57
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
58
+
59
+ # compute query, key, value
60
+ x = self.norm(x)
61
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
62
+ if context is not None:
63
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
64
+ k = torch.cat([ck, k], dim=-1)
65
+ v = torch.cat([cv, v], dim=-1)
66
+ cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
67
+ q = torch.cat([q, cq], dim=-1)
68
+
69
+ qkv = torch.cat([q,k,v], dim=1)
70
+ origin_dtype = qkv.dtype
71
+ qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
72
+ out, _ = self.flash_attn(qkv)
73
+ out.to(origin_dtype)
74
+
75
+ if context is not None:
76
+ out = out[:, :-4, :, :]
77
+ out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
78
+
79
+ # output
80
+ x = self.proj(out)
81
+ return x + identity
82
+
83
+ if __name__ == '__main__':
84
+ batch_size = 8
85
+ flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
86
+
87
+ x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
88
+ context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
89
+ # context = None
90
+ flash_net.eval()
91
+
92
+ with amp.autocast(enabled=True):
93
+ # warm up
94
+ for i in range(5):
95
+ y = flash_net(x, context)
96
+ torch.cuda.synchronize()
97
+ s1 = time.time()
98
+ for i in range(10):
99
+ y = flash_net(x, context)
100
+ torch.cuda.synchronize()
101
+ s2 = time.time()
102
+
103
+ print(f'Average cost time {(s2-s1)*1000/10} ms')
UniAnimate/tools/modules/unet/unet_unianimate.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import xformers
4
+ import xformers.ops
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ import torch.nn.functional as F
8
+ from rotary_embedding_torch import RotaryEmbedding
9
+ from fairscale.nn.checkpoint import checkpoint_wrapper
10
+
11
+ from .util import *
12
+ # from .mha_flash import FlashAttentionBlock
13
+ from utils.registry_class import MODEL
14
+
15
+
16
+ USE_TEMPORAL_TRANSFORMER = True
17
+
18
+
19
+
20
+ class PreNormattention(nn.Module):
21
+ def __init__(self, dim, fn):
22
+ super().__init__()
23
+ self.norm = nn.LayerNorm(dim)
24
+ self.fn = fn
25
+ def forward(self, x, **kwargs):
26
+ return self.fn(self.norm(x), **kwargs) + x
27
+
28
+ class PreNormattention_qkv(nn.Module):
29
+ def __init__(self, dim, fn):
30
+ super().__init__()
31
+ self.norm = nn.LayerNorm(dim)
32
+ self.fn = fn
33
+ def forward(self, q, k, v, **kwargs):
34
+ return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38
+ super().__init__()
39
+ inner_dim = dim_head * heads
40
+ project_out = not (heads == 1 and dim_head == dim)
41
+
42
+ self.heads = heads
43
+ self.scale = dim_head ** -0.5
44
+
45
+ self.attend = nn.Softmax(dim = -1)
46
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
47
+
48
+ self.to_out = nn.Sequential(
49
+ nn.Linear(inner_dim, dim),
50
+ nn.Dropout(dropout)
51
+ ) if project_out else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ b, n, _, h = *x.shape, self.heads
55
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
56
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
57
+
58
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
59
+
60
+ attn = self.attend(dots)
61
+
62
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
63
+ out = rearrange(out, 'b h n d -> b n (h d)')
64
+ return self.to_out(out)
65
+
66
+
67
+ class Attention_qkv(nn.Module):
68
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
69
+ super().__init__()
70
+ inner_dim = dim_head * heads
71
+ project_out = not (heads == 1 and dim_head == dim)
72
+
73
+ self.heads = heads
74
+ self.scale = dim_head ** -0.5
75
+
76
+ self.attend = nn.Softmax(dim = -1)
77
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
78
+ self.to_k = nn.Linear(dim, inner_dim, bias = False)
79
+ self.to_v = nn.Linear(dim, inner_dim, bias = False)
80
+
81
+ self.to_out = nn.Sequential(
82
+ nn.Linear(inner_dim, dim),
83
+ nn.Dropout(dropout)
84
+ ) if project_out else nn.Identity()
85
+
86
+ def forward(self, q, k, v):
87
+ b, n, _, h = *q.shape, self.heads
88
+ bk = k.shape[0]
89
+
90
+ q = self.to_q(q)
91
+ k = self.to_k(k)
92
+ v = self.to_v(v)
93
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
94
+ k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h)
95
+ v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h)
96
+
97
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
98
+
99
+ attn = self.attend(dots)
100
+
101
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
102
+ out = rearrange(out, 'b h n d -> b n (h d)')
103
+ return self.to_out(out)
104
+
105
+ class PostNormattention(nn.Module):
106
+ def __init__(self, dim, fn):
107
+ super().__init__()
108
+ self.norm = nn.LayerNorm(dim)
109
+ self.fn = fn
110
+ def forward(self, x, **kwargs):
111
+ return self.norm(self.fn(x, **kwargs) + x)
112
+
113
+
114
+
115
+
116
+ class Transformer_v2(nn.Module):
117
+ def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
118
+ super().__init__()
119
+ self.layers = nn.ModuleList([])
120
+ self.depth = depth
121
+ for _ in range(depth):
122
+ self.layers.append(nn.ModuleList([
123
+ PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
124
+ FeedForward(dim, mlp_dim, dropout = dropout_ffn),
125
+ ]))
126
+ def forward(self, x):
127
+ for attn, ff in self.layers[:1]:
128
+ x = attn(x)
129
+ x = ff(x) + x
130
+ if self.depth > 1:
131
+ for attn, ff in self.layers[1:]:
132
+ x = attn(x)
133
+ x = ff(x) + x
134
+ return x
135
+
136
+
137
+ class DropPath(nn.Module):
138
+ r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
139
+ """
140
+ def __init__(self, p):
141
+ super(DropPath, self).__init__()
142
+ self.p = p
143
+
144
+ def forward(self, *args, zero=None, keep=None):
145
+ if not self.training:
146
+ return args[0] if len(args) == 1 else args
147
+
148
+ # params
149
+ x = args[0]
150
+ b = x.size(0)
151
+ n = (torch.rand(b) < self.p).sum()
152
+
153
+ # non-zero and non-keep mask
154
+ mask = x.new_ones(b, dtype=torch.bool)
155
+ if keep is not None:
156
+ mask[keep] = False
157
+ if zero is not None:
158
+ mask[zero] = False
159
+
160
+ # drop-path index
161
+ index = torch.where(mask)[0]
162
+ index = index[torch.randperm(len(index))[:n]]
163
+ if zero is not None:
164
+ index = torch.cat([index, torch.where(zero)[0]], dim=0)
165
+
166
+ # drop-path multiplier
167
+ multiplier = x.new_ones(b)
168
+ multiplier[index] = 0.0
169
+ output = tuple(u * self.broadcast(multiplier, u) for u in args)
170
+ return output[0] if len(args) == 1 else output
171
+
172
+ def broadcast(self, src, dst):
173
+ assert src.size(0) == dst.size(0)
174
+ shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
175
+ return src.view(shape)
176
+
177
+
178
+
179
+
180
+ @MODEL.register_class()
181
+ class UNetSD_UniAnimate(nn.Module):
182
+
183
+ def __init__(self,
184
+ config=None,
185
+ in_dim=4,
186
+ dim=512,
187
+ y_dim=512,
188
+ context_dim=1024,
189
+ hist_dim = 156,
190
+ concat_dim = 8,
191
+ out_dim=6,
192
+ dim_mult=[1, 2, 3, 4],
193
+ num_heads=None,
194
+ head_dim=64,
195
+ num_res_blocks=3,
196
+ attn_scales=[1 / 2, 1 / 4, 1 / 8],
197
+ use_scale_shift_norm=True,
198
+ dropout=0.1,
199
+ temporal_attn_times=1,
200
+ temporal_attention = True,
201
+ use_checkpoint=False,
202
+ use_image_dataset=False,
203
+ use_fps_condition= False,
204
+ use_sim_mask = False,
205
+ misc_dropout = 0.5,
206
+ training=True,
207
+ inpainting=True,
208
+ p_all_zero=0.1,
209
+ p_all_keep=0.1,
210
+ zero_y = None,
211
+ black_image_feature = None,
212
+ adapter_transformer_layers = 1,
213
+ num_tokens=4,
214
+ **kwargs
215
+ ):
216
+ embed_dim = dim * 4
217
+ num_heads=num_heads if num_heads else dim//32
218
+ super(UNetSD_UniAnimate, self).__init__()
219
+ self.zero_y = zero_y
220
+ self.black_image_feature = black_image_feature
221
+ self.cfg = config
222
+ self.in_dim = in_dim
223
+ self.dim = dim
224
+ self.y_dim = y_dim
225
+ self.context_dim = context_dim
226
+ self.num_tokens = num_tokens
227
+ self.hist_dim = hist_dim
228
+ self.concat_dim = concat_dim
229
+ self.embed_dim = embed_dim
230
+ self.out_dim = out_dim
231
+ self.dim_mult = dim_mult
232
+
233
+ self.num_heads = num_heads
234
+
235
+ self.head_dim = head_dim
236
+ self.num_res_blocks = num_res_blocks
237
+ self.attn_scales = attn_scales
238
+ self.use_scale_shift_norm = use_scale_shift_norm
239
+ self.temporal_attn_times = temporal_attn_times
240
+ self.temporal_attention = temporal_attention
241
+ self.use_checkpoint = use_checkpoint
242
+ self.use_image_dataset = use_image_dataset
243
+ self.use_fps_condition = use_fps_condition
244
+ self.use_sim_mask = use_sim_mask
245
+ self.training=training
246
+ self.inpainting = inpainting
247
+ self.video_compositions = self.cfg.video_compositions
248
+ self.misc_dropout = misc_dropout
249
+ self.p_all_zero = p_all_zero
250
+ self.p_all_keep = p_all_keep
251
+
252
+ use_linear_in_temporal = False
253
+ transformer_depth = 1
254
+ disabled_sa = False
255
+ # params
256
+ enc_dims = [dim * u for u in [1] + dim_mult]
257
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
258
+ shortcut_dims = []
259
+ scale = 1.0
260
+ self.resolution = config.resolution
261
+
262
+
263
+ # embeddings
264
+ self.time_embed = nn.Sequential(
265
+ nn.Linear(dim, embed_dim),
266
+ nn.SiLU(),
267
+ nn.Linear(embed_dim, embed_dim))
268
+ if 'image' in self.video_compositions:
269
+ self.pre_image_condition = nn.Sequential(
270
+ nn.Linear(self.context_dim, self.context_dim),
271
+ nn.SiLU(),
272
+ nn.Linear(self.context_dim, self.context_dim*self.num_tokens))
273
+
274
+
275
+ if 'local_image' in self.video_compositions:
276
+ self.local_image_embedding = nn.Sequential(
277
+ nn.Conv2d(3, concat_dim * 4, 3, padding=1),
278
+ nn.SiLU(),
279
+ nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
280
+ nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
281
+ nn.SiLU(),
282
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
283
+ self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
284
+
285
+ if 'dwpose' in self.video_compositions:
286
+ self.dwpose_embedding = nn.Sequential(
287
+ nn.Conv2d(3, concat_dim * 4, 3, padding=1),
288
+ nn.SiLU(),
289
+ nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
290
+ nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
291
+ nn.SiLU(),
292
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
293
+ self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
294
+
295
+ if 'randomref_pose' in self.video_compositions:
296
+ randomref_dim = 4
297
+ self.randomref_pose2_embedding = nn.Sequential(
298
+ nn.Conv2d(3, concat_dim * 4, 3, padding=1),
299
+ nn.SiLU(),
300
+ nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
301
+ nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
302
+ nn.SiLU(),
303
+ nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1))
304
+ self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
305
+
306
+ if 'randomref' in self.video_compositions:
307
+ randomref_dim = 4
308
+ self.randomref_embedding2 = nn.Sequential(
309
+ nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1),
310
+ nn.SiLU(),
311
+ nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1),
312
+ nn.SiLU(),
313
+ nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1))
314
+ self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
315
+
316
+ ### Condition Dropout
317
+ self.misc_dropout = DropPath(misc_dropout)
318
+
319
+
320
+ if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
321
+ self.rotary_emb = RotaryEmbedding(min(32, head_dim))
322
+ self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet
323
+
324
+ if self.use_fps_condition:
325
+ self.fps_embedding = nn.Sequential(
326
+ nn.Linear(dim, embed_dim),
327
+ nn.SiLU(),
328
+ nn.Linear(embed_dim, embed_dim))
329
+ nn.init.zeros_(self.fps_embedding[-1].weight)
330
+ nn.init.zeros_(self.fps_embedding[-1].bias)
331
+
332
+ # encoder
333
+ self.input_blocks = nn.ModuleList()
334
+ self.pre_image = nn.Sequential()
335
+ init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)])
336
+
337
+ #### need an initial temporal attention?
338
+ if temporal_attention:
339
+ if USE_TEMPORAL_TRANSFORMER:
340
+ init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
341
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
342
+ else:
343
+ init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
344
+
345
+ self.input_blocks.append(init_block)
346
+ shortcut_dims.append(dim)
347
+ for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
348
+ for j in range(num_res_blocks):
349
+
350
+ block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)])
351
+
352
+ if scale in attn_scales:
353
+ block.append(
354
+ SpatialTransformer(
355
+ out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
356
+ disable_self_attn=False, use_linear=True
357
+ )
358
+ )
359
+ if self.temporal_attention:
360
+ if USE_TEMPORAL_TRANSFORMER:
361
+ block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
362
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
363
+ else:
364
+ block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
365
+ in_dim = out_dim
366
+ self.input_blocks.append(block)
367
+ shortcut_dims.append(out_dim)
368
+
369
+ # downsample
370
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
371
+ downsample = Downsample(
372
+ out_dim, True, dims=2, out_channels=out_dim
373
+ )
374
+ shortcut_dims.append(out_dim)
375
+ scale /= 2.0
376
+ self.input_blocks.append(downsample)
377
+
378
+ # middle
379
+ self.middle_block = nn.ModuleList([
380
+ ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
381
+ SpatialTransformer(
382
+ out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
383
+ disable_self_attn=False, use_linear=True
384
+ )])
385
+
386
+ if self.temporal_attention:
387
+ if USE_TEMPORAL_TRANSFORMER:
388
+ self.middle_block.append(
389
+ TemporalTransformer(
390
+ out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
391
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
392
+ multiply_zero=use_image_dataset,
393
+ )
394
+ )
395
+ else:
396
+ self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
397
+
398
+ self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
399
+
400
+
401
+ # decoder
402
+ self.output_blocks = nn.ModuleList()
403
+ for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
404
+ for j in range(num_res_blocks + 1):
405
+
406
+ block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
407
+ if scale in attn_scales:
408
+ block.append(
409
+ SpatialTransformer(
410
+ out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
411
+ disable_self_attn=False, use_linear=True
412
+ )
413
+ )
414
+ if self.temporal_attention:
415
+ if USE_TEMPORAL_TRANSFORMER:
416
+ block.append(
417
+ TemporalTransformer(
418
+ out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
419
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
420
+ )
421
+ )
422
+ else:
423
+ block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
424
+ in_dim = out_dim
425
+
426
+ # upsample
427
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
428
+ upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
429
+ scale *= 2.0
430
+ block.append(upsample)
431
+ self.output_blocks.append(block)
432
+
433
+ # head
434
+ self.out = nn.Sequential(
435
+ nn.GroupNorm(32, out_dim),
436
+ nn.SiLU(),
437
+ nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
438
+
439
+ # zero out the last layer params
440
+ nn.init.zeros_(self.out[-1].weight)
441
+
442
+ def forward(self,
443
+ x,
444
+ t,
445
+ y = None,
446
+ depth = None,
447
+ image = None,
448
+ motion = None,
449
+ local_image = None,
450
+ single_sketch = None,
451
+ masked = None,
452
+ canny = None,
453
+ sketch = None,
454
+ dwpose = None,
455
+ randomref = None,
456
+ histogram = None,
457
+ fps = None,
458
+ video_mask = None,
459
+ focus_present_mask = None,
460
+ prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
461
+ mask_last_frame_num = 0 # mask last frame num
462
+ ):
463
+
464
+
465
+ assert self.inpainting or masked is None, 'inpainting is not supported'
466
+
467
+ batch, c, f, h, w= x.shape
468
+ frames = f
469
+ device = x.device
470
+ self.batch = batch
471
+
472
+ #### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
473
+ if mask_last_frame_num > 0:
474
+ focus_present_mask = None
475
+ video_mask[-mask_last_frame_num:] = False
476
+ else:
477
+ focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
478
+
479
+ if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
480
+ time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
481
+ else:
482
+ time_rel_pos_bias = None
483
+
484
+
485
+ # all-zero and all-keep masks
486
+ zero = torch.zeros(batch, dtype=torch.bool).to(x.device)
487
+ keep = torch.zeros(batch, dtype=torch.bool).to(x.device)
488
+ if self.training:
489
+ nzero = (torch.rand(batch) < self.p_all_zero).sum()
490
+ nkeep = (torch.rand(batch) < self.p_all_keep).sum()
491
+ index = torch.randperm(batch)
492
+ zero[index[0:nzero]] = True
493
+ keep[index[nzero:nzero + nkeep]] = True
494
+ assert not (zero & keep).any()
495
+ misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep)
496
+
497
+
498
+ concat = x.new_zeros(batch, self.concat_dim, f, h, w)
499
+
500
+
501
+ # local_image_embedding (first frame)
502
+ if local_image is not None:
503
+ local_image = rearrange(local_image, 'b c f h w -> (b f) c h w')
504
+ local_image = self.local_image_embedding(local_image)
505
+
506
+ h = local_image.shape[2]
507
+ local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch))
508
+ local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h)
509
+
510
+ concat = concat + misc_dropout(local_image)
511
+
512
+ if dwpose is not None:
513
+ if 'randomref_pose' in self.video_compositions:
514
+ dwpose_random_ref = dwpose[:,:,:1].clone()
515
+ dwpose = dwpose[:,:,1:]
516
+ dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w')
517
+ dwpose = self.dwpose_embedding(dwpose)
518
+
519
+ h = dwpose.shape[2]
520
+ dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch))
521
+ dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h)
522
+ concat = concat + misc_dropout(dwpose)
523
+
524
+ randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w)
525
+ if randomref is not None:
526
+ randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w')
527
+ randomref = self.randomref_embedding2(randomref)
528
+
529
+ h = randomref.shape[2]
530
+ randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch))
531
+ if 'randomref_pose' in self.video_compositions:
532
+ dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w')
533
+ dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref)
534
+ dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch))
535
+ randomref = randomref + dwpose_random_ref
536
+
537
+ randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h)
538
+ randomref_b = randomref_b + randomref_a
539
+
540
+
541
+ x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2)
542
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
543
+ x = self.pre_image(x)
544
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
545
+
546
+ # embeddings
547
+ if self.use_fps_condition and fps is not None:
548
+ e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
549
+ else:
550
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
551
+
552
+ context = x.new_zeros(batch, 0, self.context_dim)
553
+
554
+
555
+ if image is not None:
556
+ y_context = self.zero_y.repeat(batch, 1, 1)
557
+ context = torch.cat([context, y_context], dim=1)
558
+
559
+ image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1)
560
+ context = torch.cat([context, image_context], dim=1)
561
+ else:
562
+ y_context = self.zero_y.repeat(batch, 1, 1)
563
+ context = torch.cat([context, y_context], dim=1)
564
+ image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens]
565
+ context = torch.cat([context, image_context], dim=1)
566
+
567
+ # repeat f times for spatial e and context
568
+ e = e.repeat_interleave(repeats=f+1, dim=0)
569
+ context = context.repeat_interleave(repeats=f+1, dim=0)
570
+
571
+
572
+
573
+ ## always in shape (b f) c h w, except for temporal layer
574
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
575
+ # encoder
576
+ xs = []
577
+ for block in self.input_blocks:
578
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
579
+ xs.append(x)
580
+
581
+ # middle
582
+ for block in self.middle_block:
583
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask)
584
+
585
+ # decoder
586
+ for block in self.output_blocks:
587
+ x = torch.cat([x, xs.pop()], dim=1)
588
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
589
+
590
+ # head
591
+ x = self.out(x)
592
+
593
+ # reshape back to (b c f h w)
594
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
595
+ return x[:,:,1:]
596
+
597
+ def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
598
+ if isinstance(module, ResidualBlock):
599
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
600
+ x = x.contiguous()
601
+ x = module(x, e, reference)
602
+ elif isinstance(module, ResBlock):
603
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
604
+ x = x.contiguous()
605
+ x = module(x, e, self.batch)
606
+ elif isinstance(module, SpatialTransformer):
607
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
608
+ x = module(x, context)
609
+ elif isinstance(module, TemporalTransformer):
610
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
611
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
612
+ x = module(x, context)
613
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
614
+ elif isinstance(module, CrossAttention):
615
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
616
+ x = module(x, context)
617
+ elif isinstance(module, MemoryEfficientCrossAttention):
618
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
619
+ x = module(x, context)
620
+ elif isinstance(module, BasicTransformerBlock):
621
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
622
+ x = module(x, context)
623
+ elif isinstance(module, FeedForward):
624
+ x = module(x, context)
625
+ elif isinstance(module, Upsample):
626
+ x = module(x)
627
+ elif isinstance(module, Downsample):
628
+ x = module(x)
629
+ elif isinstance(module, Resample):
630
+ x = module(x, reference)
631
+ elif isinstance(module, TemporalAttentionBlock):
632
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
633
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
634
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
635
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
636
+ elif isinstance(module, TemporalAttentionMultiBlock):
637
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
638
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
639
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
640
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
641
+ elif isinstance(module, InitTemporalConvBlock):
642
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
643
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
644
+ x = module(x)
645
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
646
+ elif isinstance(module, TemporalConvBlock):
647
+ module = checkpoint_wrapper(module) if self.use_checkpoint else module
648
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
649
+ x = module(x)
650
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
651
+ elif isinstance(module, nn.ModuleList):
652
+ for block in module:
653
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
654
+ else:
655
+ x = module(x)
656
+ return x
657
+
658
+
659
+
UniAnimate/tools/modules/unet/util.py ADDED
@@ -0,0 +1,1741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import xformers
4
+ import open_clip
5
+ import xformers.ops
6
+ import torch.nn as nn
7
+ from torch import einsum
8
+ from einops import rearrange
9
+ from functools import partial
10
+ import torch.nn.functional as F
11
+ import torch.nn.init as init
12
+ from rotary_embedding_torch import RotaryEmbedding
13
+ from fairscale.nn.checkpoint import checkpoint_wrapper
14
+
15
+ # from .mha_flash import FlashAttentionBlock
16
+ from utils.registry_class import MODEL
17
+
18
+
19
+ ### load all keys started with prefix and replace them with new_prefix
20
+ def load_Block(state, prefix, new_prefix=None):
21
+ if new_prefix is None:
22
+ new_prefix = prefix
23
+
24
+ state_dict = {}
25
+ state = {key:value for key,value in state.items() if prefix in key}
26
+ for key,value in state.items():
27
+ new_key = key.replace(prefix, new_prefix)
28
+ state_dict[new_key]=value
29
+ return state_dict
30
+
31
+
32
+ def load_2d_pretrained_state_dict(state,cfg):
33
+
34
+ new_state_dict = {}
35
+
36
+ dim = cfg.unet_dim
37
+ num_res_blocks = cfg.unet_res_blocks
38
+ temporal_attention = cfg.temporal_attention
39
+ temporal_conv = cfg.temporal_conv
40
+ dim_mult = cfg.unet_dim_mult
41
+ attn_scales = cfg.unet_attn_scales
42
+
43
+ # params
44
+ enc_dims = [dim * u for u in [1] + dim_mult]
45
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
46
+ shortcut_dims = []
47
+ scale = 1.0
48
+
49
+ #embeddings
50
+ state_dict = load_Block(state,prefix=f'time_embedding')
51
+ new_state_dict.update(state_dict)
52
+ state_dict = load_Block(state,prefix=f'y_embedding')
53
+ new_state_dict.update(state_dict)
54
+ state_dict = load_Block(state,prefix=f'context_embedding')
55
+ new_state_dict.update(state_dict)
56
+
57
+ encoder_idx = 0
58
+ ### init block
59
+ state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
60
+ new_state_dict.update(state_dict)
61
+ encoder_idx += 1
62
+
63
+ shortcut_dims.append(dim)
64
+ for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
65
+ for j in range(num_res_blocks):
66
+ # residual (+attention) blocks
67
+ idx = 0
68
+ idx_ = 0
69
+ # residual (+attention) blocks
70
+ state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
71
+ new_state_dict.update(state_dict)
72
+ idx += 1
73
+ idx_ = 2
74
+
75
+ if scale in attn_scales:
76
+ # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
77
+ state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
78
+ new_state_dict.update(state_dict)
79
+ # if temporal_attention:
80
+ # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
81
+ in_dim = out_dim
82
+ encoder_idx += 1
83
+ shortcut_dims.append(out_dim)
84
+
85
+ # downsample
86
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
87
+ # downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout)
88
+ state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
89
+ new_state_dict.update(state_dict)
90
+
91
+ shortcut_dims.append(out_dim)
92
+ scale /= 2.0
93
+ encoder_idx += 1
94
+
95
+ # middle
96
+ # self.middle = nn.ModuleList([
97
+ # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'),
98
+ # TemporalConvBlock(out_dim),
99
+ # AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
100
+ # if temporal_attention:
101
+ # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
102
+ # elif temporal_conv:
103
+ # self.middle.append(TemporalConvBlock(out_dim,dropout=dropout))
104
+ # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'))
105
+ # self.middle.append(TemporalConvBlock(out_dim))
106
+
107
+
108
+ # middle
109
+ middle_idx = 0
110
+ # self.middle = nn.ModuleList([
111
+ # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout),
112
+ # AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
113
+ state_dict = load_Block(state,prefix=f'middle.{middle_idx}')
114
+ new_state_dict.update(state_dict)
115
+ middle_idx += 2
116
+
117
+ state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}')
118
+ new_state_dict.update(state_dict)
119
+ middle_idx += 1
120
+
121
+ for _ in range(cfg.temporal_attn_times):
122
+ # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
123
+ middle_idx += 1
124
+
125
+ # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout))
126
+ state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}')
127
+ new_state_dict.update(state_dict)
128
+ middle_idx += 2
129
+
130
+
131
+ decoder_idx = 0
132
+ for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
133
+ for j in range(num_res_blocks + 1):
134
+ idx = 0
135
+ idx_ = 0
136
+ # residual (+attention) blocks
137
+ # block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
138
+ state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
139
+ new_state_dict.update(state_dict)
140
+ idx += 1
141
+ idx_ += 2
142
+ if scale in attn_scales:
143
+ # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
144
+ state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
145
+ new_state_dict.update(state_dict)
146
+ idx += 1
147
+ idx_ += 1
148
+ for _ in range(cfg.temporal_attn_times):
149
+ # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
150
+ idx_ +=1
151
+
152
+ in_dim = out_dim
153
+
154
+ # upsample
155
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
156
+
157
+ # upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout)
158
+ state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
159
+ new_state_dict.update(state_dict)
160
+ idx += 1
161
+ idx_ += 2
162
+
163
+ scale *= 2.0
164
+ # block.append(upsample)
165
+ # self.decoder.append(block)
166
+ decoder_idx += 1
167
+
168
+ # head
169
+ # self.head = nn.Sequential(
170
+ # nn.GroupNorm(32, out_dim),
171
+ # nn.SiLU(),
172
+ # nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1)))
173
+ state_dict = load_Block(state,prefix=f'head')
174
+ new_state_dict.update(state_dict)
175
+
176
+ return new_state_dict
177
+
178
+ def sinusoidal_embedding(timesteps, dim):
179
+ # check input
180
+ half = dim // 2
181
+ timesteps = timesteps.float()
182
+
183
+ # compute sinusoidal embedding
184
+ sinusoid = torch.outer(
185
+ timesteps,
186
+ torch.pow(10000, -torch.arange(half).to(timesteps).div(half)))
187
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
188
+ if dim % 2 != 0:
189
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
190
+ return x
191
+
192
+ def exists(x):
193
+ return x is not None
194
+
195
+ def default(val, d):
196
+ if exists(val):
197
+ return val
198
+ return d() if callable(d) else d
199
+
200
+ def prob_mask_like(shape, prob, device):
201
+ if prob == 1:
202
+ return torch.ones(shape, device = device, dtype = torch.bool)
203
+ elif prob == 0:
204
+ return torch.zeros(shape, device = device, dtype = torch.bool)
205
+ else:
206
+ mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
207
+ ### aviod mask all, which will cause find_unused_parameters error
208
+ if mask.all():
209
+ mask[0]=False
210
+ return mask
211
+
212
+
213
+ class MemoryEfficientCrossAttention(nn.Module):
214
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
215
+ def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0):
216
+ super().__init__()
217
+ inner_dim = dim_head * heads
218
+ context_dim = default(context_dim, query_dim)
219
+
220
+ self.max_bs = max_bs
221
+ self.heads = heads
222
+ self.dim_head = dim_head
223
+
224
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
225
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
226
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
227
+
228
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
229
+ self.attention_op: Optional[Any] = None
230
+
231
+ def forward(self, x, context=None, mask=None):
232
+ q = self.to_q(x)
233
+ context = default(context, x)
234
+ k = self.to_k(context)
235
+ v = self.to_v(context)
236
+
237
+ b, _, _ = q.shape
238
+ q, k, v = map(
239
+ lambda t: t.unsqueeze(3)
240
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
243
+ .contiguous(),
244
+ (q, k, v),
245
+ )
246
+
247
+ # actually compute the attention, what we cannot get enough of
248
+ if q.shape[0] > self.max_bs:
249
+ q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
250
+ k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
251
+ v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
252
+ out_list = []
253
+ for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
254
+ out = xformers.ops.memory_efficient_attention(
255
+ q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
256
+ out_list.append(out)
257
+ out = torch.cat(out_list, dim=0)
258
+ else:
259
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
260
+
261
+ if exists(mask):
262
+ raise NotImplementedError
263
+ out = (
264
+ out.unsqueeze(0)
265
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
266
+ .permute(0, 2, 1, 3)
267
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
268
+ )
269
+ return self.to_out(out)
270
+
271
+ class RelativePositionBias(nn.Module):
272
+ def __init__(
273
+ self,
274
+ heads = 8,
275
+ num_buckets = 32,
276
+ max_distance = 128
277
+ ):
278
+ super().__init__()
279
+ self.num_buckets = num_buckets
280
+ self.max_distance = max_distance
281
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
282
+
283
+ @staticmethod
284
+ def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
285
+ ret = 0
286
+ n = -relative_position
287
+
288
+ num_buckets //= 2
289
+ ret += (n < 0).long() * num_buckets
290
+ n = torch.abs(n)
291
+
292
+ max_exact = num_buckets // 2
293
+ is_small = n < max_exact
294
+
295
+ val_if_large = max_exact + (
296
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
297
+ ).long()
298
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
299
+
300
+ ret += torch.where(is_small, n, val_if_large)
301
+ return ret
302
+
303
+ def forward(self, n, device):
304
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
305
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
306
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
307
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
308
+ values = self.relative_attention_bias(rp_bucket)
309
+ return rearrange(values, 'i j h -> h i j')
310
+
311
+ class SpatialTransformer(nn.Module):
312
+ """
313
+ Transformer block for image-like data.
314
+ First, project the input (aka embedding)
315
+ and reshape to b, t, d.
316
+ Then apply standard transformer action.
317
+ Finally, reshape to image
318
+ NEW: use_linear for more efficiency instead of the 1x1 convs
319
+ """
320
+ def __init__(self, in_channels, n_heads, d_head,
321
+ depth=1, dropout=0., context_dim=None,
322
+ disable_self_attn=False, use_linear=False,
323
+ use_checkpoint=True):
324
+ super().__init__()
325
+ if exists(context_dim) and not isinstance(context_dim, list):
326
+ context_dim = [context_dim]
327
+ self.in_channels = in_channels
328
+ inner_dim = n_heads * d_head
329
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
330
+ if not use_linear:
331
+ self.proj_in = nn.Conv2d(in_channels,
332
+ inner_dim,
333
+ kernel_size=1,
334
+ stride=1,
335
+ padding=0)
336
+ else:
337
+ self.proj_in = nn.Linear(in_channels, inner_dim)
338
+
339
+ self.transformer_blocks = nn.ModuleList(
340
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
341
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
342
+ for d in range(depth)]
343
+ )
344
+ if not use_linear:
345
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
346
+ in_channels,
347
+ kernel_size=1,
348
+ stride=1,
349
+ padding=0))
350
+ else:
351
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
352
+ self.use_linear = use_linear
353
+
354
+ def forward(self, x, context=None):
355
+ # note: if no context is given, cross-attention defaults to self-attention
356
+ if not isinstance(context, list):
357
+ context = [context]
358
+ b, c, h, w = x.shape
359
+ x_in = x
360
+ x = self.norm(x)
361
+ if not self.use_linear:
362
+ x = self.proj_in(x)
363
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
364
+ if self.use_linear:
365
+ x = self.proj_in(x)
366
+ for i, block in enumerate(self.transformer_blocks):
367
+ x = block(x, context=context[i])
368
+ if self.use_linear:
369
+ x = self.proj_out(x)
370
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
371
+ if not self.use_linear:
372
+ x = self.proj_out(x)
373
+ return x + x_in
374
+
375
+
376
+ class SpatialTransformerWithAdapter(nn.Module):
377
+ """
378
+ Transformer block for image-like data.
379
+ First, project the input (aka embedding)
380
+ and reshape to b, t, d.
381
+ Then apply standard transformer action.
382
+ Finally, reshape to image
383
+ NEW: use_linear for more efficiency instead of the 1x1 convs
384
+ """
385
+ def __init__(self, in_channels, n_heads, d_head,
386
+ depth=1, dropout=0., context_dim=None,
387
+ disable_self_attn=False, use_linear=False,
388
+ use_checkpoint=True,
389
+ adapter_list=[], adapter_position_list=['', 'parallel', ''],
390
+ adapter_hidden_dim=None):
391
+ super().__init__()
392
+ if exists(context_dim) and not isinstance(context_dim, list):
393
+ context_dim = [context_dim]
394
+ self.in_channels = in_channels
395
+ inner_dim = n_heads * d_head
396
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
397
+ if not use_linear:
398
+ self.proj_in = nn.Conv2d(in_channels,
399
+ inner_dim,
400
+ kernel_size=1,
401
+ stride=1,
402
+ padding=0)
403
+ else:
404
+ self.proj_in = nn.Linear(in_channels, inner_dim)
405
+
406
+ self.transformer_blocks = nn.ModuleList(
407
+ [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
408
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint,
409
+ adapter_list=adapter_list, adapter_position_list=adapter_position_list,
410
+ adapter_hidden_dim=adapter_hidden_dim)
411
+ for d in range(depth)]
412
+ )
413
+ if not use_linear:
414
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
415
+ in_channels,
416
+ kernel_size=1,
417
+ stride=1,
418
+ padding=0))
419
+ else:
420
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
421
+ self.use_linear = use_linear
422
+
423
+ def forward(self, x, context=None):
424
+ # note: if no context is given, cross-attention defaults to self-attention
425
+ if not isinstance(context, list):
426
+ context = [context]
427
+ b, c, h, w = x.shape
428
+ x_in = x
429
+ x = self.norm(x)
430
+ if not self.use_linear:
431
+ x = self.proj_in(x)
432
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
433
+ if self.use_linear:
434
+ x = self.proj_in(x)
435
+ for i, block in enumerate(self.transformer_blocks):
436
+ x = block(x, context=context[i])
437
+ if self.use_linear:
438
+ x = self.proj_out(x)
439
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
440
+ if not self.use_linear:
441
+ x = self.proj_out(x)
442
+ return x + x_in
443
+
444
+ import os
445
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
446
+
447
+ class CrossAttention(nn.Module):
448
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
449
+ super().__init__()
450
+ inner_dim = dim_head * heads
451
+ context_dim = default(context_dim, query_dim)
452
+
453
+ self.scale = dim_head ** -0.5
454
+ self.heads = heads
455
+
456
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
457
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
458
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
459
+
460
+ self.to_out = nn.Sequential(
461
+ nn.Linear(inner_dim, query_dim),
462
+ nn.Dropout(dropout)
463
+ )
464
+
465
+ def forward(self, x, context=None, mask=None):
466
+ h = self.heads
467
+
468
+ q = self.to_q(x)
469
+ context = default(context, x)
470
+ k = self.to_k(context)
471
+ v = self.to_v(context)
472
+
473
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
474
+
475
+ # force cast to fp32 to avoid overflowing
476
+ if _ATTN_PRECISION =="fp32":
477
+ with torch.autocast(enabled=False, device_type = 'cuda'):
478
+ q, k = q.float(), k.float()
479
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
480
+ else:
481
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
482
+
483
+ del q, k
484
+
485
+ if exists(mask):
486
+ mask = rearrange(mask, 'b ... -> b (...)')
487
+ max_neg_value = -torch.finfo(sim.dtype).max
488
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
489
+ sim.masked_fill_(~mask, max_neg_value)
490
+
491
+ # attention, what we cannot get enough of
492
+ sim = sim.softmax(dim=-1)
493
+
494
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
495
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
496
+ return self.to_out(out)
497
+
498
+
499
+ class Adapter(nn.Module):
500
+ def __init__(self, in_dim, hidden_dim, condition_dim=None):
501
+ super().__init__()
502
+ self.down_linear = nn.Linear(in_dim, hidden_dim)
503
+ self.up_linear = nn.Linear(hidden_dim, in_dim)
504
+ self.condition_dim = condition_dim
505
+ if condition_dim is not None:
506
+ self.condition_linear = nn.Linear(condition_dim, in_dim)
507
+
508
+ init.zeros_(self.up_linear.weight)
509
+ init.zeros_(self.up_linear.bias)
510
+
511
+ def forward(self, x, condition=None, condition_lam=1):
512
+ x_in = x
513
+ if self.condition_dim is not None and condition is not None:
514
+ x = x + condition_lam * self.condition_linear(condition)
515
+ x = self.down_linear(x)
516
+ x = F.gelu(x)
517
+ x = self.up_linear(x)
518
+ x += x_in
519
+ return x
520
+
521
+
522
+ class MemoryEfficientCrossAttention_attemask(nn.Module):
523
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
524
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
525
+ super().__init__()
526
+ inner_dim = dim_head * heads
527
+ context_dim = default(context_dim, query_dim)
528
+
529
+ self.heads = heads
530
+ self.dim_head = dim_head
531
+
532
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
533
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
534
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
535
+
536
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
537
+ self.attention_op: Optional[Any] = None
538
+
539
+ def forward(self, x, context=None, mask=None):
540
+ q = self.to_q(x)
541
+ context = default(context, x)
542
+ k = self.to_k(context)
543
+ v = self.to_v(context)
544
+
545
+ b, _, _ = q.shape
546
+ q, k, v = map(
547
+ lambda t: t.unsqueeze(3)
548
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
549
+ .permute(0, 2, 1, 3)
550
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
551
+ .contiguous(),
552
+ (q, k, v),
553
+ )
554
+
555
+ # actually compute the attention, what we cannot get enough of
556
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op)
557
+
558
+ if exists(mask):
559
+ raise NotImplementedError
560
+ out = (
561
+ out.unsqueeze(0)
562
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
563
+ .permute(0, 2, 1, 3)
564
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
565
+ )
566
+ return self.to_out(out)
567
+
568
+
569
+
570
+ class BasicTransformerBlock_attemask(nn.Module):
571
+ # ATTENTION_MODES = {
572
+ # "softmax": CrossAttention, # vanilla attention
573
+ # "softmax-xformers": MemoryEfficientCrossAttention
574
+ # }
575
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
576
+ disable_self_attn=False):
577
+ super().__init__()
578
+ # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
579
+ # assert attn_mode in self.ATTENTION_MODES
580
+ # attn_cls = CrossAttention
581
+ attn_cls = MemoryEfficientCrossAttention_attemask
582
+ self.disable_self_attn = disable_self_attn
583
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
584
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
585
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
586
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
587
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
588
+ self.norm1 = nn.LayerNorm(dim)
589
+ self.norm2 = nn.LayerNorm(dim)
590
+ self.norm3 = nn.LayerNorm(dim)
591
+ self.checkpoint = checkpoint
592
+
593
+ def forward_(self, x, context=None):
594
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
595
+
596
+ def forward(self, x, context=None):
597
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
598
+ x = self.attn2(self.norm2(x), context=context) + x
599
+ x = self.ff(self.norm3(x)) + x
600
+ return x
601
+
602
+
603
+ class BasicTransformerBlockWithAdapter(nn.Module):
604
+ # ATTENTION_MODES = {
605
+ # "softmax": CrossAttention, # vanilla attention
606
+ # "softmax-xformers": MemoryEfficientCrossAttention
607
+ # }
608
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False,
609
+ adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], adapter_hidden_dim=None, adapter_condition_dim=None
610
+ ):
611
+ super().__init__()
612
+ # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
613
+ # assert attn_mode in self.ATTENTION_MODES
614
+ # attn_cls = CrossAttention
615
+ attn_cls = MemoryEfficientCrossAttention
616
+ self.disable_self_attn = disable_self_attn
617
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
618
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
619
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
620
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
621
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
622
+ self.norm1 = nn.LayerNorm(dim)
623
+ self.norm2 = nn.LayerNorm(dim)
624
+ self.norm3 = nn.LayerNorm(dim)
625
+ self.checkpoint = checkpoint
626
+ # adapter
627
+ self.adapter_list = adapter_list
628
+ self.adapter_position_list = adapter_position_list
629
+ hidden_dim = dim//2 if not adapter_hidden_dim else adapter_hidden_dim
630
+ if "self_attention" in adapter_list:
631
+ self.attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
632
+ if "cross_attention" in adapter_list:
633
+ self.cross_attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
634
+ if "feedforward" in adapter_list:
635
+ self.ff_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
636
+
637
+
638
+ def forward_(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
639
+ return checkpoint(self._forward, (x, context, adapter_condition, adapter_condition_lam), self.parameters(), self.checkpoint)
640
+
641
+ def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
642
+ if "self_attention" in self.adapter_list:
643
+ if self.adapter_position_list[0] == 'parallel':
644
+ # parallel
645
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + self.attn_adapter(x, adapter_condition, adapter_condition_lam)
646
+ elif self.adapter_position_list[0] == 'serial':
647
+ # serial
648
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
649
+ x = self.attn_adapter(x, adapter_condition, adapter_condition_lam)
650
+ else:
651
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
652
+
653
+ if "cross_attention" in self.adapter_list:
654
+ if self.adapter_position_list[1] == 'parallel':
655
+ # parallel
656
+ x = self.attn2(self.norm2(x), context=context) + self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam)
657
+ elif self.adapter_position_list[1] == 'serial':
658
+ x = self.attn2(self.norm2(x), context=context) + x
659
+ x = self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam)
660
+ else:
661
+ x = self.attn2(self.norm2(x), context=context) + x
662
+
663
+ if "feedforward" in self.adapter_list:
664
+ if self.adapter_position_list[2] == 'parallel':
665
+ x = self.ff(self.norm3(x)) + self.ff_adapter(x, adapter_condition, adapter_condition_lam)
666
+ elif self.adapter_position_list[2] == 'serial':
667
+ x = self.ff(self.norm3(x)) + x
668
+ x = self.ff_adapter(x, adapter_condition, adapter_condition_lam)
669
+ else:
670
+ x = self.ff(self.norm3(x)) + x
671
+
672
+ return x
673
+
674
+ class BasicTransformerBlock(nn.Module):
675
+ # ATTENTION_MODES = {
676
+ # "softmax": CrossAttention, # vanilla attention
677
+ # "softmax-xformers": MemoryEfficientCrossAttention
678
+ # }
679
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
680
+ disable_self_attn=False):
681
+ super().__init__()
682
+ # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
683
+ # assert attn_mode in self.ATTENTION_MODES
684
+ # attn_cls = CrossAttention
685
+ attn_cls = MemoryEfficientCrossAttention
686
+ self.disable_self_attn = disable_self_attn
687
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
688
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
689
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
690
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
691
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
692
+ self.norm1 = nn.LayerNorm(dim)
693
+ self.norm2 = nn.LayerNorm(dim)
694
+ self.norm3 = nn.LayerNorm(dim)
695
+ self.checkpoint = checkpoint
696
+
697
+ def forward_(self, x, context=None):
698
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
699
+
700
+ def forward(self, x, context=None):
701
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
702
+ x = self.attn2(self.norm2(x), context=context) + x
703
+ x = self.ff(self.norm3(x)) + x
704
+ return x
705
+
706
+ # feedforward
707
+ class GEGLU(nn.Module):
708
+ def __init__(self, dim_in, dim_out):
709
+ super().__init__()
710
+ self.proj = nn.Linear(dim_in, dim_out * 2)
711
+
712
+ def forward(self, x):
713
+ x, gate = self.proj(x).chunk(2, dim=-1)
714
+ return x * F.gelu(gate)
715
+
716
+ def zero_module(module):
717
+ """
718
+ Zero out the parameters of a module and return it.
719
+ """
720
+ for p in module.parameters():
721
+ p.detach().zero_()
722
+ return module
723
+
724
+ class FeedForward(nn.Module):
725
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
726
+ super().__init__()
727
+ inner_dim = int(dim * mult)
728
+ dim_out = default(dim_out, dim)
729
+ project_in = nn.Sequential(
730
+ nn.Linear(dim, inner_dim),
731
+ nn.GELU()
732
+ ) if not glu else GEGLU(dim, inner_dim)
733
+
734
+ self.net = nn.Sequential(
735
+ project_in,
736
+ nn.Dropout(dropout),
737
+ nn.Linear(inner_dim, dim_out)
738
+ )
739
+
740
+ def forward(self, x):
741
+ return self.net(x)
742
+
743
+ class Upsample(nn.Module):
744
+ """
745
+ An upsampling layer with an optional convolution.
746
+ :param channels: channels in the inputs and outputs.
747
+ :param use_conv: a bool determining if a convolution is applied.
748
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
749
+ upsampling occurs in the inner-two dimensions.
750
+ """
751
+
752
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
753
+ super().__init__()
754
+ self.channels = channels
755
+ self.out_channels = out_channels or channels
756
+ self.use_conv = use_conv
757
+ self.dims = dims
758
+ if use_conv:
759
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
760
+
761
+ def forward(self, x):
762
+ assert x.shape[1] == self.channels
763
+ if self.dims == 3:
764
+ x = F.interpolate(
765
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
766
+ )
767
+ else:
768
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
769
+ if self.use_conv:
770
+ x = self.conv(x)
771
+ return x
772
+
773
+
774
+ class UpsampleSR600(nn.Module):
775
+ """
776
+ An upsampling layer with an optional convolution.
777
+ :param channels: channels in the inputs and outputs.
778
+ :param use_conv: a bool determining if a convolution is applied.
779
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
780
+ upsampling occurs in the inner-two dimensions.
781
+ """
782
+
783
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
784
+ super().__init__()
785
+ self.channels = channels
786
+ self.out_channels = out_channels or channels
787
+ self.use_conv = use_conv
788
+ self.dims = dims
789
+ if use_conv:
790
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
791
+
792
+ def forward(self, x):
793
+ assert x.shape[1] == self.channels
794
+ if self.dims == 3:
795
+ x = F.interpolate(
796
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
797
+ )
798
+ else:
799
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
800
+ # TODO: to match input_blocks, remove elements of two sides
801
+ x = x[..., 1:-1, :]
802
+ if self.use_conv:
803
+ x = self.conv(x)
804
+ return x
805
+
806
+
807
+ class ResBlock(nn.Module):
808
+ """
809
+ A residual block that can optionally change the number of channels.
810
+ :param channels: the number of input channels.
811
+ :param emb_channels: the number of timestep embedding channels.
812
+ :param dropout: the rate of dropout.
813
+ :param out_channels: if specified, the number of out channels.
814
+ :param use_conv: if True and out_channels is specified, use a spatial
815
+ convolution instead of a smaller 1x1 convolution to change the
816
+ channels in the skip connection.
817
+ :param dims: determines if the signal is 1D, 2D, or 3D.
818
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
819
+ :param up: if True, use this block for upsampling.
820
+ :param down: if True, use this block for downsampling.
821
+ """
822
+ def __init__(
823
+ self,
824
+ channels,
825
+ emb_channels,
826
+ dropout,
827
+ out_channels=None,
828
+ use_conv=False,
829
+ use_scale_shift_norm=False,
830
+ dims=2,
831
+ up=False,
832
+ down=False,
833
+ use_temporal_conv=True,
834
+ use_image_dataset=False,
835
+ ):
836
+ super().__init__()
837
+ self.channels = channels
838
+ self.emb_channels = emb_channels
839
+ self.dropout = dropout
840
+ self.out_channels = out_channels or channels
841
+ self.use_conv = use_conv
842
+ self.use_scale_shift_norm = use_scale_shift_norm
843
+ self.use_temporal_conv = use_temporal_conv
844
+
845
+ self.in_layers = nn.Sequential(
846
+ nn.GroupNorm(32, channels),
847
+ nn.SiLU(),
848
+ nn.Conv2d(channels, self.out_channels, 3, padding=1),
849
+ )
850
+
851
+ self.updown = up or down
852
+
853
+ if up:
854
+ self.h_upd = Upsample(channels, False, dims)
855
+ self.x_upd = Upsample(channels, False, dims)
856
+ elif down:
857
+ self.h_upd = Downsample(channels, False, dims)
858
+ self.x_upd = Downsample(channels, False, dims)
859
+ else:
860
+ self.h_upd = self.x_upd = nn.Identity()
861
+
862
+ self.emb_layers = nn.Sequential(
863
+ nn.SiLU(),
864
+ nn.Linear(
865
+ emb_channels,
866
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
867
+ ),
868
+ )
869
+ self.out_layers = nn.Sequential(
870
+ nn.GroupNorm(32, self.out_channels),
871
+ nn.SiLU(),
872
+ nn.Dropout(p=dropout),
873
+ zero_module(
874
+ nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
875
+ ),
876
+ )
877
+
878
+ if self.out_channels == channels:
879
+ self.skip_connection = nn.Identity()
880
+ elif use_conv:
881
+ self.skip_connection = conv_nd(
882
+ dims, channels, self.out_channels, 3, padding=1
883
+ )
884
+ else:
885
+ self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
886
+
887
+ if self.use_temporal_conv:
888
+ self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
889
+ # self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
890
+
891
+ def forward(self, x, emb, batch_size):
892
+ """
893
+ Apply the block to a Tensor, conditioned on a timestep embedding.
894
+ :param x: an [N x C x ...] Tensor of features.
895
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
896
+ :return: an [N x C x ...] Tensor of outputs.
897
+ """
898
+ return self._forward(x, emb, batch_size)
899
+
900
+ def _forward(self, x, emb, batch_size):
901
+ if self.updown:
902
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
903
+ h = in_rest(x)
904
+ h = self.h_upd(h)
905
+ x = self.x_upd(x)
906
+ h = in_conv(h)
907
+ else:
908
+ h = self.in_layers(x)
909
+ emb_out = self.emb_layers(emb).type(h.dtype)
910
+ while len(emb_out.shape) < len(h.shape):
911
+ emb_out = emb_out[..., None]
912
+ if self.use_scale_shift_norm:
913
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
914
+ scale, shift = th.chunk(emb_out, 2, dim=1)
915
+ h = out_norm(h) * (1 + scale) + shift
916
+ h = out_rest(h)
917
+ else:
918
+ h = h + emb_out
919
+ h = self.out_layers(h)
920
+ h = self.skip_connection(x) + h
921
+
922
+ if self.use_temporal_conv:
923
+ h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
924
+ h = self.temopral_conv(h)
925
+ # h = self.temopral_conv_2(h)
926
+ h = rearrange(h, 'b c f h w -> (b f) c h w')
927
+ return h
928
+
929
+ class Downsample(nn.Module):
930
+ """
931
+ A downsampling layer with an optional convolution.
932
+ :param channels: channels in the inputs and outputs.
933
+ :param use_conv: a bool determining if a convolution is applied.
934
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
935
+ downsampling occurs in the inner-two dimensions.
936
+ """
937
+
938
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
939
+ super().__init__()
940
+ self.channels = channels
941
+ self.out_channels = out_channels or channels
942
+ self.use_conv = use_conv
943
+ self.dims = dims
944
+ stride = 2 if dims != 3 else (1, 2, 2)
945
+ if use_conv:
946
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
947
+ else:
948
+ assert self.channels == self.out_channels
949
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
950
+
951
+ def forward(self, x):
952
+ assert x.shape[1] == self.channels
953
+ return self.op(x)
954
+
955
+ class Resample(nn.Module):
956
+
957
+ def __init__(self, in_dim, out_dim, mode):
958
+ assert mode in ['none', 'upsample', 'downsample']
959
+ super(Resample, self).__init__()
960
+ self.in_dim = in_dim
961
+ self.out_dim = out_dim
962
+ self.mode = mode
963
+
964
+ def forward(self, x, reference=None):
965
+ if self.mode == 'upsample':
966
+ assert reference is not None
967
+ x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
968
+ elif self.mode == 'downsample':
969
+ x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:]))
970
+ return x
971
+
972
+ class ResidualBlock(nn.Module):
973
+
974
+ def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True,
975
+ mode='none', dropout=0.0):
976
+ super(ResidualBlock, self).__init__()
977
+ self.in_dim = in_dim
978
+ self.embed_dim = embed_dim
979
+ self.out_dim = out_dim
980
+ self.use_scale_shift_norm = use_scale_shift_norm
981
+ self.mode = mode
982
+
983
+ # layers
984
+ self.layer1 = nn.Sequential(
985
+ nn.GroupNorm(32, in_dim),
986
+ nn.SiLU(),
987
+ nn.Conv2d(in_dim, out_dim, 3, padding=1))
988
+ self.resample = Resample(in_dim, in_dim, mode)
989
+ self.embedding = nn.Sequential(
990
+ nn.SiLU(),
991
+ nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim))
992
+ self.layer2 = nn.Sequential(
993
+ nn.GroupNorm(32, out_dim),
994
+ nn.SiLU(),
995
+ nn.Dropout(dropout),
996
+ nn.Conv2d(out_dim, out_dim, 3, padding=1))
997
+ self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1)
998
+
999
+ # zero out the last layer params
1000
+ nn.init.zeros_(self.layer2[-1].weight)
1001
+
1002
+ def forward(self, x, e, reference=None):
1003
+ identity = self.resample(x, reference)
1004
+ x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
1005
+ e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
1006
+ if self.use_scale_shift_norm:
1007
+ scale, shift = e.chunk(2, dim=1)
1008
+ x = self.layer2[0](x) * (1 + scale) + shift
1009
+ x = self.layer2[1:](x)
1010
+ else:
1011
+ x = x + e
1012
+ x = self.layer2(x)
1013
+ x = x + self.shortcut(identity)
1014
+ return x
1015
+
1016
+ class AttentionBlock(nn.Module):
1017
+
1018
+ def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
1019
+ # consider head_dim first, then num_heads
1020
+ num_heads = dim // head_dim if head_dim else num_heads
1021
+ head_dim = dim // num_heads
1022
+ assert num_heads * head_dim == dim
1023
+ super(AttentionBlock, self).__init__()
1024
+ self.dim = dim
1025
+ self.context_dim = context_dim
1026
+ self.num_heads = num_heads
1027
+ self.head_dim = head_dim
1028
+ self.scale = math.pow(head_dim, -0.25)
1029
+
1030
+ # layers
1031
+ self.norm = nn.GroupNorm(32, dim)
1032
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
1033
+ if context_dim is not None:
1034
+ self.context_kv = nn.Linear(context_dim, dim * 2)
1035
+ self.proj = nn.Conv2d(dim, dim, 1)
1036
+
1037
+ # zero out the last layer params
1038
+ nn.init.zeros_(self.proj.weight)
1039
+
1040
+ def forward(self, x, context=None):
1041
+ r"""x: [B, C, H, W].
1042
+ context: [B, L, C] or None.
1043
+ """
1044
+ identity = x
1045
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
1046
+
1047
+ # compute query, key, value
1048
+ x = self.norm(x)
1049
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
1050
+ if context is not None:
1051
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
1052
+ k = torch.cat([ck, k], dim=-1)
1053
+ v = torch.cat([cv, v], dim=-1)
1054
+
1055
+ # compute attention
1056
+ attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
1057
+ attn = F.softmax(attn, dim=-1)
1058
+
1059
+ # gather context
1060
+ x = torch.matmul(v, attn.transpose(-1, -2))
1061
+ x = x.reshape(b, c, h, w)
1062
+
1063
+ # output
1064
+ x = self.proj(x)
1065
+ return x + identity
1066
+
1067
+
1068
+ class TemporalAttentionBlock(nn.Module):
1069
+ def __init__(
1070
+ self,
1071
+ dim,
1072
+ heads = 4,
1073
+ dim_head = 32,
1074
+ rotary_emb = None,
1075
+ use_image_dataset = False,
1076
+ use_sim_mask = False
1077
+ ):
1078
+ super().__init__()
1079
+ # consider num_heads first, as pos_bias needs fixed num_heads
1080
+ # heads = dim // dim_head if dim_head else heads
1081
+ dim_head = dim // heads
1082
+ assert heads * dim_head == dim
1083
+ self.use_image_dataset = use_image_dataset
1084
+ self.use_sim_mask = use_sim_mask
1085
+
1086
+ self.scale = dim_head ** -0.5
1087
+ self.heads = heads
1088
+ hidden_dim = dim_head * heads
1089
+
1090
+ self.norm = nn.GroupNorm(32, dim)
1091
+ self.rotary_emb = rotary_emb
1092
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False)
1093
+ self.to_out = nn.Linear(hidden_dim, dim)#, bias = False)
1094
+
1095
+ # nn.init.zeros_(self.to_out.weight)
1096
+ # nn.init.zeros_(self.to_out.bias)
1097
+
1098
+ def forward(
1099
+ self,
1100
+ x,
1101
+ pos_bias = None,
1102
+ focus_present_mask = None,
1103
+ video_mask = None
1104
+ ):
1105
+
1106
+ identity = x
1107
+ n, height, device = x.shape[2], x.shape[-2], x.device
1108
+
1109
+ x = self.norm(x)
1110
+ x = rearrange(x, 'b c f h w -> b (h w) f c')
1111
+
1112
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
1113
+
1114
+ if exists(focus_present_mask) and focus_present_mask.all():
1115
+ # if all batch samples are focusing on present
1116
+ # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
1117
+ values = qkv[-1]
1118
+ out = self.to_out(values)
1119
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
1120
+
1121
+ return out + identity
1122
+
1123
+ # split out heads
1124
+ # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads)
1125
+ # shape [b (hw) h n c/h], n=f
1126
+ q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads)
1127
+ k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads)
1128
+ v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads)
1129
+
1130
+
1131
+ # scale
1132
+
1133
+ q = q * self.scale
1134
+
1135
+ # rotate positions into queries and keys for time attention
1136
+ if exists(self.rotary_emb):
1137
+ q = self.rotary_emb.rotate_queries_or_keys(q)
1138
+ k = self.rotary_emb.rotate_queries_or_keys(k)
1139
+
1140
+ # similarity
1141
+ # shape [b (hw) h n n], n=f
1142
+ sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
1143
+
1144
+ # relative positional bias
1145
+
1146
+ if exists(pos_bias):
1147
+ # print(sim.shape,pos_bias.shape)
1148
+ sim = sim + pos_bias
1149
+
1150
+ if (focus_present_mask is None and video_mask is not None):
1151
+ #video_mask: [B, n]
1152
+ mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n]
1153
+ mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n]
1154
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
1155
+ elif exists(focus_present_mask) and not (~focus_present_mask).all():
1156
+ attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool)
1157
+ attend_self_mask = torch.eye(n, device = device, dtype = torch.bool)
1158
+
1159
+ mask = torch.where(
1160
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
1161
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
1162
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
1163
+ )
1164
+
1165
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
1166
+
1167
+ if self.use_sim_mask:
1168
+ sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0)
1169
+ sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
1170
+
1171
+ # numerical stability
1172
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
1173
+ attn = sim.softmax(dim = -1)
1174
+
1175
+ # aggregate values
1176
+
1177
+ out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
1178
+ out = rearrange(out, '... h n d -> ... n (h d)')
1179
+ out = self.to_out(out)
1180
+
1181
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
1182
+
1183
+ if self.use_image_dataset:
1184
+ out = identity + 0*out
1185
+ else:
1186
+ out = identity + out
1187
+ return out
1188
+
1189
+ class TemporalTransformer(nn.Module):
1190
+ """
1191
+ Transformer block for image-like data.
1192
+ First, project the input (aka embedding)
1193
+ and reshape to b, t, d.
1194
+ Then apply standard transformer action.
1195
+ Finally, reshape to image
1196
+ """
1197
+ def __init__(self, in_channels, n_heads, d_head,
1198
+ depth=1, dropout=0., context_dim=None,
1199
+ disable_self_attn=False, use_linear=False,
1200
+ use_checkpoint=True, only_self_att=True, multiply_zero=False):
1201
+ super().__init__()
1202
+ self.multiply_zero = multiply_zero
1203
+ self.only_self_att = only_self_att
1204
+ self.use_adaptor = False
1205
+ if self.only_self_att:
1206
+ context_dim = None
1207
+ if not isinstance(context_dim, list):
1208
+ context_dim = [context_dim]
1209
+ self.in_channels = in_channels
1210
+ inner_dim = n_heads * d_head
1211
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
1212
+ if not use_linear:
1213
+ self.proj_in = nn.Conv1d(in_channels,
1214
+ inner_dim,
1215
+ kernel_size=1,
1216
+ stride=1,
1217
+ padding=0)
1218
+ else:
1219
+ self.proj_in = nn.Linear(in_channels, inner_dim)
1220
+ if self.use_adaptor:
1221
+ self.adaptor_in = nn.Linear(frames, frames)
1222
+
1223
+ self.transformer_blocks = nn.ModuleList(
1224
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
1225
+ checkpoint=use_checkpoint)
1226
+ for d in range(depth)]
1227
+ )
1228
+ if not use_linear:
1229
+ self.proj_out = zero_module(nn.Conv1d(inner_dim,
1230
+ in_channels,
1231
+ kernel_size=1,
1232
+ stride=1,
1233
+ padding=0))
1234
+ else:
1235
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
1236
+ if self.use_adaptor:
1237
+ self.adaptor_out = nn.Linear(frames, frames)
1238
+ self.use_linear = use_linear
1239
+
1240
+ def forward(self, x, context=None):
1241
+ # note: if no context is given, cross-attention defaults to self-attention
1242
+ if self.only_self_att:
1243
+ context = None
1244
+ if not isinstance(context, list):
1245
+ context = [context]
1246
+ b, c, f, h, w = x.shape
1247
+ x_in = x
1248
+ x = self.norm(x)
1249
+
1250
+ if not self.use_linear:
1251
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
1252
+ x = self.proj_in(x)
1253
+ # [16384, 16, 320]
1254
+ if self.use_linear:
1255
+ x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
1256
+ x = self.proj_in(x)
1257
+
1258
+ if self.only_self_att:
1259
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
1260
+ for i, block in enumerate(self.transformer_blocks):
1261
+ x = block(x)
1262
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
1263
+ else:
1264
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
1265
+ for i, block in enumerate(self.transformer_blocks):
1266
+ # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1267
+ context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
1268
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
1269
+ for j in range(b):
1270
+ context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1271
+ x[j] = block(x[j], context=context_i_j)
1272
+
1273
+ if self.use_linear:
1274
+ x = self.proj_out(x)
1275
+ x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
1276
+ if not self.use_linear:
1277
+ # x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
1278
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
1279
+ x = self.proj_out(x)
1280
+ x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
1281
+
1282
+ if self.multiply_zero:
1283
+ x = 0.0 * x + x_in
1284
+ else:
1285
+ x = x + x_in
1286
+ return x
1287
+
1288
+
1289
+ class TemporalTransformerWithAdapter(nn.Module):
1290
+ """
1291
+ Transformer block for image-like data.
1292
+ First, project the input (aka embedding)
1293
+ and reshape to b, t, d.
1294
+ Then apply standard transformer action.
1295
+ Finally, reshape to image
1296
+ """
1297
+ def __init__(self, in_channels, n_heads, d_head,
1298
+ depth=1, dropout=0., context_dim=None,
1299
+ disable_self_attn=False, use_linear=False,
1300
+ use_checkpoint=True, only_self_att=True, multiply_zero=False,
1301
+ adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'],
1302
+ adapter_hidden_dim=None, adapter_condition_dim=None):
1303
+ super().__init__()
1304
+ self.multiply_zero = multiply_zero
1305
+ self.only_self_att = only_self_att
1306
+ self.use_adaptor = False
1307
+ if self.only_self_att:
1308
+ context_dim = None
1309
+ if not isinstance(context_dim, list):
1310
+ context_dim = [context_dim]
1311
+ self.in_channels = in_channels
1312
+ inner_dim = n_heads * d_head
1313
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
1314
+ if not use_linear:
1315
+ self.proj_in = nn.Conv1d(in_channels,
1316
+ inner_dim,
1317
+ kernel_size=1,
1318
+ stride=1,
1319
+ padding=0)
1320
+ else:
1321
+ self.proj_in = nn.Linear(in_channels, inner_dim)
1322
+ if self.use_adaptor:
1323
+ self.adaptor_in = nn.Linear(frames, frames)
1324
+
1325
+ self.transformer_blocks = nn.ModuleList(
1326
+ [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
1327
+ checkpoint=use_checkpoint, adapter_list=adapter_list, adapter_position_list=adapter_position_list,
1328
+ adapter_hidden_dim=adapter_hidden_dim, adapter_condition_dim=adapter_condition_dim)
1329
+ for d in range(depth)]
1330
+ )
1331
+ if not use_linear:
1332
+ self.proj_out = zero_module(nn.Conv1d(inner_dim,
1333
+ in_channels,
1334
+ kernel_size=1,
1335
+ stride=1,
1336
+ padding=0))
1337
+ else:
1338
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
1339
+ if self.use_adaptor:
1340
+ self.adaptor_out = nn.Linear(frames, frames)
1341
+ self.use_linear = use_linear
1342
+
1343
+ def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
1344
+ # note: if no context is given, cross-attention defaults to self-attention
1345
+ if self.only_self_att:
1346
+ context = None
1347
+ if not isinstance(context, list):
1348
+ context = [context]
1349
+ b, c, f, h, w = x.shape
1350
+ x_in = x
1351
+ x = self.norm(x)
1352
+
1353
+ if not self.use_linear:
1354
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
1355
+ x = self.proj_in(x)
1356
+ # [16384, 16, 320]
1357
+ if self.use_linear:
1358
+ x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
1359
+ x = self.proj_in(x)
1360
+
1361
+ if adapter_condition is not None:
1362
+ b_cond, f_cond, c_cond = adapter_condition.shape
1363
+ adapter_condition = adapter_condition.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1, 1)
1364
+ adapter_condition = adapter_condition.reshape(b_cond*h*w, f_cond, c_cond)
1365
+
1366
+ if self.only_self_att:
1367
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
1368
+ for i, block in enumerate(self.transformer_blocks):
1369
+ x = block(x, adapter_condition=adapter_condition, adapter_condition_lam=adapter_condition_lam)
1370
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
1371
+ else:
1372
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
1373
+ for i, block in enumerate(self.transformer_blocks):
1374
+ # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1375
+ context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
1376
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
1377
+ for j in range(b):
1378
+ context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1379
+ x[j] = block(x[j], context=context_i_j)
1380
+
1381
+ if self.use_linear:
1382
+ x = self.proj_out(x)
1383
+ x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
1384
+ if not self.use_linear:
1385
+ # x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
1386
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
1387
+ x = self.proj_out(x)
1388
+ x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
1389
+
1390
+ if self.multiply_zero:
1391
+ x = 0.0 * x + x_in
1392
+ else:
1393
+ x = x + x_in
1394
+ return x
1395
+
1396
+ class Attention(nn.Module):
1397
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
1398
+ super().__init__()
1399
+ inner_dim = dim_head * heads
1400
+ project_out = not (heads == 1 and dim_head == dim)
1401
+
1402
+ self.heads = heads
1403
+ self.scale = dim_head ** -0.5
1404
+
1405
+ self.attend = nn.Softmax(dim = -1)
1406
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
1407
+
1408
+ self.to_out = nn.Sequential(
1409
+ nn.Linear(inner_dim, dim),
1410
+ nn.Dropout(dropout)
1411
+ ) if project_out else nn.Identity()
1412
+
1413
+ def forward(self, x):
1414
+ b, n, _, h = *x.shape, self.heads
1415
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
1416
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
1417
+
1418
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
1419
+
1420
+ attn = self.attend(dots)
1421
+
1422
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
1423
+ out = rearrange(out, 'b h n d -> b n (h d)')
1424
+ return self.to_out(out)
1425
+
1426
+ class PreNormattention(nn.Module):
1427
+ def __init__(self, dim, fn):
1428
+ super().__init__()
1429
+ self.norm = nn.LayerNorm(dim)
1430
+ self.fn = fn
1431
+ def forward(self, x, **kwargs):
1432
+ return self.fn(self.norm(x), **kwargs) + x
1433
+
1434
+ class TransformerV2(nn.Module):
1435
+ def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
1436
+ super().__init__()
1437
+ self.layers = nn.ModuleList([])
1438
+ self.depth = depth
1439
+ for _ in range(depth):
1440
+ self.layers.append(nn.ModuleList([
1441
+ PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
1442
+ FeedForward(dim, mlp_dim, dropout = dropout_ffn),
1443
+ ]))
1444
+ def forward(self, x):
1445
+ # if self.depth
1446
+ for attn, ff in self.layers[:1]:
1447
+ x = attn(x)
1448
+ x = ff(x) + x
1449
+ if self.depth > 1:
1450
+ for attn, ff in self.layers[1:]:
1451
+ x = attn(x)
1452
+ x = ff(x) + x
1453
+ return x
1454
+
1455
+ class TemporalTransformer_attemask(nn.Module):
1456
+ """
1457
+ Transformer block for image-like data.
1458
+ First, project the input (aka embedding)
1459
+ and reshape to b, t, d.
1460
+ Then apply standard transformer action.
1461
+ Finally, reshape to image
1462
+ """
1463
+ def __init__(self, in_channels, n_heads, d_head,
1464
+ depth=1, dropout=0., context_dim=None,
1465
+ disable_self_attn=False, use_linear=False,
1466
+ use_checkpoint=True, only_self_att=True, multiply_zero=False):
1467
+ super().__init__()
1468
+ self.multiply_zero = multiply_zero
1469
+ self.only_self_att = only_self_att
1470
+ self.use_adaptor = False
1471
+ if self.only_self_att:
1472
+ context_dim = None
1473
+ if not isinstance(context_dim, list):
1474
+ context_dim = [context_dim]
1475
+ self.in_channels = in_channels
1476
+ inner_dim = n_heads * d_head
1477
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
1478
+ if not use_linear:
1479
+ self.proj_in = nn.Conv1d(in_channels,
1480
+ inner_dim,
1481
+ kernel_size=1,
1482
+ stride=1,
1483
+ padding=0)
1484
+ else:
1485
+ self.proj_in = nn.Linear(in_channels, inner_dim)
1486
+ if self.use_adaptor:
1487
+ self.adaptor_in = nn.Linear(frames, frames)
1488
+
1489
+ self.transformer_blocks = nn.ModuleList(
1490
+ [BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
1491
+ checkpoint=use_checkpoint)
1492
+ for d in range(depth)]
1493
+ )
1494
+ if not use_linear:
1495
+ self.proj_out = zero_module(nn.Conv1d(inner_dim,
1496
+ in_channels,
1497
+ kernel_size=1,
1498
+ stride=1,
1499
+ padding=0))
1500
+ else:
1501
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
1502
+ if self.use_adaptor:
1503
+ self.adaptor_out = nn.Linear(frames, frames)
1504
+ self.use_linear = use_linear
1505
+
1506
+ def forward(self, x, context=None):
1507
+ # note: if no context is given, cross-attention defaults to self-attention
1508
+ if self.only_self_att:
1509
+ context = None
1510
+ if not isinstance(context, list):
1511
+ context = [context]
1512
+ b, c, f, h, w = x.shape
1513
+ x_in = x
1514
+ x = self.norm(x)
1515
+
1516
+ if not self.use_linear:
1517
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
1518
+ x = self.proj_in(x)
1519
+ # [16384, 16, 320]
1520
+ if self.use_linear:
1521
+ x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
1522
+ x = self.proj_in(x)
1523
+
1524
+ if self.only_self_att:
1525
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
1526
+ for i, block in enumerate(self.transformer_blocks):
1527
+ x = block(x)
1528
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
1529
+ else:
1530
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
1531
+ for i, block in enumerate(self.transformer_blocks):
1532
+ # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1533
+ context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
1534
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
1535
+ for j in range(b):
1536
+ context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
1537
+ x[j] = block(x[j], context=context_i_j)
1538
+
1539
+ if self.use_linear:
1540
+ x = self.proj_out(x)
1541
+ x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
1542
+ if not self.use_linear:
1543
+ # x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
1544
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
1545
+ x = self.proj_out(x)
1546
+ x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
1547
+
1548
+ if self.multiply_zero:
1549
+ x = 0.0 * x + x_in
1550
+ else:
1551
+ x = x + x_in
1552
+ return x
1553
+
1554
+ class TemporalAttentionMultiBlock(nn.Module):
1555
+ def __init__(
1556
+ self,
1557
+ dim,
1558
+ heads=4,
1559
+ dim_head=32,
1560
+ rotary_emb=None,
1561
+ use_image_dataset=False,
1562
+ use_sim_mask=False,
1563
+ temporal_attn_times=1,
1564
+ ):
1565
+ super().__init__()
1566
+ self.att_layers = nn.ModuleList(
1567
+ [TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask)
1568
+ for _ in range(temporal_attn_times)]
1569
+ )
1570
+
1571
+ def forward(
1572
+ self,
1573
+ x,
1574
+ pos_bias = None,
1575
+ focus_present_mask = None,
1576
+ video_mask = None
1577
+ ):
1578
+ for layer in self.att_layers:
1579
+ x = layer(x, pos_bias, focus_present_mask, video_mask)
1580
+ return x
1581
+
1582
+
1583
+ class InitTemporalConvBlock(nn.Module):
1584
+
1585
+ def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False):
1586
+ super(InitTemporalConvBlock, self).__init__()
1587
+ if out_dim is None:
1588
+ out_dim = in_dim#int(1.5*in_dim)
1589
+ self.in_dim = in_dim
1590
+ self.out_dim = out_dim
1591
+ self.use_image_dataset = use_image_dataset
1592
+
1593
+ # conv layers
1594
+ self.conv = nn.Sequential(
1595
+ nn.GroupNorm(32, out_dim),
1596
+ nn.SiLU(),
1597
+ nn.Dropout(dropout),
1598
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
1599
+
1600
+ # zero out the last layer params,so the conv block is identity
1601
+ # nn.init.zeros_(self.conv1[-1].weight)
1602
+ # nn.init.zeros_(self.conv1[-1].bias)
1603
+ nn.init.zeros_(self.conv[-1].weight)
1604
+ nn.init.zeros_(self.conv[-1].bias)
1605
+
1606
+ def forward(self, x):
1607
+ identity = x
1608
+ x = self.conv(x)
1609
+ if self.use_image_dataset:
1610
+ x = identity + 0*x
1611
+ else:
1612
+ x = identity + x
1613
+ return x
1614
+
1615
+ class TemporalConvBlock(nn.Module):
1616
+
1617
+ def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False):
1618
+ super(TemporalConvBlock, self).__init__()
1619
+ if out_dim is None:
1620
+ out_dim = in_dim#int(1.5*in_dim)
1621
+ self.in_dim = in_dim
1622
+ self.out_dim = out_dim
1623
+ self.use_image_dataset = use_image_dataset
1624
+
1625
+ # conv layers
1626
+ self.conv1 = nn.Sequential(
1627
+ nn.GroupNorm(32, in_dim),
1628
+ nn.SiLU(),
1629
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
1630
+ self.conv2 = nn.Sequential(
1631
+ nn.GroupNorm(32, out_dim),
1632
+ nn.SiLU(),
1633
+ nn.Dropout(dropout),
1634
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
1635
+
1636
+ # zero out the last layer params,so the conv block is identity
1637
+ # nn.init.zeros_(self.conv1[-1].weight)
1638
+ # nn.init.zeros_(self.conv1[-1].bias)
1639
+ nn.init.zeros_(self.conv2[-1].weight)
1640
+ nn.init.zeros_(self.conv2[-1].bias)
1641
+
1642
+ def forward(self, x):
1643
+ identity = x
1644
+ x = self.conv1(x)
1645
+ x = self.conv2(x)
1646
+ if self.use_image_dataset:
1647
+ x = identity + 0*x
1648
+ else:
1649
+ x = identity + x
1650
+ return x
1651
+
1652
+ class TemporalConvBlock_v2(nn.Module):
1653
+ def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False):
1654
+ super(TemporalConvBlock_v2, self).__init__()
1655
+ if out_dim is None:
1656
+ out_dim = in_dim # int(1.5*in_dim)
1657
+ self.in_dim = in_dim
1658
+ self.out_dim = out_dim
1659
+ self.use_image_dataset = use_image_dataset
1660
+
1661
+ # conv layers
1662
+ self.conv1 = nn.Sequential(
1663
+ nn.GroupNorm(32, in_dim),
1664
+ nn.SiLU(),
1665
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
1666
+ self.conv2 = nn.Sequential(
1667
+ nn.GroupNorm(32, out_dim),
1668
+ nn.SiLU(),
1669
+ nn.Dropout(dropout),
1670
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
1671
+ self.conv3 = nn.Sequential(
1672
+ nn.GroupNorm(32, out_dim),
1673
+ nn.SiLU(),
1674
+ nn.Dropout(dropout),
1675
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
1676
+ self.conv4 = nn.Sequential(
1677
+ nn.GroupNorm(32, out_dim),
1678
+ nn.SiLU(),
1679
+ nn.Dropout(dropout),
1680
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
1681
+
1682
+ # zero out the last layer params,so the conv block is identity
1683
+ nn.init.zeros_(self.conv4[-1].weight)
1684
+ nn.init.zeros_(self.conv4[-1].bias)
1685
+
1686
+ def forward(self, x):
1687
+ identity = x
1688
+ x = self.conv1(x)
1689
+ x = self.conv2(x)
1690
+ x = self.conv3(x)
1691
+ x = self.conv4(x)
1692
+
1693
+ if self.use_image_dataset:
1694
+ x = identity + 0.0 * x
1695
+ else:
1696
+ x = identity + x
1697
+ return x
1698
+
1699
+
1700
+ class DropPath(nn.Module):
1701
+ r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
1702
+ """
1703
+ def __init__(self, p):
1704
+ super(DropPath, self).__init__()
1705
+ self.p = p
1706
+
1707
+ def forward(self, *args, zero=None, keep=None):
1708
+ if not self.training:
1709
+ return args[0] if len(args) == 1 else args
1710
+
1711
+ # params
1712
+ x = args[0]
1713
+ b = x.size(0)
1714
+ n = (torch.rand(b) < self.p).sum()
1715
+
1716
+ # non-zero and non-keep mask
1717
+ mask = x.new_ones(b, dtype=torch.bool)
1718
+ if keep is not None:
1719
+ mask[keep] = False
1720
+ if zero is not None:
1721
+ mask[zero] = False
1722
+
1723
+ # drop-path index
1724
+ index = torch.where(mask)[0]
1725
+ index = index[torch.randperm(len(index))[:n]]
1726
+ if zero is not None:
1727
+ index = torch.cat([index, torch.where(zero)[0]], dim=0)
1728
+
1729
+ # drop-path multiplier
1730
+ multiplier = x.new_ones(b)
1731
+ multiplier[index] = 0.0
1732
+ output = tuple(u * self.broadcast(multiplier, u) for u in args)
1733
+ return output[0] if len(args) == 1 else output
1734
+
1735
+ def broadcast(self, src, dst):
1736
+ assert src.size(0) == dst.size(0)
1737
+ shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
1738
+ return src.view(shape)
1739
+
1740
+
1741
+
UniAnimate/utils/__init__.py ADDED
File without changes
UniAnimate/utils/assign_cfg.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml
2
+ from copy import deepcopy, copy
3
+
4
+
5
+ # def get prior and ldm config
6
+ def assign_prior_mudule_cfg(cfg):
7
+ '''
8
+ '''
9
+ #
10
+ prior_cfg = deepcopy(cfg)
11
+ vldm_cfg = deepcopy(cfg)
12
+
13
+ with open(cfg.prior_cfg, 'r') as f:
14
+ _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
15
+ # _cfg_update = _cfg_update.cfg_dict
16
+ for k, v in _cfg_update.items():
17
+ if isinstance(v, dict) and k in cfg:
18
+ prior_cfg[k].update(v)
19
+ else:
20
+ prior_cfg[k] = v
21
+
22
+ with open(cfg.vldm_cfg, 'r') as f:
23
+ _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
24
+ # _cfg_update = _cfg_update.cfg_dict
25
+ for k, v in _cfg_update.items():
26
+ if isinstance(v, dict) and k in cfg:
27
+ vldm_cfg[k].update(v)
28
+ else:
29
+ vldm_cfg[k] = v
30
+
31
+ return prior_cfg, vldm_cfg
32
+
33
+
34
+ # def get prior and ldm config
35
+ def assign_vldm_vsr_mudule_cfg(cfg):
36
+ '''
37
+ '''
38
+ #
39
+ vldm_cfg = deepcopy(cfg)
40
+ vsr_cfg = deepcopy(cfg)
41
+
42
+ with open(cfg.vldm_cfg, 'r') as f:
43
+ _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
44
+ # _cfg_update = _cfg_update.cfg_dict
45
+ for k, v in _cfg_update.items():
46
+ if isinstance(v, dict) and k in cfg:
47
+ vldm_cfg[k].update(v)
48
+ else:
49
+ vldm_cfg[k] = v
50
+
51
+ with open(cfg.vsr_cfg, 'r') as f:
52
+ _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
53
+ # _cfg_update = _cfg_update.cfg_dict
54
+ for k, v in _cfg_update.items():
55
+ if isinstance(v, dict) and k in cfg:
56
+ vsr_cfg[k].update(v)
57
+ else:
58
+ vsr_cfg[k] = v
59
+
60
+ return vldm_cfg, vsr_cfg
61
+
62
+
63
+ # def get prior and ldm config
64
+ def assign_signle_cfg(cfg, _cfg_update, tname):
65
+ '''
66
+ '''
67
+ #
68
+ vldm_cfg = deepcopy(cfg)
69
+ if os.path.exists(_cfg_update[tname]):
70
+ with open(_cfg_update[tname], 'r') as f:
71
+ _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
72
+ # _cfg_update = _cfg_update.cfg_dict
73
+ for k, v in _cfg_update.items():
74
+ if isinstance(v, dict) and k in cfg:
75
+ vldm_cfg[k].update(v)
76
+ else:
77
+ vldm_cfg[k] = v
78
+ return vldm_cfg
UniAnimate/utils/config.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import json
4
+ import copy
5
+ import argparse
6
+
7
+ import utils.logging as logging
8
+ logger = logging.get_logger(__name__)
9
+
10
+ class Config(object):
11
+ def __init__(self, load=True, cfg_dict=None, cfg_level=None):
12
+ self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
13
+ if load:
14
+ self.args = self._parse_args()
15
+ logger.info("Loading config from {}.".format(self.args.cfg_file))
16
+ self.need_initialization = True
17
+ cfg_base = self._load_yaml(self.args) # self._initialize_cfg()
18
+ cfg_dict = self._load_yaml(self.args)
19
+ cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
20
+ cfg_dict = self._update_from_args(cfg_dict)
21
+ self.cfg_dict = cfg_dict
22
+ self._update_dict(cfg_dict)
23
+
24
+ def _parse_args(self):
25
+ parser = argparse.ArgumentParser(
26
+ description="Argparser for configuring [code base name to think of] codebase"
27
+ )
28
+ parser.add_argument(
29
+ "--cfg",
30
+ dest="cfg_file",
31
+ help="Path to the configuration file",
32
+ default='configs/UniAnimate_infer.yaml'
33
+ )
34
+ parser.add_argument(
35
+ "--init_method",
36
+ help="Initialization method, includes TCP or shared file-system",
37
+ default="tcp://localhost:9999",
38
+ type=str,
39
+ )
40
+ parser.add_argument(
41
+ '--debug',
42
+ action='store_true',
43
+ default=False,
44
+ help='Into debug information'
45
+ )
46
+ parser.add_argument(
47
+ "opts",
48
+ help="other configurations",
49
+ default=None,
50
+ nargs=argparse.REMAINDER)
51
+ return parser.parse_args()
52
+
53
+ def _path_join(self, path_list):
54
+ path = ""
55
+ for p in path_list:
56
+ path+= p + '/'
57
+ return path[:-1]
58
+
59
+ def _update_from_args(self, cfg_dict):
60
+ args = self.args
61
+ for var in vars(args):
62
+ cfg_dict[var] = getattr(args, var)
63
+ return cfg_dict
64
+
65
+ def _initialize_cfg(self):
66
+ if self.need_initialization:
67
+ self.need_initialization = False
68
+ if os.path.exists('./configs/base.yaml'):
69
+ with open("./configs/base.yaml", 'r') as f:
70
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
71
+ else:
72
+ with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
73
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
74
+ return cfg
75
+
76
+ def _load_yaml(self, args, file_name=""):
77
+ assert args.cfg_file is not None
78
+ if not file_name == "": # reading from base file
79
+ with open(file_name, 'r') as f:
80
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
81
+ else:
82
+ if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
83
+ args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
84
+ with open(args.cfg_file, 'r') as f:
85
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
86
+ file_name = args.cfg_file
87
+
88
+ if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
89
+ # return cfg if the base file is being accessed
90
+ cfg = self._merge_cfg_from_command_update(args, cfg)
91
+ return cfg
92
+
93
+ if "_BASE" in cfg.keys():
94
+ if cfg["_BASE"][1] == '.':
95
+ prev_count = cfg["_BASE"].count('..')
96
+ cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
97
+ else:
98
+ cfg_base_file = cfg["_BASE"].replace(
99
+ "./",
100
+ args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
101
+ )
102
+ cfg_base = self._load_yaml(args, cfg_base_file)
103
+ cfg = self._merge_cfg_from_base(cfg_base, cfg)
104
+ else:
105
+ if "_BASE_RUN" in cfg.keys():
106
+ if cfg["_BASE_RUN"][1] == '.':
107
+ prev_count = cfg["_BASE_RUN"].count('..')
108
+ cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
109
+ else:
110
+ cfg_base_file = cfg["_BASE_RUN"].replace(
111
+ "./",
112
+ args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
113
+ )
114
+ cfg_base = self._load_yaml(args, cfg_base_file)
115
+ cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
116
+ if "_BASE_MODEL" in cfg.keys():
117
+ if cfg["_BASE_MODEL"][1] == '.':
118
+ prev_count = cfg["_BASE_MODEL"].count('..')
119
+ cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
120
+ else:
121
+ cfg_base_file = cfg["_BASE_MODEL"].replace(
122
+ "./",
123
+ args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
124
+ )
125
+ cfg_base = self._load_yaml(args, cfg_base_file)
126
+ cfg = self._merge_cfg_from_base(cfg_base, cfg)
127
+ cfg = self._merge_cfg_from_command(args, cfg)
128
+ return cfg
129
+
130
+ def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
131
+ for k,v in cfg_new.items():
132
+ if k in cfg_base.keys():
133
+ if isinstance(v, dict):
134
+ self._merge_cfg_from_base(cfg_base[k], v)
135
+ else:
136
+ cfg_base[k] = v
137
+ else:
138
+ if "BASE" not in k or preserve_base:
139
+ cfg_base[k] = v
140
+ return cfg_base
141
+
142
+ def _merge_cfg_from_command_update(self, args, cfg):
143
+ if len(args.opts) == 0:
144
+ return cfg
145
+
146
+ assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
147
+ args.opts, len(args.opts)
148
+ )
149
+ keys = args.opts[0::2]
150
+ vals = args.opts[1::2]
151
+
152
+ for key, val in zip(keys, vals):
153
+ cfg[key] = val
154
+
155
+ return cfg
156
+
157
+ def _merge_cfg_from_command(self, args, cfg):
158
+ assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
159
+ args.opts, len(args.opts)
160
+ )
161
+ keys = args.opts[0::2]
162
+ vals = args.opts[1::2]
163
+
164
+ # maximum supported depth 3
165
+ for idx, key in enumerate(keys):
166
+ key_split = key.split('.')
167
+ assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
168
+ len(key_split)
169
+ )
170
+ assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
171
+ key_split[0]
172
+ )
173
+ if len(key_split) == 2:
174
+ assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
175
+ key
176
+ )
177
+ elif len(key_split) == 3:
178
+ assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
179
+ key
180
+ )
181
+ assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
182
+ key
183
+ )
184
+ elif len(key_split) == 4:
185
+ assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
186
+ key
187
+ )
188
+ assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
189
+ key
190
+ )
191
+ assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
192
+ key
193
+ )
194
+ if len(key_split) == 1:
195
+ cfg[key_split[0]] = vals[idx]
196
+ elif len(key_split) == 2:
197
+ cfg[key_split[0]][key_split[1]] = vals[idx]
198
+ elif len(key_split) == 3:
199
+ cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
200
+ elif len(key_split) == 4:
201
+ cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
202
+ return cfg
203
+
204
+ def _update_dict(self, cfg_dict):
205
+ def recur(key, elem):
206
+ if type(elem) is dict:
207
+ return key, Config(load=False, cfg_dict=elem, cfg_level=key)
208
+ else:
209
+ if type(elem) is str and elem[1:3]=="e-":
210
+ elem = float(elem)
211
+ return key, elem
212
+ dic = dict(recur(k, v) for k, v in cfg_dict.items())
213
+ self.__dict__.update(dic)
214
+
215
+ def get_args(self):
216
+ return self.args
217
+
218
+ def __repr__(self):
219
+ return "{}\n".format(self.dump())
220
+
221
+ def dump(self):
222
+ return json.dumps(self.cfg_dict, indent=2)
223
+
224
+ def deep_copy(self):
225
+ return copy.deepcopy(self)
226
+
227
+ if __name__ == '__main__':
228
+ # debug
229
+ cfg = Config(load=True)
230
+ print(cfg.DATA)
UniAnimate/utils/distributed.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.distributed as dist
7
+ import functools
8
+ import pickle
9
+ import numpy as np
10
+ from collections import OrderedDict
11
+ from torch.autograd import Function
12
+
13
+ __all__ = ['is_dist_initialized',
14
+ 'get_world_size',
15
+ 'get_rank',
16
+ 'new_group',
17
+ 'destroy_process_group',
18
+ 'barrier',
19
+ 'broadcast',
20
+ 'all_reduce',
21
+ 'reduce',
22
+ 'gather',
23
+ 'all_gather',
24
+ 'reduce_dict',
25
+ 'get_global_gloo_group',
26
+ 'generalized_all_gather',
27
+ 'generalized_gather',
28
+ 'scatter',
29
+ 'reduce_scatter',
30
+ 'send',
31
+ 'recv',
32
+ 'isend',
33
+ 'irecv',
34
+ 'shared_random_seed',
35
+ 'diff_all_gather',
36
+ 'diff_all_reduce',
37
+ 'diff_scatter',
38
+ 'diff_copy',
39
+ 'spherical_kmeans',
40
+ 'sinkhorn']
41
+
42
+ #-------------------------------- Distributed operations --------------------------------#
43
+
44
+ def is_dist_initialized():
45
+ return dist.is_available() and dist.is_initialized()
46
+
47
+ def get_world_size(group=None):
48
+ return dist.get_world_size(group) if is_dist_initialized() else 1
49
+
50
+ def get_rank(group=None):
51
+ return dist.get_rank(group) if is_dist_initialized() else 0
52
+
53
+ def new_group(ranks=None, **kwargs):
54
+ if is_dist_initialized():
55
+ return dist.new_group(ranks, **kwargs)
56
+ return None
57
+
58
+ def destroy_process_group():
59
+ if is_dist_initialized():
60
+ dist.destroy_process_group()
61
+
62
+ def barrier(group=None, **kwargs):
63
+ if get_world_size(group) > 1:
64
+ dist.barrier(group, **kwargs)
65
+
66
+ def broadcast(tensor, src, group=None, **kwargs):
67
+ if get_world_size(group) > 1:
68
+ return dist.broadcast(tensor, src, group, **kwargs)
69
+
70
+ def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
71
+ if get_world_size(group) > 1:
72
+ return dist.all_reduce(tensor, op, group, **kwargs)
73
+
74
+ def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
75
+ if get_world_size(group) > 1:
76
+ return dist.reduce(tensor, dst, op, group, **kwargs)
77
+
78
+ def gather(tensor, dst=0, group=None, **kwargs):
79
+ rank = get_rank() # global rank
80
+ world_size = get_world_size(group)
81
+ if world_size == 1:
82
+ return [tensor]
83
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None
84
+ dist.gather(tensor, tensor_list, dst, group, **kwargs)
85
+ return tensor_list
86
+
87
+ def all_gather(tensor, uniform_size=True, group=None, **kwargs):
88
+ world_size = get_world_size(group)
89
+ if world_size == 1:
90
+ return [tensor]
91
+ assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()'
92
+
93
+ if uniform_size:
94
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
95
+ dist.all_gather(tensor_list, tensor, group, **kwargs)
96
+ return tensor_list
97
+ else:
98
+ # collect tensor shapes across GPUs
99
+ shape = tuple(tensor.shape)
100
+ shape_list = generalized_all_gather(shape, group)
101
+
102
+ # flatten the tensor
103
+ tensor = tensor.reshape(-1)
104
+ size = int(np.prod(shape))
105
+ size_list = [int(np.prod(u)) for u in shape_list]
106
+ max_size = max(size_list)
107
+
108
+ # pad to maximum size
109
+ if size != max_size:
110
+ padding = tensor.new_zeros(max_size - size)
111
+ tensor = torch.cat([tensor, padding], dim=0)
112
+
113
+ # all_gather
114
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
115
+ dist.all_gather(tensor_list, tensor, group, **kwargs)
116
+
117
+ # reshape tensors
118
+ tensor_list = [t[:n].view(s) for t, n, s in zip(
119
+ tensor_list, size_list, shape_list)]
120
+ return tensor_list
121
+
122
+ @torch.no_grad()
123
+ def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
124
+ assert reduction in ['mean', 'sum']
125
+ world_size = get_world_size(group)
126
+ if world_size == 1:
127
+ return input_dict
128
+
129
+ # ensure that the orders of keys are consistent across processes
130
+ if isinstance(input_dict, OrderedDict):
131
+ keys = list(input_dict.keys)
132
+ else:
133
+ keys = sorted(input_dict.keys())
134
+ vals = [input_dict[key] for key in keys]
135
+ vals = torch.stack(vals, dim=0)
136
+ dist.reduce(vals, dst=0, group=group, **kwargs)
137
+ if dist.get_rank(group) == 0 and reduction == 'mean':
138
+ vals /= world_size
139
+ dist.broadcast(vals, src=0, group=group, **kwargs)
140
+ reduced_dict = type(input_dict)([
141
+ (key, val) for key, val in zip(keys, vals)])
142
+ return reduced_dict
143
+
144
+ @functools.lru_cache()
145
+ def get_global_gloo_group():
146
+ backend = dist.get_backend()
147
+ assert backend in ['gloo', 'nccl']
148
+ if backend == 'nccl':
149
+ return dist.new_group(backend='gloo')
150
+ else:
151
+ return dist.group.WORLD
152
+
153
+ def _serialize_to_tensor(data, group):
154
+ backend = dist.get_backend(group)
155
+ assert backend in ['gloo', 'nccl']
156
+ device = torch.device('cpu' if backend == 'gloo' else 'cuda')
157
+
158
+ buffer = pickle.dumps(data)
159
+ if len(buffer) > 1024 ** 3:
160
+ logger = logging.getLogger(__name__)
161
+ logger.warning(
162
+ 'Rank {} trying to all-gather {:.2f} GB of data on device'
163
+ '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device))
164
+ storage = torch.ByteStorage.from_buffer(buffer)
165
+ tensor = torch.ByteTensor(storage).to(device=device)
166
+ return tensor
167
+
168
+ def _pad_to_largest_tensor(tensor, group):
169
+ world_size = dist.get_world_size(group=group)
170
+ assert world_size >= 1, \
171
+ 'gather/all_gather must be called from ranks within' \
172
+ 'the give group!'
173
+ local_size = torch.tensor(
174
+ [tensor.numel()], dtype=torch.int64, device=tensor.device)
175
+ size_list = [torch.zeros(
176
+ [1], dtype=torch.int64, device=tensor.device)
177
+ for _ in range(world_size)]
178
+
179
+ # gather tensors and compute the maximum size
180
+ dist.all_gather(size_list, local_size, group=group)
181
+ size_list = [int(size.item()) for size in size_list]
182
+ max_size = max(size_list)
183
+
184
+ # pad tensors to the same size
185
+ if local_size != max_size:
186
+ padding = torch.zeros(
187
+ (max_size - local_size, ),
188
+ dtype=torch.uint8, device=tensor.device)
189
+ tensor = torch.cat((tensor, padding), dim=0)
190
+ return size_list, tensor
191
+
192
+ def generalized_all_gather(data, group=None):
193
+ if get_world_size(group) == 1:
194
+ return [data]
195
+ if group is None:
196
+ group = get_global_gloo_group()
197
+
198
+ tensor = _serialize_to_tensor(data, group)
199
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
200
+ max_size = max(size_list)
201
+
202
+ # receiving tensors from all ranks
203
+ tensor_list = [torch.empty(
204
+ (max_size, ), dtype=torch.uint8, device=tensor.device)
205
+ for _ in size_list]
206
+ dist.all_gather(tensor_list, tensor, group=group)
207
+
208
+ data_list = []
209
+ for size, tensor in zip(size_list, tensor_list):
210
+ buffer = tensor.cpu().numpy().tobytes()[:size]
211
+ data_list.append(pickle.loads(buffer))
212
+ return data_list
213
+
214
+ def generalized_gather(data, dst=0, group=None):
215
+ world_size = get_world_size(group)
216
+ if world_size == 1:
217
+ return [data]
218
+ if group is None:
219
+ group = get_global_gloo_group()
220
+ rank = dist.get_rank() # global rank
221
+
222
+ tensor = _serialize_to_tensor(data, group)
223
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
224
+
225
+ # receiving tensors from all ranks to dst
226
+ if rank == dst:
227
+ max_size = max(size_list)
228
+ tensor_list = [torch.empty(
229
+ (max_size, ), dtype=torch.uint8, device=tensor.device)
230
+ for _ in size_list]
231
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
232
+
233
+ data_list = []
234
+ for size, tensor in zip(size_list, tensor_list):
235
+ buffer = tensor.cpu().numpy().tobytes()[:size]
236
+ data_list.append(pickle.loads(buffer))
237
+ return data_list
238
+ else:
239
+ dist.gather(tensor, [], dst=dst, group=group)
240
+ return []
241
+
242
+ def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
243
+ r"""NOTE: only supports CPU tensor communication.
244
+ """
245
+ if get_world_size(group) > 1:
246
+ return dist.scatter(data, scatter_list, src, group, **kwargs)
247
+
248
+ def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs):
249
+ if get_world_size(group) > 1:
250
+ return dist.reduce_scatter(output, input_list, op, group, **kwargs)
251
+
252
+ def send(tensor, dst, group=None, **kwargs):
253
+ if get_world_size(group) > 1:
254
+ assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()'
255
+ return dist.send(tensor, dst, group, **kwargs)
256
+
257
+ def recv(tensor, src=None, group=None, **kwargs):
258
+ if get_world_size(group) > 1:
259
+ assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()'
260
+ return dist.recv(tensor, src, group, **kwargs)
261
+
262
+ def isend(tensor, dst, group=None, **kwargs):
263
+ if get_world_size(group) > 1:
264
+ assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()'
265
+ return dist.isend(tensor, dst, group, **kwargs)
266
+
267
+ def irecv(tensor, src=None, group=None, **kwargs):
268
+ if get_world_size(group) > 1:
269
+ assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()'
270
+ return dist.irecv(tensor, src, group, **kwargs)
271
+
272
+ def shared_random_seed(group=None):
273
+ seed = np.random.randint(2 ** 31)
274
+ all_seeds = generalized_all_gather(seed, group)
275
+ return all_seeds[0]
276
+
277
+ #-------------------------------- Differentiable operations --------------------------------#
278
+
279
+ def _all_gather(x):
280
+ if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
281
+ return x
282
+ rank = dist.get_rank()
283
+ world_size = dist.get_world_size()
284
+ tensors = [torch.empty_like(x) for _ in range(world_size)]
285
+ tensors[rank] = x
286
+ dist.all_gather(tensors, x)
287
+ return torch.cat(tensors, dim=0).contiguous()
288
+
289
+ def _all_reduce(x):
290
+ if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
291
+ return x
292
+ dist.all_reduce(x)
293
+ return x
294
+
295
+ def _split(x):
296
+ if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
297
+ return x
298
+ rank = dist.get_rank()
299
+ world_size = dist.get_world_size()
300
+ return x.chunk(world_size, dim=0)[rank].contiguous()
301
+
302
+ class DiffAllGather(Function):
303
+ r"""Differentiable all-gather.
304
+ """
305
+ @staticmethod
306
+ def symbolic(graph, input):
307
+ return _all_gather(input)
308
+
309
+ @staticmethod
310
+ def forward(ctx, input):
311
+ return _all_gather(input)
312
+
313
+ @staticmethod
314
+ def backward(ctx, grad_output):
315
+ return _split(grad_output)
316
+
317
+ class DiffAllReduce(Function):
318
+ r"""Differentiable all-reducd.
319
+ """
320
+ @staticmethod
321
+ def symbolic(graph, input):
322
+ return _all_reduce(input)
323
+
324
+ @staticmethod
325
+ def forward(ctx, input):
326
+ return _all_reduce(input)
327
+
328
+ @staticmethod
329
+ def backward(ctx, grad_output):
330
+ return grad_output
331
+
332
+ class DiffScatter(Function):
333
+ r"""Differentiable scatter.
334
+ """
335
+ @staticmethod
336
+ def symbolic(graph, input):
337
+ return _split(input)
338
+
339
+ @staticmethod
340
+ def symbolic(ctx, input):
341
+ return _split(input)
342
+
343
+ @staticmethod
344
+ def backward(ctx, grad_output):
345
+ return _all_gather(grad_output)
346
+
347
+ class DiffCopy(Function):
348
+ r"""Differentiable copy that reduces all gradients during backward.
349
+ """
350
+ @staticmethod
351
+ def symbolic(graph, input):
352
+ return input
353
+
354
+ @staticmethod
355
+ def forward(ctx, input):
356
+ return input
357
+
358
+ @staticmethod
359
+ def backward(ctx, grad_output):
360
+ return _all_reduce(grad_output)
361
+
362
+ diff_all_gather = DiffAllGather.apply
363
+ diff_all_reduce = DiffAllReduce.apply
364
+ diff_scatter = DiffScatter.apply
365
+ diff_copy = DiffCopy.apply
366
+
367
+ #-------------------------------- Distributed algorithms --------------------------------#
368
+
369
+ @torch.no_grad()
370
+ def spherical_kmeans(feats, num_clusters, num_iters=10):
371
+ k, n, c = num_clusters, *feats.size()
372
+ ones = feats.new_ones(n, dtype=torch.long)
373
+
374
+ # distributed settings
375
+ rank = get_rank()
376
+ world_size = get_world_size()
377
+
378
+ # init clusters
379
+ rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
380
+ clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
381
+
382
+ # variables
383
+ new_clusters = feats.new_zeros(k, c)
384
+ counts = feats.new_zeros(k, dtype=torch.long)
385
+
386
+ # iterative Expectation-Maximization
387
+ for step in range(num_iters + 1):
388
+ # Expectation step
389
+ simmat = torch.mm(feats, clusters.t())
390
+ scores, assigns = simmat.max(dim=1)
391
+ if step == num_iters:
392
+ break
393
+
394
+ # Maximization step
395
+ new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats)
396
+ all_reduce(new_clusters)
397
+
398
+ counts.zero_()
399
+ counts.index_add_(0, assigns, ones)
400
+ all_reduce(counts)
401
+
402
+ mask = (counts > 0)
403
+ clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
404
+ clusters = F.normalize(clusters, p=2, dim=1)
405
+ return clusters, assigns, scores
406
+
407
+ @torch.no_grad()
408
+ def sinkhorn(Q, eps=0.5, num_iters=3):
409
+ # normalize Q
410
+ Q = torch.exp(Q / eps).t()
411
+ sum_Q = Q.sum()
412
+ all_reduce(sum_Q)
413
+ Q /= sum_Q
414
+
415
+ # variables
416
+ n, m = Q.size()
417
+ u = Q.new_zeros(n)
418
+ r = Q.new_ones(n) / n
419
+ c = Q.new_ones(m) / (m * get_world_size())
420
+
421
+ # iterative update
422
+ cur_sum = Q.sum(dim=1)
423
+ all_reduce(cur_sum)
424
+ for i in range(num_iters):
425
+ u = cur_sum
426
+ Q *= (r / u).unsqueeze(1)
427
+ Q *= (c / Q.sum(dim=0)).unsqueeze(0)
428
+ cur_sum = Q.sum(dim=1)
429
+ all_reduce(cur_sum)
430
+ return (Q / Q.sum(dim=0, keepdim=True)).t().float()
UniAnimate/utils/logging.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ """Logging."""
5
+
6
+ import builtins
7
+ import decimal
8
+ import functools
9
+ import logging
10
+ import os
11
+ import sys
12
+ import simplejson
13
+ # from fvcore.common.file_io import PathManager
14
+
15
+ import utils.distributed as du
16
+
17
+
18
+ def _suppress_print():
19
+ """
20
+ Suppresses printing from the current process.
21
+ """
22
+
23
+ def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
24
+ pass
25
+
26
+ builtins.print = print_pass
27
+
28
+
29
+ # @functools.lru_cache(maxsize=None)
30
+ # def _cached_log_stream(filename):
31
+ # return PathManager.open(filename, "a")
32
+
33
+
34
+ def setup_logging(cfg, log_file):
35
+ """
36
+ Sets up the logging for multiple processes. Only enable the logging for the
37
+ master process, and suppress logging for the non-master processes.
38
+ """
39
+ if du.is_master_proc():
40
+ # Enable logging for the master process.
41
+ logging.root.handlers = []
42
+ else:
43
+ # Suppress logging for non-master processes.
44
+ _suppress_print()
45
+
46
+ logger = logging.getLogger()
47
+ logger.setLevel(logging.INFO)
48
+ logger.propagate = False
49
+ plain_formatter = logging.Formatter(
50
+ "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
51
+ datefmt="%m/%d %H:%M:%S",
52
+ )
53
+
54
+ if du.is_master_proc():
55
+ ch = logging.StreamHandler(stream=sys.stdout)
56
+ ch.setLevel(logging.DEBUG)
57
+ ch.setFormatter(plain_formatter)
58
+ logger.addHandler(ch)
59
+
60
+ if log_file is not None and du.is_master_proc(du.get_world_size()):
61
+ filename = os.path.join(cfg.OUTPUT_DIR, log_file)
62
+ fh = logging.FileHandler(filename)
63
+ fh.setLevel(logging.DEBUG)
64
+ fh.setFormatter(plain_formatter)
65
+ logger.addHandler(fh)
66
+
67
+
68
+ def get_logger(name):
69
+ """
70
+ Retrieve the logger with the specified name or, if name is None, return a
71
+ logger which is the root logger of the hierarchy.
72
+ Args:
73
+ name (string): name of the logger.
74
+ """
75
+ return logging.getLogger(name)
76
+
77
+
78
+ def log_json_stats(stats):
79
+ """
80
+ Logs json stats.
81
+ Args:
82
+ stats (dict): a dictionary of statistical information to log.
83
+ """
84
+ stats = {
85
+ k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
86
+ for k, v in stats.items()
87
+ }
88
+ json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
89
+ logger = get_logger(__name__)
90
+ logger.info("{:s}".format(json_stats))
UniAnimate/utils/mp4_to_gif.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+
5
+ # source_mp4_dir = "outputs/UniAnimate_infer"
6
+ # target_gif_dir = "outputs/UniAnimate_infer_gif"
7
+
8
+ source_mp4_dir = "outputs/UniAnimate_infer_long"
9
+ target_gif_dir = "outputs/UniAnimate_infer_long_gif"
10
+
11
+ os.makedirs(target_gif_dir, exist_ok=True)
12
+ for video in os.listdir(source_mp4_dir):
13
+ video_dir = os.path.join(source_mp4_dir, video)
14
+ gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif"))
15
+ cmd = f'ffmpeg -i {video_dir} {gif_dir}'
16
+ os.system(cmd)
UniAnimate/utils/multi_port.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ from contextlib import closing
3
+
4
+ def find_free_port():
5
+ """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
6
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
7
+ s.bind(('', 0))
8
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
9
+ return str(s.getsockname()[1])
UniAnimate/utils/optim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .lr_scheduler import *
2
+ from .adafactor import *
UniAnimate/utils/optim/adafactor.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.optim import Optimizer
4
+ from torch.optim.lr_scheduler import LambdaLR
5
+
6
+ __all__ = ['Adafactor']
7
+
8
+ class Adafactor(Optimizer):
9
+ """
10
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
11
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
12
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
13
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
14
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
15
+ `relative_step=False`.
16
+ Arguments:
17
+ params (`Iterable[nn.parameter.Parameter]`):
18
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
19
+ lr (`float`, *optional*):
20
+ The external learning rate.
21
+ eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
22
+ Regularization constants for square gradient and parameter scale respectively
23
+ clip_threshold (`float`, *optional*, defaults 1.0):
24
+ Threshold of root mean square of final gradient update
25
+ decay_rate (`float`, *optional*, defaults to -0.8):
26
+ Coefficient used to compute running averages of square
27
+ beta1 (`float`, *optional*):
28
+ Coefficient used for computing running averages of gradient
29
+ weight_decay (`float`, *optional*, defaults to 0):
30
+ Weight decay (L2 penalty)
31
+ scale_parameter (`bool`, *optional*, defaults to `True`):
32
+ If True, learning rate is scaled by root mean square
33
+ relative_step (`bool`, *optional*, defaults to `True`):
34
+ If True, time-dependent learning rate is computed instead of external learning rate
35
+ warmup_init (`bool`, *optional*, defaults to `False`):
36
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
37
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
38
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
39
+ - Training without LR warmup or clip_threshold is not recommended.
40
+ - use scheduled LR warm-up to fixed LR
41
+ - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
42
+ - Disable relative updates
43
+ - Use scale_parameter=False
44
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
45
+ Example:
46
+ ```python
47
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
48
+ ```
49
+ Others reported the following combination to work well:
50
+ ```python
51
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
52
+ ```
53
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
54
+ scheduler as following:
55
+ ```python
56
+ from transformers.optimization import Adafactor, AdafactorSchedule
57
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
58
+ lr_scheduler = AdafactorSchedule(optimizer)
59
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
60
+ ```
61
+ Usage:
62
+ ```python
63
+ # replace AdamW with Adafactor
64
+ optimizer = Adafactor(
65
+ model.parameters(),
66
+ lr=1e-3,
67
+ eps=(1e-30, 1e-3),
68
+ clip_threshold=1.0,
69
+ decay_rate=-0.8,
70
+ beta1=None,
71
+ weight_decay=0.0,
72
+ relative_step=False,
73
+ scale_parameter=False,
74
+ warmup_init=False,
75
+ )
76
+ ```"""
77
+
78
+ def __init__(
79
+ self,
80
+ params,
81
+ lr=None,
82
+ eps=(1e-30, 1e-3),
83
+ clip_threshold=1.0,
84
+ decay_rate=-0.8,
85
+ beta1=None,
86
+ weight_decay=0.0,
87
+ scale_parameter=True,
88
+ relative_step=True,
89
+ warmup_init=False,
90
+ ):
91
+ r"""require_version("torch>=1.5.0") # add_ with alpha
92
+ """
93
+ if lr is not None and relative_step:
94
+ raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
95
+ if warmup_init and not relative_step:
96
+ raise ValueError("`warmup_init=True` requires `relative_step=True`")
97
+
98
+ defaults = dict(
99
+ lr=lr,
100
+ eps=eps,
101
+ clip_threshold=clip_threshold,
102
+ decay_rate=decay_rate,
103
+ beta1=beta1,
104
+ weight_decay=weight_decay,
105
+ scale_parameter=scale_parameter,
106
+ relative_step=relative_step,
107
+ warmup_init=warmup_init,
108
+ )
109
+ super().__init__(params, defaults)
110
+
111
+ @staticmethod
112
+ def _get_lr(param_group, param_state):
113
+ rel_step_sz = param_group["lr"]
114
+ if param_group["relative_step"]:
115
+ min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
116
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
117
+ param_scale = 1.0
118
+ if param_group["scale_parameter"]:
119
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
120
+ return param_scale * rel_step_sz
121
+
122
+ @staticmethod
123
+ def _get_options(param_group, param_shape):
124
+ factored = len(param_shape) >= 2
125
+ use_first_moment = param_group["beta1"] is not None
126
+ return factored, use_first_moment
127
+
128
+ @staticmethod
129
+ def _rms(tensor):
130
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
131
+
132
+ @staticmethod
133
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
134
+ # copy from fairseq's adafactor implementation:
135
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
136
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
137
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
138
+ return torch.mul(r_factor, c_factor)
139
+
140
+ def step(self, closure=None):
141
+ """
142
+ Performs a single optimization step
143
+ Arguments:
144
+ closure (callable, optional): A closure that reevaluates the model
145
+ and returns the loss.
146
+ """
147
+ loss = None
148
+ if closure is not None:
149
+ loss = closure()
150
+
151
+ for group in self.param_groups:
152
+ for p in group["params"]:
153
+ if p.grad is None:
154
+ continue
155
+ grad = p.grad.data
156
+ if grad.dtype in {torch.float16, torch.bfloat16}:
157
+ grad = grad.float()
158
+ if grad.is_sparse:
159
+ raise RuntimeError("Adafactor does not support sparse gradients.")
160
+
161
+ state = self.state[p]
162
+ grad_shape = grad.shape
163
+
164
+ factored, use_first_moment = self._get_options(group, grad_shape)
165
+ # State Initialization
166
+ if len(state) == 0:
167
+ state["step"] = 0
168
+
169
+ if use_first_moment:
170
+ # Exponential moving average of gradient values
171
+ state["exp_avg"] = torch.zeros_like(grad)
172
+ if factored:
173
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
174
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
175
+ else:
176
+ state["exp_avg_sq"] = torch.zeros_like(grad)
177
+
178
+ state["RMS"] = 0
179
+ else:
180
+ if use_first_moment:
181
+ state["exp_avg"] = state["exp_avg"].to(grad)
182
+ if factored:
183
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
184
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
185
+ else:
186
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
187
+
188
+ p_data_fp32 = p.data
189
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
190
+ p_data_fp32 = p_data_fp32.float()
191
+
192
+ state["step"] += 1
193
+ state["RMS"] = self._rms(p_data_fp32)
194
+ lr = self._get_lr(group, state)
195
+
196
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
197
+ update = (grad**2) + group["eps"][0]
198
+ if factored:
199
+ exp_avg_sq_row = state["exp_avg_sq_row"]
200
+ exp_avg_sq_col = state["exp_avg_sq_col"]
201
+
202
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
203
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
204
+
205
+ # Approximation of exponential moving average of square of gradient
206
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
207
+ update.mul_(grad)
208
+ else:
209
+ exp_avg_sq = state["exp_avg_sq"]
210
+
211
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
212
+ update = exp_avg_sq.rsqrt().mul_(grad)
213
+
214
+ update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
215
+ update.mul_(lr)
216
+
217
+ if use_first_moment:
218
+ exp_avg = state["exp_avg"]
219
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
220
+ update = exp_avg
221
+
222
+ if group["weight_decay"] != 0:
223
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
224
+
225
+ p_data_fp32.add_(-update)
226
+
227
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
228
+ p.data.copy_(p_data_fp32)
229
+
230
+ return loss
UniAnimate/utils/optim/lr_scheduler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from torch.optim.lr_scheduler import _LRScheduler
3
+
4
+ __all__ = ['AnnealingLR']
5
+
6
+ class AnnealingLR(_LRScheduler):
7
+
8
+ def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1):
9
+ assert decay_mode in ['linear', 'cosine', 'none']
10
+ self.optimizer = optimizer
11
+ self.base_lr = base_lr
12
+ self.warmup_steps = warmup_steps
13
+ self.total_steps = total_steps
14
+ self.decay_mode = decay_mode
15
+ self.min_lr = min_lr
16
+ self.current_step = last_step + 1
17
+ self.step(self.current_step)
18
+
19
+ def get_lr(self):
20
+ if self.warmup_steps > 0 and self.current_step <= self.warmup_steps:
21
+ return self.base_lr * self.current_step / self.warmup_steps
22
+ else:
23
+ ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
24
+ ratio = min(1.0, max(0.0, ratio))
25
+ if self.decay_mode == 'linear':
26
+ return self.base_lr * (1 - ratio)
27
+ elif self.decay_mode == 'cosine':
28
+ return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0
29
+ else:
30
+ return self.base_lr
31
+
32
+ def step(self, current_step=None):
33
+ if current_step is None:
34
+ current_step = self.current_step + 1
35
+ self.current_step = current_step
36
+ new_lr = max(self.min_lr, self.get_lr())
37
+ if isinstance(self.optimizer, list):
38
+ for o in self.optimizer:
39
+ for group in o.param_groups:
40
+ group['lr'] = new_lr
41
+ else:
42
+ for group in self.optimizer.param_groups:
43
+ group['lr'] = new_lr
44
+
45
+ def state_dict(self):
46
+ return {
47
+ 'base_lr': self.base_lr,
48
+ 'warmup_steps': self.warmup_steps,
49
+ 'total_steps': self.total_steps,
50
+ 'decay_mode': self.decay_mode,
51
+ 'current_step': self.current_step}
52
+
53
+ def load_state_dict(self, state_dict):
54
+ self.base_lr = state_dict['base_lr']
55
+ self.warmup_steps = state_dict['warmup_steps']
56
+ self.total_steps = state_dict['total_steps']
57
+ self.decay_mode = state_dict['decay_mode']
58
+ self.current_step = state_dict['current_step']