Spaces:
Running
Running
Evgeny Zhukov
commited on
Commit
·
2ba4412
1
Parent(s):
45e557f
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- UniAnimate/.gitignore +18 -0
- UniAnimate/README.md +344 -0
- UniAnimate/configs/UniAnimate_infer.yaml +98 -0
- UniAnimate/configs/UniAnimate_infer_long.yaml +101 -0
- UniAnimate/dwpose/__init__.py +0 -0
- UniAnimate/dwpose/onnxdet.py +127 -0
- UniAnimate/dwpose/onnxpose.py +360 -0
- UniAnimate/dwpose/util.py +336 -0
- UniAnimate/dwpose/wholebody.py +48 -0
- UniAnimate/environment.yaml +236 -0
- UniAnimate/inference.py +18 -0
- UniAnimate/requirements.txt +201 -0
- UniAnimate/run_align_pose.py +712 -0
- UniAnimate/test_func/save_targer_keys.py +108 -0
- UniAnimate/test_func/test_EndDec.py +95 -0
- UniAnimate/test_func/test_dataset.py +152 -0
- UniAnimate/test_func/test_models.py +56 -0
- UniAnimate/test_func/test_save_video.py +24 -0
- UniAnimate/tools/__init__.py +3 -0
- UniAnimate/tools/datasets/__init__.py +2 -0
- UniAnimate/tools/datasets/image_dataset.py +86 -0
- UniAnimate/tools/datasets/video_dataset.py +118 -0
- UniAnimate/tools/inferences/__init__.py +2 -0
- UniAnimate/tools/inferences/inference_unianimate_entrance.py +483 -0
- UniAnimate/tools/inferences/inference_unianimate_long_entrance.py +508 -0
- UniAnimate/tools/modules/__init__.py +7 -0
- UniAnimate/tools/modules/autoencoder.py +690 -0
- UniAnimate/tools/modules/clip_embedder.py +212 -0
- UniAnimate/tools/modules/config.py +206 -0
- UniAnimate/tools/modules/diffusions/__init__.py +1 -0
- UniAnimate/tools/modules/diffusions/diffusion_ddim.py +1121 -0
- UniAnimate/tools/modules/diffusions/diffusion_gauss.py +498 -0
- UniAnimate/tools/modules/diffusions/losses.py +28 -0
- UniAnimate/tools/modules/diffusions/schedules.py +166 -0
- UniAnimate/tools/modules/embedding_manager.py +179 -0
- UniAnimate/tools/modules/unet/__init__.py +2 -0
- UniAnimate/tools/modules/unet/mha_flash.py +103 -0
- UniAnimate/tools/modules/unet/unet_unianimate.py +659 -0
- UniAnimate/tools/modules/unet/util.py +1741 -0
- UniAnimate/utils/__init__.py +0 -0
- UniAnimate/utils/assign_cfg.py +78 -0
- UniAnimate/utils/config.py +230 -0
- UniAnimate/utils/distributed.py +430 -0
- UniAnimate/utils/logging.py +90 -0
- UniAnimate/utils/mp4_to_gif.py +16 -0
- UniAnimate/utils/multi_port.py +9 -0
- UniAnimate/utils/optim/__init__.py +2 -0
- UniAnimate/utils/optim/adafactor.py +230 -0
- UniAnimate/utils/optim/lr_scheduler.py +58 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
|
37 |
+
UniAnimate/data/** filter=lfs diff=lfs merge=lfs -text
|
UniAnimate/.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pkl
|
2 |
+
*.pt
|
3 |
+
*.mov
|
4 |
+
*.pth
|
5 |
+
*.mov
|
6 |
+
*.npz
|
7 |
+
*.npy
|
8 |
+
*.boj
|
9 |
+
*.onnx
|
10 |
+
*.tar
|
11 |
+
*.bin
|
12 |
+
cache*
|
13 |
+
.DS_Store
|
14 |
+
*DS_Store
|
15 |
+
outputs/
|
16 |
+
**/__pycache__
|
17 |
+
***/__pycache__
|
18 |
+
*/__pycache__
|
UniAnimate/README.md
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- main documents -->
|
2 |
+
|
3 |
+
|
4 |
+
<div align="center">
|
5 |
+
|
6 |
+
|
7 |
+
# UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation
|
8 |
+
|
9 |
+
[Xiang Wang](https://scholar.google.com.hk/citations?user=cQbXvkcAAAAJ&hl=zh-CN&oi=sra)<sup>1</sup>, [Shiwei Zhang](https://scholar.google.com.hk/citations?user=ZO3OQ-8AAAAJ&hl=zh-CN)<sup>2</sup>, [Changxin Gao](https://scholar.google.com.hk/citations?user=4tku-lwAAAAJ&hl=zh-CN)<sup>1</sup>, [Jiayu Wang](#)<sup>2</sup>, [Xiaoqiang Zhou](https://scholar.google.com.hk/citations?user=Z2BTkNIAAAAJ&hl=zh-CN&oi=ao)<sup>3</sup>, [Yingya Zhang](https://scholar.google.com.hk/citations?user=16RDSEUAAAAJ&hl=zh-CN)<sup>2</sup> , [Luxin Yan](#)<sup>1</sup> , [Nong Sang](https://scholar.google.com.hk/citations?user=ky_ZowEAAAAJ&hl=zh-CN)<sup>1</sup>
|
10 |
+
<sup>1</sup>HUST <sup>2</sup>Alibaba Group <sup>3</sup>USTC
|
11 |
+
|
12 |
+
|
13 |
+
[🎨 Project Page](https://unianimate.github.io/)
|
14 |
+
|
15 |
+
|
16 |
+
<p align="middle">
|
17 |
+
<img src='https://img.alicdn.com/imgextra/i4/O1CN01bW2Y491JkHAUK4W0i_!!6000000001066-2-tps-2352-1460.png' width='784'>
|
18 |
+
|
19 |
+
Demo cases generated by the proposed UniAnimate
|
20 |
+
</p>
|
21 |
+
|
22 |
+
|
23 |
+
</div>
|
24 |
+
|
25 |
+
## 🔥 News
|
26 |
+
- **[2024/07/19]** 🔥 We added a **<font color=red>noise prior</font>** to the code (refer to line 381: `noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise)` in `tools/inferences/inference_unianimate_long_entrance.py`), which can help achieve better appearance preservation (such as background), especially in long video generation. In addition, we are considering releasing an upgraded version of UniAnimate if we obtain an open source license from the company.
|
27 |
+
- **[2024/06/26]** For cards with large GPU memory, such as A100 GPU, we support multiple segments parallel denoising to accelerate long video inference. You can change `context_batch_size: 1` in `configs/UniAnimate_infer_long.yaml` to other values greater than 1, such as `context_batch_size: 4`. The inference speed will be improved to a certain extent.
|
28 |
+
- **[2024/06/15]** 🔥🔥🔥 By offloading CLIP and VAE and explicitly adding torch.float16 (i.e., set `CPU_CLIP_VAE: True` in `configs/UniAnimate_infer.yaml`), the GPU memory can be greatly reduced. Now generating a 32x768x512 video clip only requires **~12G GPU memory**. Refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/10) for more details. Thanks to [@blackight](https://github.com/blackight) for the contribution!
|
29 |
+
- **[2024/06/13]** **🔥🔥🔥 <font color=red>We released the code and models for human image animation, enjoy it!</font>**
|
30 |
+
- **[2024/06/13]** We have submitted the code to the company for approval, and **the code is expected to be released today or tomorrow**.
|
31 |
+
- **[2024/06/03]** We initialized this github repository and planed to release the paper.
|
32 |
+
|
33 |
+
|
34 |
+
## TODO
|
35 |
+
|
36 |
+
- [x] Release the models and inference code, and pose alignment code.
|
37 |
+
- [x] Support generating both short and long videos.
|
38 |
+
- [ ] Release the models for longer video generation in one batch.
|
39 |
+
- [ ] Release models based on VideoLCM for faster video synthesis.
|
40 |
+
- [ ] Training the models on higher resolution videos.
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
## Introduction
|
45 |
+
|
46 |
+
<div align="center">
|
47 |
+
<p align="middle">
|
48 |
+
<img src='https://img.alicdn.com/imgextra/i3/O1CN01VvncFJ1ueRudiMOZu_!!6000000006062-2-tps-2654-1042.png' width='784'>
|
49 |
+
|
50 |
+
Overall framework of UniAnimate
|
51 |
+
</p>
|
52 |
+
</div>
|
53 |
+
|
54 |
+
Recent diffusion-based human image animation techniques have demonstrated impressive success in synthesizing videos that faithfully follow a given reference identity and a sequence of desired movement poses. Despite this, there are still two limitations: i) an extra reference model is required to align the identity image with the main video branch, which significantly increases the optimization burden and model parameters; ii) the generated video is usually short in time (e.g., 24 frames), hampering practical applications. To address these shortcomings, we present a UniAnimate framework to enable efficient and long-term human video generation. First, to reduce the optimization difficulty and ensure temporal coherence, we map the reference image along with the posture guidance and noise video into a common feature space by incorporating a unified video diffusion model. Second, we propose a unified noise input that supports random noised input as well as first frame conditioned input, which enhances the ability to generate long-term video. Finally, to further efficiently handle long sequences, we explore an alternative temporal modeling architecture based on state space model to replace the original computation-consuming temporal Transformer. Extensive experimental results indicate that UniAnimate achieves superior synthesis results over existing state-of-the-art counterparts in both quantitative and qualitative evaluations. Notably, UniAnimate can even generate highly consistent one-minute videos by iteratively employing the first frame conditioning strategy.
|
55 |
+
|
56 |
+
|
57 |
+
## Getting Started with UniAnimate
|
58 |
+
|
59 |
+
|
60 |
+
### (1) Installation
|
61 |
+
|
62 |
+
Installation the python dependencies:
|
63 |
+
|
64 |
+
```
|
65 |
+
git clone https://github.com/ali-vilab/UniAnimate.git
|
66 |
+
cd UniAnimate
|
67 |
+
conda create -n UniAnimate python=3.9
|
68 |
+
conda activate UniAnimate
|
69 |
+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
70 |
+
pip install -r requirements.txt
|
71 |
+
```
|
72 |
+
We also provide all the dependencies in `environment.yaml`.
|
73 |
+
|
74 |
+
**Note**: for Windows operating system, you can refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/11) to install the dependencies. Thanks to [@zephirusgit](https://github.com/zephirusgit) for the contribution. If you encouter the problem of `The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1).`, please refer to [this issue](https://github.com/ali-vilab/UniAnimate/issues/61) to solve it, thanks to [@Isi-dev](https://github.com/Isi-dev) for the contribution.
|
75 |
+
|
76 |
+
### (2) Download the pretrained checkpoints
|
77 |
+
|
78 |
+
Download models:
|
79 |
+
```
|
80 |
+
!pip install modelscope
|
81 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
82 |
+
model_dir = snapshot_download('iic/unianimate', cache_dir='checkpoints/')
|
83 |
+
```
|
84 |
+
Then you might need the following command to move the checkpoints to the "checkpoints/" directory:
|
85 |
+
```
|
86 |
+
mv ./checkpoints/iic/unianimate/* ./checkpoints/
|
87 |
+
```
|
88 |
+
|
89 |
+
Finally, the model weights will be organized in `./checkpoints/` as follows:
|
90 |
+
```
|
91 |
+
./checkpoints/
|
92 |
+
|---- dw-ll_ucoco_384.onnx
|
93 |
+
|---- open_clip_pytorch_model.bin
|
94 |
+
|---- unianimate_16f_32f_non_ema_223000.pth
|
95 |
+
|---- v2-1_512-ema-pruned.ckpt
|
96 |
+
└---- yolox_l.onnx
|
97 |
+
```
|
98 |
+
|
99 |
+
### (3) Pose alignment **(Important)**
|
100 |
+
|
101 |
+
Rescale the target pose sequence to match the pose of the reference image:
|
102 |
+
```
|
103 |
+
# reference image 1
|
104 |
+
python run_align_pose.py --ref_name data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full
|
105 |
+
|
106 |
+
# reference image 2
|
107 |
+
python run_align_pose.py --ref_name data/images/musk.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/musk
|
108 |
+
|
109 |
+
# reference image 3
|
110 |
+
python run_align_pose.py --ref_name data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full
|
111 |
+
|
112 |
+
# reference image 4
|
113 |
+
python run_align_pose.py --ref_name data/images/IMG_20240514_104337.jpg --source_video_paths data/videos/source_video.mp4 --saved_pose_dir data/saved_pose/IMG_20240514_104337
|
114 |
+
```
|
115 |
+
We have already provided the processed target pose for demo videos in ```data/saved_pose```, if you run our demo video example, this step can be skipped. In addition, you need to install onnxruntime-gpu (`pip install onnxruntime-gpu==1.13.1`) to run pose alignment on GPU.
|
116 |
+
|
117 |
+
**<font color=red>✔ Some tips</font>**:
|
118 |
+
|
119 |
+
- > In pose alignment, the first frame in the target pose sequence is used to calculate the scale coefficient of the alignment. Therefore, if the first frame in the target pose sequence contains the entire face and pose (hand and foot), it can help obtain more accurate estimation and better video generation results.
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
### (4) Run the UniAnimate model to generate videos
|
124 |
+
|
125 |
+
#### (4.1) Generating video clips (32 frames with 768x512 resolution)
|
126 |
+
|
127 |
+
Execute the following command to generate video clips:
|
128 |
+
```
|
129 |
+
python inference.py --cfg configs/UniAnimate_infer.yaml
|
130 |
+
```
|
131 |
+
After this, 32-frame video clips with 768x512 resolution will be generated:
|
132 |
+
|
133 |
+
|
134 |
+
<table>
|
135 |
+
<center>
|
136 |
+
<tr>
|
137 |
+
<td ><center>
|
138 |
+
<image height="260" src="assets/1.gif"></image>
|
139 |
+
</center></td>
|
140 |
+
<td ><center>
|
141 |
+
<image height="260" src="assets/2.gif"></image>
|
142 |
+
</center></td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td ><center>
|
146 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV1hTb2g3Zlpmb1E/Vk9HZHZkdDBUQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
147 |
+
</center></td>
|
148 |
+
<td ><center>
|
149 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYzNUUWRKR043c1FaZkVHSkpSMnpoeTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
150 |
+
</center></td>
|
151 |
+
</tr>
|
152 |
+
</center>
|
153 |
+
</table>
|
154 |
+
</center>
|
155 |
+
|
156 |
+
<!-- <table>
|
157 |
+
<center>
|
158 |
+
<tr>
|
159 |
+
<td ><center>
|
160 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYzNUUWRKR043c1FaZkVHSkpSMnpoeTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
161 |
+
</td>
|
162 |
+
<td ><center>
|
163 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV1hTb2g3Zlpmb1E/Vk9HZHZkdDBUQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
164 |
+
</td>
|
165 |
+
</tr>
|
166 |
+
</table> -->
|
167 |
+
|
168 |
+
|
169 |
+
**<font color=red>✔ Some tips</font>**:
|
170 |
+
|
171 |
+
- > To run the model, **~12G** ~~26G~~ GPU memory will be used. If your GPU is smaller than this, you can change the `max_frames: 32` in `configs/UniAnimate_infer.yaml` to other values, e.g., 24, 16, and 8. Our model is compatible with all of them.
|
172 |
+
|
173 |
+
|
174 |
+
#### (4.2) Generating video clips (32 frames with 1216x768 resolution)
|
175 |
+
|
176 |
+
If you want to synthesize higher resolution results, you can change the `resolution: [512, 768]` in `configs/UniAnimate_infer.yaml` to `resolution: [768, 1216]`. And execute the following command to generate video clips:
|
177 |
+
```
|
178 |
+
python inference.py --cfg configs/UniAnimate_infer.yaml
|
179 |
+
```
|
180 |
+
After this, 32-frame video clips with 1216x768 resolution will be generated:
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
<table>
|
185 |
+
<center>
|
186 |
+
<tr>
|
187 |
+
<td ><center>
|
188 |
+
<image height="260" src="assets/3.gif"></image>
|
189 |
+
</center></td>
|
190 |
+
<td ><center>
|
191 |
+
<image height="260" src="assets/4.gif"></image>
|
192 |
+
</center></td>
|
193 |
+
</tr>
|
194 |
+
<tr>
|
195 |
+
<td ><center>
|
196 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/NTFJUWJ1YXphUzU5b3dhZHJlQk1YZjA3emppMWNJbHhXSlN6WmZHc2FTYTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
197 |
+
</center></td>
|
198 |
+
<td ><center>
|
199 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYklMcGdIRFlDcXcwVEU5ZnR0VlBpRzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
200 |
+
</center></td>
|
201 |
+
</tr>
|
202 |
+
</center>
|
203 |
+
</table>
|
204 |
+
</center>
|
205 |
+
|
206 |
+
<!-- <table>
|
207 |
+
<center>
|
208 |
+
<tr>
|
209 |
+
<td ><center>
|
210 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/NTFJUWJ1YXphUzU5b3dhZHJlQk1YZjA3emppMWNJbHhXSlN6WmZHc2FTYTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
211 |
+
</td>
|
212 |
+
<td ><center>
|
213 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYYklMcGdIRFlDcXcwVEU5ZnR0VlBpRzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
214 |
+
</td>
|
215 |
+
</tr>
|
216 |
+
</table> -->
|
217 |
+
|
218 |
+
|
219 |
+
**<font color=red>✔ Some tips</font>**:
|
220 |
+
|
221 |
+
- > To run the model, **~21G** ~~36G~~ GPU memory will be used. Even though our model was trained on 512x768 resolution, we observed that direct inference on 768x1216 is usually allowed and produces satisfactory results. If this results in inconsistent apparence, you can try a different seed or adjust the resolution to 512x768.
|
222 |
+
|
223 |
+
- > Although our model was not trained on 48 or 64 frames, we found that the model generalizes well to synthesis of these lengths.
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
In the `configs/UniAnimate_infer.yaml` configuration file, you can specify the data, adjust the video length using `max_frames`, and validate your ideas with different Diffusion settings, and so on.
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
#### (4.3) Generating long videos
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
If you want to synthesize videos as long as the target pose sequence, you can execute the following command to generate long videos:
|
236 |
+
```
|
237 |
+
python inference.py --cfg configs/UniAnimate_infer_long.yaml
|
238 |
+
```
|
239 |
+
After this, long videos with 1216x768 resolution will be generated:
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
<table>
|
244 |
+
<center>
|
245 |
+
<tr>
|
246 |
+
<td ><center>
|
247 |
+
<image height="260" src="assets/5.gif"></image>
|
248 |
+
</center></td>
|
249 |
+
<td ><center>
|
250 |
+
<image height="260" src="assets/6.gif"></image>
|
251 |
+
</center></td>
|
252 |
+
</tr>
|
253 |
+
<tr>
|
254 |
+
<td ><center>
|
255 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYVmJKZUJSbDl6N1FXU01DYTlDRmJKTzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
256 |
+
</center></td>
|
257 |
+
<td ><center>
|
258 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/VUdZTUE5MWtST3VtNEdFaVpGbHN1U25nNEorTEc2SzZROUNiUjNncW5ycTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
259 |
+
</center></td>
|
260 |
+
</tr>
|
261 |
+
</center>
|
262 |
+
</table>
|
263 |
+
</center>
|
264 |
+
|
265 |
+
<!-- <table>
|
266 |
+
<center>
|
267 |
+
<tr>
|
268 |
+
<td ><center>
|
269 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYVmJKZUJSbDl6N1FXU01DYTlDRmJKTzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
270 |
+
</td>
|
271 |
+
<td ><center>
|
272 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/VUdZTUE5MWtST3VtNEdFaVpGbHN1U25nNEorTEc2SzZROUNiUjNncW5ycTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
273 |
+
</td>
|
274 |
+
</tr>
|
275 |
+
</table> -->
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
<table>
|
281 |
+
<center>
|
282 |
+
<tr>
|
283 |
+
<td ><center>
|
284 |
+
<image height="260" src="assets/7.gif"></image>
|
285 |
+
</center></td>
|
286 |
+
<td ><center>
|
287 |
+
<image height="260" src="assets/8.gif"></image>
|
288 |
+
</center></td>
|
289 |
+
</tr>
|
290 |
+
<tr>
|
291 |
+
<td ><center>
|
292 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV04xKzd3eWFPVGZCQjVTUWdtbTFuQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
293 |
+
</center></td>
|
294 |
+
<td ><center>
|
295 |
+
<p>Click <a href="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYWGwxVkNMY1NXOHpWTVdNZDRxKzRuZTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ">HERE</a> to view the generated video.</p>
|
296 |
+
</center></td>
|
297 |
+
</tr>
|
298 |
+
</center>
|
299 |
+
</table>
|
300 |
+
</center>
|
301 |
+
|
302 |
+
<!-- <table>
|
303 |
+
<center>
|
304 |
+
<tr>
|
305 |
+
<td ><center>
|
306 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYV04xKzd3eWFPVGZCQjVTUWdtbTFuQzZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
307 |
+
</td>
|
308 |
+
<td ><center>
|
309 |
+
<video height="260" controls autoplay loop src="https://cloud.video.taobao.com/vod/play/cEdJVkF4TXRTOTd2bTQ4andjMENYWGwxVkNMY1NXOHpWTVdNZDRxKzRuZTZQZWw1SnpKVVVCTlh4OVFON0V5UUVMUDduY1RJak82VE1sdXdHTjNOaHc9PQ" muted="false"></video>
|
310 |
+
</td>
|
311 |
+
</tr>
|
312 |
+
</table> -->
|
313 |
+
|
314 |
+
In the `configs/UniAnimate_infer_long.yaml` configuration file, `test_list_path` should in the format of `[frame_interval, reference image, driving pose sequence]`, where `frame_interval=1` means that all frames in the target pose sequence will be used to generate the video, and `frame_interval=2` means that one frame is sampled every two frames. `reference image` is the location where the reference image is saved, and `driving pose sequence` is the location where the driving pose sequence is saved.
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
**<font color=red>✔ Some tips</font>**:
|
319 |
+
|
320 |
+
- > If you find inconsistent appearance, you can change the resolution from 768x1216 to 512x768, or change the `context_overlap` from 8 to 16.
|
321 |
+
- > In the default setting of `configs/UniAnimate_infer_long.yaml`, the strategy of sliding window with temporal overlap is used. You can also generate a satisfactory video segment first, and then input the last frame of this segment into the model to generate the next segment to continue the video.
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
## Citation
|
326 |
+
|
327 |
+
If you find this codebase useful for your research, please cite the following paper:
|
328 |
+
|
329 |
+
```
|
330 |
+
@article{wang2024unianimate,
|
331 |
+
title={UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation},
|
332 |
+
author={Wang, Xiang and Zhang, Shiwei and Gao, Changxin and Wang, Jiayu and Zhou, Xiaoqiang and Zhang, Yingya and Yan, Luxin and Sang, Nong},
|
333 |
+
journal={arXiv preprint arXiv:2406.01188},
|
334 |
+
year={2024}
|
335 |
+
}
|
336 |
+
```
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
## Disclaimer
|
341 |
+
|
342 |
+
|
343 |
+
This open-source model is intended for <strong>RESEARCH/NON-COMMERCIAL USE ONLY</strong>.
|
344 |
+
We explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards.
|
UniAnimate/configs/UniAnimate_infer.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# manual setting
|
2 |
+
max_frames: 32
|
3 |
+
resolution: [512, 768] # or resolution: [768, 1216]
|
4 |
+
# resolution: [768, 1216]
|
5 |
+
round: 1
|
6 |
+
ddim_timesteps: 30 # among 25-50
|
7 |
+
seed: 11 # 7
|
8 |
+
test_list_path: [
|
9 |
+
# Format: [frame_interval, reference image, driving pose sequence]
|
10 |
+
[2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"],
|
11 |
+
[2, "data/images/musk.jpg", "data/saved_pose/musk"],
|
12 |
+
[2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"],
|
13 |
+
[2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"]
|
14 |
+
]
|
15 |
+
partial_keys: [
|
16 |
+
['image','local_image', "dwpose"], # reference image as the first frame of the generated video (optional)
|
17 |
+
['image', 'randomref', "dwpose"],
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
# default settings
|
23 |
+
TASK_TYPE: inference_unianimate_entrance
|
24 |
+
guide_scale: 2.5
|
25 |
+
vit_resolution: [224, 224]
|
26 |
+
use_fp16: True
|
27 |
+
batch_size: 1
|
28 |
+
latent_random_ref: True
|
29 |
+
chunk_size: 2
|
30 |
+
decoder_bs: 2
|
31 |
+
scale: 8
|
32 |
+
use_fps_condition: False
|
33 |
+
test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth
|
34 |
+
embedder: {
|
35 |
+
'type': 'FrozenOpenCLIPTextVisualEmbedder',
|
36 |
+
'layer': 'penultimate',
|
37 |
+
'pretrained': 'checkpoints/open_clip_pytorch_model.bin'
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
auto_encoder: {
|
42 |
+
'type': 'AutoencoderKL',
|
43 |
+
'ddconfig': {
|
44 |
+
'double_z': True,
|
45 |
+
'z_channels': 4,
|
46 |
+
'resolution': 256,
|
47 |
+
'in_channels': 3,
|
48 |
+
'out_ch': 3,
|
49 |
+
'ch': 128,
|
50 |
+
'ch_mult': [1, 2, 4, 4],
|
51 |
+
'num_res_blocks': 2,
|
52 |
+
'attn_resolutions': [],
|
53 |
+
'dropout': 0.0,
|
54 |
+
'video_kernel_size': [3, 1, 1]
|
55 |
+
},
|
56 |
+
'embed_dim': 4,
|
57 |
+
'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt'
|
58 |
+
}
|
59 |
+
|
60 |
+
UNet: {
|
61 |
+
'type': 'UNetSD_UniAnimate',
|
62 |
+
'config': None,
|
63 |
+
'in_dim': 4,
|
64 |
+
'dim': 320,
|
65 |
+
'y_dim': 1024,
|
66 |
+
'context_dim': 1024,
|
67 |
+
'out_dim': 4,
|
68 |
+
'dim_mult': [1, 2, 4, 4],
|
69 |
+
'num_heads': 8,
|
70 |
+
'head_dim': 64,
|
71 |
+
'num_res_blocks': 2,
|
72 |
+
'dropout': 0.1,
|
73 |
+
'temporal_attention': True,
|
74 |
+
'num_tokens': 4,
|
75 |
+
'temporal_attn_times': 1,
|
76 |
+
'use_checkpoint': True,
|
77 |
+
'use_fps_condition': False,
|
78 |
+
'use_sim_mask': False
|
79 |
+
}
|
80 |
+
video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose']
|
81 |
+
Diffusion: {
|
82 |
+
'type': 'DiffusionDDIM',
|
83 |
+
'schedule': 'linear_sd',
|
84 |
+
'schedule_param': {
|
85 |
+
'num_timesteps': 1000,
|
86 |
+
"init_beta": 0.00085,
|
87 |
+
"last_beta": 0.0120,
|
88 |
+
'zero_terminal_snr': True,
|
89 |
+
},
|
90 |
+
'mean_type': 'v',
|
91 |
+
'loss_type': 'mse',
|
92 |
+
'var_type': 'fixed_small', # 'fixed_large',
|
93 |
+
'rescale_timesteps': False,
|
94 |
+
'noise_strength': 0.1
|
95 |
+
}
|
96 |
+
use_DiffusionDPM: False
|
97 |
+
CPU_CLIP_VAE: True
|
98 |
+
noise_prior_value: 949 # or 999, 949
|
UniAnimate/configs/UniAnimate_infer_long.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# manual setting
|
2 |
+
# resolution: [512, 768] # or [768, 1216]
|
3 |
+
resolution: [768, 1216]
|
4 |
+
round: 1
|
5 |
+
ddim_timesteps: 30 # among 25-50
|
6 |
+
context_size: 32
|
7 |
+
context_stride: 1
|
8 |
+
context_overlap: 8
|
9 |
+
seed: 7
|
10 |
+
max_frames: "None" # 64, 96, "None" mean the length of original pose sequence
|
11 |
+
test_list_path: [
|
12 |
+
# Format: [frame_interval, reference image, driving pose sequence]
|
13 |
+
[2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"],
|
14 |
+
[2, "data/images/musk.jpg", "data/saved_pose/musk"],
|
15 |
+
[2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"],
|
16 |
+
[2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"],
|
17 |
+
[2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337_dance"],
|
18 |
+
[2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full_dance"]
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
# default settings
|
23 |
+
TASK_TYPE: inference_unianimate_long_entrance
|
24 |
+
guide_scale: 2.5
|
25 |
+
vit_resolution: [224, 224]
|
26 |
+
use_fp16: True
|
27 |
+
batch_size: 1
|
28 |
+
latent_random_ref: True
|
29 |
+
chunk_size: 2
|
30 |
+
decoder_bs: 2
|
31 |
+
scale: 8
|
32 |
+
use_fps_condition: False
|
33 |
+
test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth
|
34 |
+
partial_keys: [
|
35 |
+
['image', 'randomref', "dwpose"],
|
36 |
+
]
|
37 |
+
embedder: {
|
38 |
+
'type': 'FrozenOpenCLIPTextVisualEmbedder',
|
39 |
+
'layer': 'penultimate',
|
40 |
+
'pretrained': 'checkpoints/open_clip_pytorch_model.bin'
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
auto_encoder: {
|
45 |
+
'type': 'AutoencoderKL',
|
46 |
+
'ddconfig': {
|
47 |
+
'double_z': True,
|
48 |
+
'z_channels': 4,
|
49 |
+
'resolution': 256,
|
50 |
+
'in_channels': 3,
|
51 |
+
'out_ch': 3,
|
52 |
+
'ch': 128,
|
53 |
+
'ch_mult': [1, 2, 4, 4],
|
54 |
+
'num_res_blocks': 2,
|
55 |
+
'attn_resolutions': [],
|
56 |
+
'dropout': 0.0,
|
57 |
+
'video_kernel_size': [3, 1, 1]
|
58 |
+
},
|
59 |
+
'embed_dim': 4,
|
60 |
+
'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt'
|
61 |
+
}
|
62 |
+
|
63 |
+
UNet: {
|
64 |
+
'type': 'UNetSD_UniAnimate',
|
65 |
+
'config': None,
|
66 |
+
'in_dim': 4,
|
67 |
+
'dim': 320,
|
68 |
+
'y_dim': 1024,
|
69 |
+
'context_dim': 1024,
|
70 |
+
'out_dim': 4,
|
71 |
+
'dim_mult': [1, 2, 4, 4],
|
72 |
+
'num_heads': 8,
|
73 |
+
'head_dim': 64,
|
74 |
+
'num_res_blocks': 2,
|
75 |
+
'dropout': 0.1,
|
76 |
+
'temporal_attention': True,
|
77 |
+
'num_tokens': 4,
|
78 |
+
'temporal_attn_times': 1,
|
79 |
+
'use_checkpoint': True,
|
80 |
+
'use_fps_condition': False,
|
81 |
+
'use_sim_mask': False
|
82 |
+
}
|
83 |
+
video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose']
|
84 |
+
Diffusion: {
|
85 |
+
'type': 'DiffusionDDIMLong',
|
86 |
+
'schedule': 'linear_sd',
|
87 |
+
'schedule_param': {
|
88 |
+
'num_timesteps': 1000,
|
89 |
+
"init_beta": 0.00085,
|
90 |
+
"last_beta": 0.0120,
|
91 |
+
'zero_terminal_snr': True,
|
92 |
+
},
|
93 |
+
'mean_type': 'v',
|
94 |
+
'loss_type': 'mse',
|
95 |
+
'var_type': 'fixed_small',
|
96 |
+
'rescale_timesteps': False,
|
97 |
+
'noise_strength': 0.1
|
98 |
+
}
|
99 |
+
CPU_CLIP_VAE: True
|
100 |
+
context_batch_size: 1
|
101 |
+
noise_prior_value: 939 # or 999, 949
|
UniAnimate/dwpose/__init__.py
ADDED
File without changes
|
UniAnimate/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime
|
5 |
+
|
6 |
+
def nms(boxes, scores, nms_thr):
|
7 |
+
"""Single class NMS implemented in Numpy."""
|
8 |
+
x1 = boxes[:, 0]
|
9 |
+
y1 = boxes[:, 1]
|
10 |
+
x2 = boxes[:, 2]
|
11 |
+
y2 = boxes[:, 3]
|
12 |
+
|
13 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
14 |
+
order = scores.argsort()[::-1]
|
15 |
+
|
16 |
+
keep = []
|
17 |
+
while order.size > 0:
|
18 |
+
i = order[0]
|
19 |
+
keep.append(i)
|
20 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
21 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
22 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
23 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
24 |
+
|
25 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
26 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
27 |
+
inter = w * h
|
28 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
29 |
+
|
30 |
+
inds = np.where(ovr <= nms_thr)[0]
|
31 |
+
order = order[inds + 1]
|
32 |
+
|
33 |
+
return keep
|
34 |
+
|
35 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
36 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
37 |
+
final_dets = []
|
38 |
+
num_classes = scores.shape[1]
|
39 |
+
for cls_ind in range(num_classes):
|
40 |
+
cls_scores = scores[:, cls_ind]
|
41 |
+
valid_score_mask = cls_scores > score_thr
|
42 |
+
if valid_score_mask.sum() == 0:
|
43 |
+
continue
|
44 |
+
else:
|
45 |
+
valid_scores = cls_scores[valid_score_mask]
|
46 |
+
valid_boxes = boxes[valid_score_mask]
|
47 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
48 |
+
if len(keep) > 0:
|
49 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
50 |
+
dets = np.concatenate(
|
51 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
52 |
+
)
|
53 |
+
final_dets.append(dets)
|
54 |
+
if len(final_dets) == 0:
|
55 |
+
return None
|
56 |
+
return np.concatenate(final_dets, 0)
|
57 |
+
|
58 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
59 |
+
grids = []
|
60 |
+
expanded_strides = []
|
61 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
62 |
+
|
63 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
64 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
65 |
+
|
66 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
67 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
68 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
69 |
+
grids.append(grid)
|
70 |
+
shape = grid.shape[:2]
|
71 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
72 |
+
|
73 |
+
grids = np.concatenate(grids, 1)
|
74 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
75 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
76 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
77 |
+
|
78 |
+
return outputs
|
79 |
+
|
80 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
81 |
+
if len(img.shape) == 3:
|
82 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
83 |
+
else:
|
84 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
85 |
+
|
86 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
87 |
+
resized_img = cv2.resize(
|
88 |
+
img,
|
89 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
90 |
+
interpolation=cv2.INTER_LINEAR,
|
91 |
+
).astype(np.uint8)
|
92 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
93 |
+
|
94 |
+
padded_img = padded_img.transpose(swap)
|
95 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
96 |
+
return padded_img, r
|
97 |
+
|
98 |
+
def inference_detector(session, oriImg):
|
99 |
+
input_shape = (640,640)
|
100 |
+
img, ratio = preprocess(oriImg, input_shape)
|
101 |
+
|
102 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
103 |
+
|
104 |
+
output = session.run(None, ort_inputs)
|
105 |
+
|
106 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
107 |
+
|
108 |
+
boxes = predictions[:, :4]
|
109 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
110 |
+
|
111 |
+
boxes_xyxy = np.ones_like(boxes)
|
112 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
114 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
115 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
116 |
+
boxes_xyxy /= ratio
|
117 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
118 |
+
if dets is not None:
|
119 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
120 |
+
isscore = final_scores>0.3
|
121 |
+
iscat = final_cls_inds == 0
|
122 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
123 |
+
final_boxes = final_boxes[isbbox]
|
124 |
+
else:
|
125 |
+
final_boxes = np.array([])
|
126 |
+
|
127 |
+
return final_boxes
|
UniAnimate/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
|
7 |
+
def preprocess(
|
8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
10 |
+
"""Do preprocessing for RTMPose model inference.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (np.ndarray): Input image in shape.
|
14 |
+
input_size (tuple): Input image size in shape (w, h).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple:
|
18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
19 |
+
- center (np.ndarray): Center of image.
|
20 |
+
- scale (np.ndarray): Scale of image.
|
21 |
+
"""
|
22 |
+
# get shape of image
|
23 |
+
img_shape = img.shape[:2]
|
24 |
+
out_img, out_center, out_scale = [], [], []
|
25 |
+
if len(out_bbox) == 0:
|
26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
27 |
+
for i in range(len(out_bbox)):
|
28 |
+
x0 = out_bbox[i][0]
|
29 |
+
y0 = out_bbox[i][1]
|
30 |
+
x1 = out_bbox[i][2]
|
31 |
+
y1 = out_bbox[i][3]
|
32 |
+
bbox = np.array([x0, y0, x1, y1])
|
33 |
+
|
34 |
+
# get center and scale
|
35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
36 |
+
|
37 |
+
# do affine transformation
|
38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
39 |
+
|
40 |
+
# normalize image
|
41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
42 |
+
std = np.array([58.395, 57.12, 57.375])
|
43 |
+
resized_img = (resized_img - mean) / std
|
44 |
+
|
45 |
+
out_img.append(resized_img)
|
46 |
+
out_center.append(center)
|
47 |
+
out_scale.append(scale)
|
48 |
+
|
49 |
+
return out_img, out_center, out_scale
|
50 |
+
|
51 |
+
|
52 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
53 |
+
"""Inference RTMPose model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
57 |
+
img (np.ndarray): Input image in shape.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
outputs (np.ndarray): Output of RTMPose model.
|
61 |
+
"""
|
62 |
+
all_out = []
|
63 |
+
# build input
|
64 |
+
for i in range(len(img)):
|
65 |
+
input = [img[i].transpose(2, 0, 1)]
|
66 |
+
|
67 |
+
# build output
|
68 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
69 |
+
sess_output = []
|
70 |
+
for out in sess.get_outputs():
|
71 |
+
sess_output.append(out.name)
|
72 |
+
|
73 |
+
# run model
|
74 |
+
outputs = sess.run(sess_output, sess_input)
|
75 |
+
all_out.append(outputs)
|
76 |
+
|
77 |
+
return all_out
|
78 |
+
|
79 |
+
|
80 |
+
def postprocess(outputs: List[np.ndarray],
|
81 |
+
model_input_size: Tuple[int, int],
|
82 |
+
center: Tuple[int, int],
|
83 |
+
scale: Tuple[int, int],
|
84 |
+
simcc_split_ratio: float = 2.0
|
85 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
86 |
+
"""Postprocess for RTMPose model output.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
outputs (np.ndarray): Output of RTMPose model.
|
90 |
+
model_input_size (tuple): RTMPose model Input image size.
|
91 |
+
center (tuple): Center of bbox in shape (x, y).
|
92 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
93 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple:
|
97 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
98 |
+
- scores (np.ndarray): Model predict scores.
|
99 |
+
"""
|
100 |
+
all_key = []
|
101 |
+
all_score = []
|
102 |
+
for i in range(len(outputs)):
|
103 |
+
# use simcc to decode
|
104 |
+
simcc_x, simcc_y = outputs[i]
|
105 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
106 |
+
|
107 |
+
# rescale keypoints
|
108 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
109 |
+
all_key.append(keypoints[0])
|
110 |
+
all_score.append(scores[0])
|
111 |
+
|
112 |
+
return np.array(all_key), np.array(all_score)
|
113 |
+
|
114 |
+
|
115 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
116 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
117 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
121 |
+
as (left, top, right, bottom)
|
122 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
123 |
+
Default: 1.0
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tuple: A tuple containing center and scale.
|
127 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
128 |
+
(n, 2)
|
129 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
"""
|
132 |
+
# convert single bbox from (4, ) to (1, 4)
|
133 |
+
dim = bbox.ndim
|
134 |
+
if dim == 1:
|
135 |
+
bbox = bbox[None, :]
|
136 |
+
|
137 |
+
# get bbox center and scale
|
138 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
139 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
140 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
141 |
+
|
142 |
+
if dim == 1:
|
143 |
+
center = center[0]
|
144 |
+
scale = scale[0]
|
145 |
+
|
146 |
+
return center, scale
|
147 |
+
|
148 |
+
|
149 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
150 |
+
aspect_ratio: float) -> np.ndarray:
|
151 |
+
"""Extend the scale to match the given aspect ratio.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
155 |
+
aspect_ratio (float): The ratio of ``w/h``
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
np.ndarray: The reshaped image scale in (2, )
|
159 |
+
"""
|
160 |
+
w, h = np.hsplit(bbox_scale, [1])
|
161 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
162 |
+
np.hstack([w, w / aspect_ratio]),
|
163 |
+
np.hstack([h * aspect_ratio, h]))
|
164 |
+
return bbox_scale
|
165 |
+
|
166 |
+
|
167 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
168 |
+
"""Rotate a point by an angle.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
172 |
+
angle_rad (float): rotation angle in radian
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
np.ndarray: Rotated point in shape (2, )
|
176 |
+
"""
|
177 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
178 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
179 |
+
return rot_mat @ pt
|
180 |
+
|
181 |
+
|
182 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
183 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
184 |
+
function is used to get the 3rd point, given 2D points a & b.
|
185 |
+
|
186 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
187 |
+
anticlockwise, using b as the rotation center.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
191 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
np.ndarray: The 3rd point.
|
195 |
+
"""
|
196 |
+
direction = a - b
|
197 |
+
c = b + np.r_[-direction[1], direction[0]]
|
198 |
+
return c
|
199 |
+
|
200 |
+
|
201 |
+
def get_warp_matrix(center: np.ndarray,
|
202 |
+
scale: np.ndarray,
|
203 |
+
rot: float,
|
204 |
+
output_size: Tuple[int, int],
|
205 |
+
shift: Tuple[float, float] = (0., 0.),
|
206 |
+
inv: bool = False) -> np.ndarray:
|
207 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
208 |
+
in the input image to the output size.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
212 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
213 |
+
wrt [width, height].
|
214 |
+
rot (float): Rotation angle (degree).
|
215 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
216 |
+
destination heatmaps.
|
217 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
218 |
+
Default (0., 0.).
|
219 |
+
inv (bool): Option to inverse the affine transform direction.
|
220 |
+
(inv=False: src->dst or inv=True: dst->src)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
np.ndarray: A 2x3 transformation matrix
|
224 |
+
"""
|
225 |
+
shift = np.array(shift)
|
226 |
+
src_w = scale[0]
|
227 |
+
dst_w = output_size[0]
|
228 |
+
dst_h = output_size[1]
|
229 |
+
|
230 |
+
# compute transformation matrix
|
231 |
+
rot_rad = np.deg2rad(rot)
|
232 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
233 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
234 |
+
|
235 |
+
# get four corners of the src rectangle in the original image
|
236 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
237 |
+
src[0, :] = center + scale * shift
|
238 |
+
src[1, :] = center + src_dir + scale * shift
|
239 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
240 |
+
|
241 |
+
# get four corners of the dst rectangle in the input image
|
242 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
243 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
244 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
245 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
246 |
+
|
247 |
+
if inv:
|
248 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
249 |
+
else:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
251 |
+
|
252 |
+
return warp_mat
|
253 |
+
|
254 |
+
|
255 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
256 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
257 |
+
"""Get the bbox image as the model input by affine transform.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
input_size (dict): The input size of the model.
|
261 |
+
bbox_scale (dict): The bbox scale of the img.
|
262 |
+
bbox_center (dict): The bbox center of the img.
|
263 |
+
img (np.ndarray): The original image.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: A tuple containing center and scale.
|
267 |
+
- np.ndarray[float32]: img after affine transform.
|
268 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
269 |
+
"""
|
270 |
+
w, h = input_size
|
271 |
+
warp_size = (int(w), int(h))
|
272 |
+
|
273 |
+
# reshape bbox to fixed aspect ratio
|
274 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
275 |
+
|
276 |
+
# get the affine matrix
|
277 |
+
center = bbox_center
|
278 |
+
scale = bbox_scale
|
279 |
+
rot = 0
|
280 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
281 |
+
|
282 |
+
# do affine transform
|
283 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
284 |
+
|
285 |
+
return img, bbox_scale
|
286 |
+
|
287 |
+
|
288 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
289 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
290 |
+
"""Get maximum response location and value from simcc representations.
|
291 |
+
|
292 |
+
Note:
|
293 |
+
instance number: N
|
294 |
+
num_keypoints: K
|
295 |
+
heatmap height: H
|
296 |
+
heatmap width: W
|
297 |
+
|
298 |
+
Args:
|
299 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
300 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
tuple:
|
304 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
305 |
+
(K, 2) or (N, K, 2)
|
306 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
307 |
+
(K,) or (N, K)
|
308 |
+
"""
|
309 |
+
N, K, Wx = simcc_x.shape
|
310 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
311 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
312 |
+
|
313 |
+
# get maximum value locations
|
314 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
315 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
316 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
317 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
318 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
319 |
+
|
320 |
+
# get maximum value across x and y axis
|
321 |
+
mask = max_val_x > max_val_y
|
322 |
+
max_val_x[mask] = max_val_y[mask]
|
323 |
+
vals = max_val_x
|
324 |
+
locs[vals <= 0.] = -1
|
325 |
+
|
326 |
+
# reshape
|
327 |
+
locs = locs.reshape(N, K, 2)
|
328 |
+
vals = vals.reshape(N, K)
|
329 |
+
|
330 |
+
return locs, vals
|
331 |
+
|
332 |
+
|
333 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
334 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
335 |
+
"""Modulate simcc distribution with Gaussian.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
339 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
340 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
tuple: A tuple containing center and scale.
|
344 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
345 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
346 |
+
"""
|
347 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
348 |
+
keypoints /= simcc_split_ratio
|
349 |
+
|
350 |
+
return keypoints, scores
|
351 |
+
|
352 |
+
|
353 |
+
def inference_pose(session, out_bbox, oriImg):
|
354 |
+
h, w = session.get_inputs()[0].shape[2:]
|
355 |
+
model_input_size = (w, h)
|
356 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
357 |
+
outputs = inference(session, resized_img)
|
358 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
359 |
+
|
360 |
+
return keypoints, scores
|
UniAnimate/dwpose/util.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
eps = 0.01
|
8 |
+
|
9 |
+
|
10 |
+
def smart_resize(x, s):
|
11 |
+
Ht, Wt = s
|
12 |
+
if x.ndim == 2:
|
13 |
+
Ho, Wo = x.shape
|
14 |
+
Co = 1
|
15 |
+
else:
|
16 |
+
Ho, Wo, Co = x.shape
|
17 |
+
if Co == 3 or Co == 1:
|
18 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
19 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
20 |
+
else:
|
21 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
22 |
+
|
23 |
+
|
24 |
+
def smart_resize_k(x, fx, fy):
|
25 |
+
if x.ndim == 2:
|
26 |
+
Ho, Wo = x.shape
|
27 |
+
Co = 1
|
28 |
+
else:
|
29 |
+
Ho, Wo, Co = x.shape
|
30 |
+
Ht, Wt = Ho * fy, Wo * fx
|
31 |
+
if Co == 3 or Co == 1:
|
32 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
33 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
34 |
+
else:
|
35 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
36 |
+
|
37 |
+
|
38 |
+
def padRightDownCorner(img, stride, padValue):
|
39 |
+
h = img.shape[0]
|
40 |
+
w = img.shape[1]
|
41 |
+
|
42 |
+
pad = 4 * [None]
|
43 |
+
pad[0] = 0 # up
|
44 |
+
pad[1] = 0 # left
|
45 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
46 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
47 |
+
|
48 |
+
img_padded = img
|
49 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
50 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
51 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
52 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
53 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
54 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
55 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
56 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
57 |
+
|
58 |
+
return img_padded, pad
|
59 |
+
|
60 |
+
|
61 |
+
def transfer(model, model_weights):
|
62 |
+
transfered_model_weights = {}
|
63 |
+
for weights_name in model.state_dict().keys():
|
64 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
65 |
+
return transfered_model_weights
|
66 |
+
|
67 |
+
|
68 |
+
def draw_bodypose(canvas, candidate, subset):
|
69 |
+
H, W, C = canvas.shape
|
70 |
+
candidate = np.array(candidate)
|
71 |
+
subset = np.array(subset)
|
72 |
+
|
73 |
+
stickwidth = 4
|
74 |
+
|
75 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
76 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
77 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
78 |
+
|
79 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
80 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
81 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
82 |
+
|
83 |
+
for i in range(17):
|
84 |
+
for n in range(len(subset)):
|
85 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
86 |
+
if -1 in index:
|
87 |
+
continue
|
88 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
89 |
+
X = candidate[index.astype(int), 1] * float(H)
|
90 |
+
mX = np.mean(X)
|
91 |
+
mY = np.mean(Y)
|
92 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
93 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
94 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
95 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
96 |
+
|
97 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
98 |
+
|
99 |
+
for i in range(18):
|
100 |
+
for n in range(len(subset)):
|
101 |
+
index = int(subset[n][i])
|
102 |
+
if index == -1:
|
103 |
+
continue
|
104 |
+
x, y = candidate[index][0:2]
|
105 |
+
x = int(x * W)
|
106 |
+
y = int(y * H)
|
107 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
108 |
+
|
109 |
+
return canvas
|
110 |
+
|
111 |
+
|
112 |
+
def draw_body_and_foot(canvas, candidate, subset):
|
113 |
+
H, W, C = canvas.shape
|
114 |
+
candidate = np.array(candidate)
|
115 |
+
subset = np.array(subset)
|
116 |
+
|
117 |
+
stickwidth = 4
|
118 |
+
|
119 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
120 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
121 |
+
[1, 16], [16, 18], [14,19], [11, 20]]
|
122 |
+
|
123 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
124 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
125 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [170, 255, 255], [255, 255, 0]]
|
126 |
+
|
127 |
+
for i in range(19):
|
128 |
+
for n in range(len(subset)):
|
129 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
130 |
+
if -1 in index:
|
131 |
+
continue
|
132 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
133 |
+
X = candidate[index.astype(int), 1] * float(H)
|
134 |
+
mX = np.mean(X)
|
135 |
+
mY = np.mean(Y)
|
136 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
137 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
138 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
139 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
140 |
+
|
141 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
142 |
+
|
143 |
+
for i in range(20):
|
144 |
+
for n in range(len(subset)):
|
145 |
+
index = int(subset[n][i])
|
146 |
+
if index == -1:
|
147 |
+
continue
|
148 |
+
x, y = candidate[index][0:2]
|
149 |
+
x = int(x * W)
|
150 |
+
y = int(y * H)
|
151 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
152 |
+
|
153 |
+
return canvas
|
154 |
+
|
155 |
+
|
156 |
+
def draw_handpose(canvas, all_hand_peaks):
|
157 |
+
H, W, C = canvas.shape
|
158 |
+
|
159 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
160 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
161 |
+
|
162 |
+
for peaks in all_hand_peaks:
|
163 |
+
peaks = np.array(peaks)
|
164 |
+
|
165 |
+
for ie, e in enumerate(edges):
|
166 |
+
x1, y1 = peaks[e[0]]
|
167 |
+
x2, y2 = peaks[e[1]]
|
168 |
+
x1 = int(x1 * W)
|
169 |
+
y1 = int(y1 * H)
|
170 |
+
x2 = int(x2 * W)
|
171 |
+
y2 = int(y2 * H)
|
172 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
173 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
174 |
+
|
175 |
+
for i, keyponit in enumerate(peaks):
|
176 |
+
x, y = keyponit
|
177 |
+
x = int(x * W)
|
178 |
+
y = int(y * H)
|
179 |
+
if x > eps and y > eps:
|
180 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
181 |
+
return canvas
|
182 |
+
|
183 |
+
|
184 |
+
def draw_facepose(canvas, all_lmks):
|
185 |
+
H, W, C = canvas.shape
|
186 |
+
for lmks in all_lmks:
|
187 |
+
lmks = np.array(lmks)
|
188 |
+
for lmk in lmks:
|
189 |
+
x, y = lmk
|
190 |
+
x = int(x * W)
|
191 |
+
y = int(y * H)
|
192 |
+
if x > eps and y > eps:
|
193 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
194 |
+
return canvas
|
195 |
+
|
196 |
+
|
197 |
+
# detect hand according to body pose keypoints
|
198 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
199 |
+
def handDetect(candidate, subset, oriImg):
|
200 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
201 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
202 |
+
ratioWristElbow = 0.33
|
203 |
+
detect_result = []
|
204 |
+
image_height, image_width = oriImg.shape[0:2]
|
205 |
+
for person in subset.astype(int):
|
206 |
+
# if any of three not detected
|
207 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
208 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
209 |
+
if not (has_left or has_right):
|
210 |
+
continue
|
211 |
+
hands = []
|
212 |
+
#left hand
|
213 |
+
if has_left:
|
214 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
215 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
216 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
217 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
218 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
219 |
+
# right hand
|
220 |
+
if has_right:
|
221 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
222 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
223 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
224 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
225 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
226 |
+
|
227 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
228 |
+
|
229 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
230 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
231 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
232 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
233 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
234 |
+
# x-y refers to the center --> offset to topLeft point
|
235 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
236 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
237 |
+
x -= width / 2
|
238 |
+
y -= width / 2 # width = height
|
239 |
+
# overflow the image
|
240 |
+
if x < 0: x = 0
|
241 |
+
if y < 0: y = 0
|
242 |
+
width1 = width
|
243 |
+
width2 = width
|
244 |
+
if x + width > image_width: width1 = image_width - x
|
245 |
+
if y + width > image_height: width2 = image_height - y
|
246 |
+
width = min(width1, width2)
|
247 |
+
# the max hand box value is 20 pixels
|
248 |
+
if width >= 20:
|
249 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
250 |
+
|
251 |
+
'''
|
252 |
+
return value: [[x, y, w, True if left hand else False]].
|
253 |
+
width=height since the network require squared input.
|
254 |
+
x, y is the coordinate of top left
|
255 |
+
'''
|
256 |
+
return detect_result
|
257 |
+
|
258 |
+
|
259 |
+
# Written by Lvmin
|
260 |
+
def faceDetect(candidate, subset, oriImg):
|
261 |
+
# left right eye ear 14 15 16 17
|
262 |
+
detect_result = []
|
263 |
+
image_height, image_width = oriImg.shape[0:2]
|
264 |
+
for person in subset.astype(int):
|
265 |
+
has_head = person[0] > -1
|
266 |
+
if not has_head:
|
267 |
+
continue
|
268 |
+
|
269 |
+
has_left_eye = person[14] > -1
|
270 |
+
has_right_eye = person[15] > -1
|
271 |
+
has_left_ear = person[16] > -1
|
272 |
+
has_right_ear = person[17] > -1
|
273 |
+
|
274 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
275 |
+
continue
|
276 |
+
|
277 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
278 |
+
|
279 |
+
width = 0.0
|
280 |
+
x0, y0 = candidate[head][:2]
|
281 |
+
|
282 |
+
if has_left_eye:
|
283 |
+
x1, y1 = candidate[left_eye][:2]
|
284 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
285 |
+
width = max(width, d * 3.0)
|
286 |
+
|
287 |
+
if has_right_eye:
|
288 |
+
x1, y1 = candidate[right_eye][:2]
|
289 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
290 |
+
width = max(width, d * 3.0)
|
291 |
+
|
292 |
+
if has_left_ear:
|
293 |
+
x1, y1 = candidate[left_ear][:2]
|
294 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
295 |
+
width = max(width, d * 1.5)
|
296 |
+
|
297 |
+
if has_right_ear:
|
298 |
+
x1, y1 = candidate[right_ear][:2]
|
299 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
300 |
+
width = max(width, d * 1.5)
|
301 |
+
|
302 |
+
x, y = x0, y0
|
303 |
+
|
304 |
+
x -= width
|
305 |
+
y -= width
|
306 |
+
|
307 |
+
if x < 0:
|
308 |
+
x = 0
|
309 |
+
|
310 |
+
if y < 0:
|
311 |
+
y = 0
|
312 |
+
|
313 |
+
width1 = width * 2
|
314 |
+
width2 = width * 2
|
315 |
+
|
316 |
+
if x + width > image_width:
|
317 |
+
width1 = image_width - x
|
318 |
+
|
319 |
+
if y + width > image_height:
|
320 |
+
width2 = image_height - y
|
321 |
+
|
322 |
+
width = min(width1, width2)
|
323 |
+
|
324 |
+
if width >= 20:
|
325 |
+
detect_result.append([int(x), int(y), int(width)])
|
326 |
+
|
327 |
+
return detect_result
|
328 |
+
|
329 |
+
|
330 |
+
# get max index of 2d array
|
331 |
+
def npmax(array):
|
332 |
+
arrayindex = array.argmax(1)
|
333 |
+
arrayvalue = array.max(1)
|
334 |
+
i = arrayvalue.argmax()
|
335 |
+
j = arrayindex[i]
|
336 |
+
return i, j
|
UniAnimate/dwpose/wholebody.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime as ort
|
5 |
+
from dwpose.onnxdet import inference_detector
|
6 |
+
from dwpose.onnxpose import inference_pose
|
7 |
+
|
8 |
+
class Wholebody:
|
9 |
+
def __init__(self):
|
10 |
+
device = 'cuda' # 'cpu' #
|
11 |
+
providers = ['CPUExecutionProvider'
|
12 |
+
] if device == 'cpu' else ['CUDAExecutionProvider']
|
13 |
+
onnx_det = 'checkpoints/yolox_l.onnx'
|
14 |
+
onnx_pose = 'checkpoints/dw-ll_ucoco_384.onnx'
|
15 |
+
|
16 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
17 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
18 |
+
|
19 |
+
def __call__(self, oriImg):
|
20 |
+
det_result = inference_detector(self.session_det, oriImg)
|
21 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
22 |
+
|
23 |
+
keypoints_info = np.concatenate(
|
24 |
+
(keypoints, scores[..., None]), axis=-1)
|
25 |
+
# compute neck joint
|
26 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
27 |
+
# neck score when visualizing pred
|
28 |
+
neck[:, 2:4] = np.logical_and(
|
29 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
30 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
31 |
+
new_keypoints_info = np.insert(
|
32 |
+
keypoints_info, 17, neck, axis=1)
|
33 |
+
mmpose_idx = [
|
34 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
35 |
+
]
|
36 |
+
openpose_idx = [
|
37 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
38 |
+
]
|
39 |
+
new_keypoints_info[:, openpose_idx] = \
|
40 |
+
new_keypoints_info[:, mmpose_idx]
|
41 |
+
keypoints_info = new_keypoints_info
|
42 |
+
|
43 |
+
keypoints, scores = keypoints_info[
|
44 |
+
..., :2], keypoints_info[..., 2]
|
45 |
+
|
46 |
+
return keypoints, scores
|
47 |
+
|
48 |
+
|
UniAnimate/environment.yaml
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: /mnt/user/miniconda3/envs/dtrans
|
2 |
+
channels:
|
3 |
+
- http://mirrors.aliyun.com/anaconda/pkgs/main
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- ca-certificates=2023.12.12=h06a4308_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.4.4=h6a678d5_0
|
11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
12 |
+
- libgomp=11.2.0=h1234567_1
|
13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
14 |
+
- ncurses=6.4=h6a678d5_0
|
15 |
+
- openssl=3.0.12=h7f8727e_0
|
16 |
+
- pip=23.3.1=py39h06a4308_0
|
17 |
+
- python=3.9.18=h955ad1f_0
|
18 |
+
- readline=8.2=h5eee18b_0
|
19 |
+
- setuptools=68.2.2=py39h06a4308_0
|
20 |
+
- sqlite=3.41.2=h5eee18b_0
|
21 |
+
- tk=8.6.12=h1ccaba5_0
|
22 |
+
- wheel=0.41.2=py39h06a4308_0
|
23 |
+
- xz=5.4.5=h5eee18b_0
|
24 |
+
- zlib=1.2.13=h5eee18b_0
|
25 |
+
- pip:
|
26 |
+
- aiofiles==23.2.1
|
27 |
+
- aiohttp==3.9.1
|
28 |
+
- aiosignal==1.3.1
|
29 |
+
- aliyun-python-sdk-core==2.14.0
|
30 |
+
- aliyun-python-sdk-kms==2.16.2
|
31 |
+
- altair==5.2.0
|
32 |
+
- annotated-types==0.6.0
|
33 |
+
- antlr4-python3-runtime==4.9.3
|
34 |
+
- anyio==4.2.0
|
35 |
+
- argparse==1.4.0
|
36 |
+
- asttokens==2.4.1
|
37 |
+
- async-timeout==4.0.3
|
38 |
+
- attrs==23.2.0
|
39 |
+
- automat==22.10.0
|
40 |
+
- beartype==0.16.4
|
41 |
+
- blessed==1.20.0
|
42 |
+
- buildtools==1.0.6
|
43 |
+
- causal-conv1d==1.1.3.post1
|
44 |
+
- certifi==2023.11.17
|
45 |
+
- cffi==1.16.0
|
46 |
+
- chardet==5.2.0
|
47 |
+
- charset-normalizer==3.3.2
|
48 |
+
- clean-fid==0.1.35
|
49 |
+
- click==8.1.7
|
50 |
+
- clip==1.0
|
51 |
+
- cmake==3.28.1
|
52 |
+
- colorama==0.4.6
|
53 |
+
- coloredlogs==15.0.1
|
54 |
+
- constantly==23.10.4
|
55 |
+
- contourpy==1.2.0
|
56 |
+
- crcmod==1.7
|
57 |
+
- cryptography==41.0.7
|
58 |
+
- cycler==0.12.1
|
59 |
+
- decorator==5.1.1
|
60 |
+
- decord==0.6.0
|
61 |
+
- diffusers==0.26.3
|
62 |
+
- docopt==0.6.2
|
63 |
+
- easydict==1.11
|
64 |
+
- einops==0.7.0
|
65 |
+
- exceptiongroup==1.2.0
|
66 |
+
- executing==2.0.1
|
67 |
+
- fairscale==0.4.13
|
68 |
+
- fastapi==0.109.0
|
69 |
+
- ffmpeg==1.4
|
70 |
+
- ffmpy==0.3.1
|
71 |
+
- filelock==3.13.1
|
72 |
+
- flatbuffers==24.3.25
|
73 |
+
- fonttools==4.47.2
|
74 |
+
- frozenlist==1.4.1
|
75 |
+
- fsspec==2023.12.2
|
76 |
+
- ftfy==6.1.3
|
77 |
+
- furl==2.1.3
|
78 |
+
- gpustat==1.1.1
|
79 |
+
- gradio==4.14.0
|
80 |
+
- gradio-client==0.8.0
|
81 |
+
- greenlet==3.0.3
|
82 |
+
- h11==0.14.0
|
83 |
+
- httpcore==1.0.2
|
84 |
+
- httpx==0.26.0
|
85 |
+
- huggingface-hub==0.20.2
|
86 |
+
- humanfriendly==10.0
|
87 |
+
- hyperlink==21.0.0
|
88 |
+
- idna==3.6
|
89 |
+
- imageio==2.33.1
|
90 |
+
- imageio-ffmpeg==0.4.9
|
91 |
+
- importlib-metadata==7.0.1
|
92 |
+
- importlib-resources==6.1.1
|
93 |
+
- incremental==22.10.0
|
94 |
+
- ipdb==0.13.13
|
95 |
+
- ipython==8.18.1
|
96 |
+
- jedi==0.19.1
|
97 |
+
- jinja2==3.1.3
|
98 |
+
- jmespath==0.10.0
|
99 |
+
- joblib==1.3.2
|
100 |
+
- jsonschema==4.21.0
|
101 |
+
- jsonschema-specifications==2023.12.1
|
102 |
+
- kiwisolver==1.4.5
|
103 |
+
- kornia==0.7.1
|
104 |
+
- lazy-loader==0.3
|
105 |
+
- lightning-utilities==0.10.0
|
106 |
+
- lit==17.0.6
|
107 |
+
- lpips==0.1.4
|
108 |
+
- mamba-ssm==1.1.4
|
109 |
+
- markdown-it-py==3.0.0
|
110 |
+
- markupsafe==2.1.3
|
111 |
+
- matplotlib==3.8.2
|
112 |
+
- matplotlib-inline==0.1.6
|
113 |
+
- mdurl==0.1.2
|
114 |
+
- motion-vector-extractor==1.0.6
|
115 |
+
- mpmath==1.3.0
|
116 |
+
- multidict==6.0.4
|
117 |
+
- mypy-extensions==1.0.0
|
118 |
+
- networkx==3.2.1
|
119 |
+
- ninja==1.11.1.1
|
120 |
+
- numpy==1.26.3
|
121 |
+
- nvidia-cublas-cu11==11.10.3.66
|
122 |
+
- nvidia-cublas-cu12==12.1.3.1
|
123 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
124 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
125 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
126 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
127 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
128 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
129 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
130 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
131 |
+
- nvidia-cufft-cu11==10.9.0.58
|
132 |
+
- nvidia-cufft-cu12==11.0.2.54
|
133 |
+
- nvidia-curand-cu11==10.2.10.91
|
134 |
+
- nvidia-curand-cu12==10.3.2.106
|
135 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
136 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
137 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
138 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
139 |
+
- nvidia-ml-py==12.535.133
|
140 |
+
- nvidia-nccl-cu11==2.14.3
|
141 |
+
- nvidia-nccl-cu12==2.19.3
|
142 |
+
- nvidia-nvjitlink-cu12==12.3.101
|
143 |
+
- nvidia-nvtx-cu11==11.7.91
|
144 |
+
- nvidia-nvtx-cu12==12.1.105
|
145 |
+
- omegaconf==2.3.0
|
146 |
+
- onnxruntime==1.18.0
|
147 |
+
- open-clip-torch==2.24.0
|
148 |
+
- opencv-python==4.5.3.56
|
149 |
+
- opencv-python-headless==4.9.0.80
|
150 |
+
- orderedmultidict==1.0.1
|
151 |
+
- orjson==3.9.10
|
152 |
+
- oss2==2.18.4
|
153 |
+
- packaging==23.2
|
154 |
+
- pandas==2.1.4
|
155 |
+
- parso==0.8.3
|
156 |
+
- pexpect==4.9.0
|
157 |
+
- pillow==10.2.0
|
158 |
+
- piq==0.8.0
|
159 |
+
- pkgconfig==1.5.5
|
160 |
+
- prompt-toolkit==3.0.43
|
161 |
+
- protobuf==4.25.2
|
162 |
+
- psutil==5.9.8
|
163 |
+
- ptflops==0.7.2.2
|
164 |
+
- ptyprocess==0.7.0
|
165 |
+
- pure-eval==0.2.2
|
166 |
+
- pycparser==2.21
|
167 |
+
- pycryptodome==3.20.0
|
168 |
+
- pydantic==2.5.3
|
169 |
+
- pydantic-core==2.14.6
|
170 |
+
- pydub==0.25.1
|
171 |
+
- pygments==2.17.2
|
172 |
+
- pynvml==11.5.0
|
173 |
+
- pyparsing==3.1.1
|
174 |
+
- pyre-extensions==0.0.29
|
175 |
+
- python-dateutil==2.8.2
|
176 |
+
- python-multipart==0.0.6
|
177 |
+
- pytorch-lightning==2.1.3
|
178 |
+
- pytz==2023.3.post1
|
179 |
+
- pyyaml==6.0.1
|
180 |
+
- redo==2.0.4
|
181 |
+
- referencing==0.32.1
|
182 |
+
- regex==2023.12.25
|
183 |
+
- requests==2.31.0
|
184 |
+
- rich==13.7.0
|
185 |
+
- rotary-embedding-torch==0.5.3
|
186 |
+
- rpds-py==0.17.1
|
187 |
+
- ruff==0.2.0
|
188 |
+
- safetensors==0.4.1
|
189 |
+
- scikit-image==0.22.0
|
190 |
+
- scikit-learn==1.4.0
|
191 |
+
- scipy==1.11.4
|
192 |
+
- semantic-version==2.10.0
|
193 |
+
- sentencepiece==0.1.99
|
194 |
+
- shellingham==1.5.4
|
195 |
+
- simplejson==3.19.2
|
196 |
+
- six==1.16.0
|
197 |
+
- sk-video==1.1.10
|
198 |
+
- sniffio==1.3.0
|
199 |
+
- sqlalchemy==2.0.27
|
200 |
+
- stack-data==0.6.3
|
201 |
+
- starlette==0.35.1
|
202 |
+
- sympy==1.12
|
203 |
+
- thop==0.1.1-2209072238
|
204 |
+
- threadpoolctl==3.2.0
|
205 |
+
- tifffile==2023.12.9
|
206 |
+
- timm==0.9.12
|
207 |
+
- tokenizers==0.15.0
|
208 |
+
- tomli==2.0.1
|
209 |
+
- tomlkit==0.12.0
|
210 |
+
- toolz==0.12.0
|
211 |
+
- torch==2.0.1+cu118
|
212 |
+
- torchaudio==2.0.2+cu118
|
213 |
+
- torchdiffeq==0.2.3
|
214 |
+
- torchmetrics==1.3.0.post0
|
215 |
+
- torchsde==0.2.6
|
216 |
+
- torchvision==0.15.2+cu118
|
217 |
+
- tqdm==4.66.1
|
218 |
+
- traitlets==5.14.1
|
219 |
+
- trampoline==0.1.2
|
220 |
+
- transformers==4.36.2
|
221 |
+
- triton==2.0.0
|
222 |
+
- twisted==23.10.0
|
223 |
+
- typer==0.9.0
|
224 |
+
- typing-extensions==4.9.0
|
225 |
+
- typing-inspect==0.9.0
|
226 |
+
- tzdata==2023.4
|
227 |
+
- urllib3==2.1.0
|
228 |
+
- uvicorn==0.26.0
|
229 |
+
- wcwidth==0.2.13
|
230 |
+
- websockets==11.0.3
|
231 |
+
- xformers==0.0.20
|
232 |
+
- yarl==1.9.4
|
233 |
+
- zipp==3.17.0
|
234 |
+
- zope-interface==6.2
|
235 |
+
- onnxruntime-gpu==1.13.1
|
236 |
+
prefix: /mnt/user/miniconda3/envs/dtrans
|
UniAnimate/inference.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import logging
|
8 |
+
import itertools
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from utils.config import Config
|
12 |
+
from utils.registry_class import INFER_ENGINE
|
13 |
+
|
14 |
+
from tools import *
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
cfg_update = Config(load=True)
|
18 |
+
INFER_ENGINE.build(dict(type=cfg_update.TASK_TYPE), cfg_update=cfg_update.cfg_dict)
|
UniAnimate/requirements.txt
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# antlr4-python3-runtime==4.9.3
|
2 |
+
# anyio==4.2.0
|
3 |
+
# asttokens==2.4.1
|
4 |
+
# async-timeout==4.0.3
|
5 |
+
# attrs==23.2.0
|
6 |
+
# Automat==22.10.0
|
7 |
+
# beartype==0.16.4
|
8 |
+
# blessed==1.20.0
|
9 |
+
# buildtools==1.0.6
|
10 |
+
# # causal-conv1d==1.1.3.post1
|
11 |
+
# certifi==2023.11.17
|
12 |
+
# cffi==1.16.0
|
13 |
+
# chardet==5.2.0
|
14 |
+
# charset-normalizer==3.3.2
|
15 |
+
# clean-fid==0.1.35
|
16 |
+
# click==8.1.7
|
17 |
+
# # clip==1.0
|
18 |
+
# cmake==3.28.1
|
19 |
+
# colorama==0.4.6
|
20 |
+
# coloredlogs==15.0.1
|
21 |
+
# constantly==23.10.4
|
22 |
+
# contourpy==1.2.0
|
23 |
+
# crcmod==1.7
|
24 |
+
# cryptography==41.0.7
|
25 |
+
# cycler==0.12.1
|
26 |
+
# decorator==5.1.1
|
27 |
+
# decord==0.6.0
|
28 |
+
# diffusers==0.26.3
|
29 |
+
# docopt==0.6.2
|
30 |
+
easydict==1.11
|
31 |
+
einops==0.7.0
|
32 |
+
# exceptiongroup==1.2.0
|
33 |
+
# executing==2.0.1
|
34 |
+
fairscale==0.4.13
|
35 |
+
# fastapi==0.109.0
|
36 |
+
# ffmpeg==1.4
|
37 |
+
# ffmpy==0.3.1
|
38 |
+
# filelock==3.13.1
|
39 |
+
# flatbuffers==24.3.25
|
40 |
+
# fonttools==4.47.2
|
41 |
+
# frozenlist==1.4.1
|
42 |
+
# fsspec==2023.12.2
|
43 |
+
# ftfy==6.1.3
|
44 |
+
# furl==2.1.3
|
45 |
+
# gpustat==1.1.1
|
46 |
+
# gradio==4.14.0
|
47 |
+
# gradio_client==0.8.0
|
48 |
+
# greenlet==3.0.3
|
49 |
+
# h11==0.14.0
|
50 |
+
# httpcore==1.0.2
|
51 |
+
# httpx==0.26.0
|
52 |
+
# huggingface-hub==0.20.2
|
53 |
+
# humanfriendly==10.0
|
54 |
+
# hyperlink==21.0.0
|
55 |
+
# idna==3.6
|
56 |
+
imageio==2.33.1
|
57 |
+
imageio-ffmpeg==0.4.9
|
58 |
+
# importlib-metadata==7.0.1
|
59 |
+
# importlib-resources==6.1.1
|
60 |
+
# incremental==22.10.0
|
61 |
+
# ipdb==0.13.13
|
62 |
+
# ipython==8.18.1
|
63 |
+
# jedi==0.19.1
|
64 |
+
# Jinja2==3.1.3
|
65 |
+
# jmespath==0.10.0
|
66 |
+
# joblib==1.3.2
|
67 |
+
# jsonschema==4.21.0
|
68 |
+
# jsonschema-specifications==2023.12.1
|
69 |
+
# kiwisolver==1.4.5
|
70 |
+
# kornia==0.7.1
|
71 |
+
# lazy_loader==0.3
|
72 |
+
# lightning-utilities==0.10.0
|
73 |
+
# lit==17.0.6
|
74 |
+
# lpips==0.1.4
|
75 |
+
# markdown-it-py==3.0.0
|
76 |
+
# MarkupSafe==2.1.3
|
77 |
+
matplotlib==3.8.2
|
78 |
+
matplotlib-inline==0.1.6
|
79 |
+
# mdurl==0.1.2
|
80 |
+
# # motion-vector-extractor==1.0.6
|
81 |
+
# mpmath==1.3.0
|
82 |
+
# multidict==6.0.4
|
83 |
+
# mypy-extensions==1.0.0
|
84 |
+
# networkx==3.2.1
|
85 |
+
# ninja==1.11.1.1
|
86 |
+
# numpy==1.26.3
|
87 |
+
# nvidia-cublas-cu11==11.10.3.66
|
88 |
+
# nvidia-cublas-cu12==12.1.3.1
|
89 |
+
# nvidia-cuda-cupti-cu11==11.7.101
|
90 |
+
# nvidia-cuda-cupti-cu12==12.1.105
|
91 |
+
# nvidia-cuda-nvrtc-cu11==11.7.99
|
92 |
+
# nvidia-cuda-nvrtc-cu12==12.1.105
|
93 |
+
# nvidia-cuda-runtime-cu11==11.7.99
|
94 |
+
# nvidia-cuda-runtime-cu12==12.1.105
|
95 |
+
# nvidia-cudnn-cu11==8.5.0.96
|
96 |
+
# nvidia-cudnn-cu12==8.9.2.26
|
97 |
+
# nvidia-cufft-cu11==10.9.0.58
|
98 |
+
# nvidia-cufft-cu12==11.0.2.54
|
99 |
+
# nvidia-curand-cu11==10.2.10.91
|
100 |
+
# nvidia-curand-cu12==10.3.2.106
|
101 |
+
# nvidia-cusolver-cu11==11.4.0.1
|
102 |
+
# nvidia-cusolver-cu12==11.4.5.107
|
103 |
+
# nvidia-cusparse-cu11==11.7.4.91
|
104 |
+
# nvidia-cusparse-cu12==12.1.0.106
|
105 |
+
# nvidia-ml-py==12.535.133
|
106 |
+
# nvidia-nccl-cu11==2.14.3
|
107 |
+
# nvidia-nccl-cu12==2.19.3
|
108 |
+
# nvidia-nvjitlink-cu12==12.3.101
|
109 |
+
# nvidia-nvtx-cu11==11.7.91
|
110 |
+
# nvidia-nvtx-cu12==12.1.105
|
111 |
+
# omegaconf==2.3.0
|
112 |
+
onnxruntime==1.18.0
|
113 |
+
open-clip-torch==2.24.0
|
114 |
+
opencv-python==4.5.3.56
|
115 |
+
opencv-python-headless==4.9.0.80
|
116 |
+
# orderedmultidict==1.0.1
|
117 |
+
# orjson==3.9.10
|
118 |
+
oss2==2.18.4
|
119 |
+
# # packaging==23.2
|
120 |
+
# pandas==2.1.4
|
121 |
+
# parso==0.8.3
|
122 |
+
# pexpect==4.9.0
|
123 |
+
pillow==10.2.0
|
124 |
+
# piq==0.8.0
|
125 |
+
# pkgconfig==1.5.5
|
126 |
+
# prompt-toolkit==3.0.43
|
127 |
+
# protobuf==4.25.2
|
128 |
+
# psutil==5.9.8
|
129 |
+
ptflops==0.7.2.2
|
130 |
+
# ptyprocess==0.7.0
|
131 |
+
# pure-eval==0.2.2
|
132 |
+
# pycparser==2.21
|
133 |
+
# pycryptodome==3.20.0
|
134 |
+
# pydantic==2.5.3
|
135 |
+
# pydantic_core==2.14.6
|
136 |
+
# pydub==0.25.1
|
137 |
+
# Pygments==2.17.2
|
138 |
+
pynvml==11.5.0
|
139 |
+
# pyparsing==3.1.1
|
140 |
+
# pyre-extensions==0.0.29
|
141 |
+
# python-dateutil==2.8.2
|
142 |
+
# python-multipart==0.0.6
|
143 |
+
# pytorch-lightning==2.1.3
|
144 |
+
# pytz==2023.3.post1
|
145 |
+
PyYAML==6.0.1
|
146 |
+
# redo==2.0.4
|
147 |
+
# referencing==0.32.1
|
148 |
+
# regex==2023.12.25
|
149 |
+
requests==2.31.0
|
150 |
+
# rich==13.7.0
|
151 |
+
rotary-embedding-torch==0.5.3
|
152 |
+
# rpds-py==0.17.1
|
153 |
+
# ruff==0.2.0
|
154 |
+
# safetensors==0.4.1
|
155 |
+
# scikit-image==0.22.0
|
156 |
+
# scikit-learn==1.4.0
|
157 |
+
# scipy==1.11.4
|
158 |
+
# semantic-version==2.10.0
|
159 |
+
# sentencepiece==0.1.99
|
160 |
+
# shellingham==1.5.4
|
161 |
+
simplejson==3.19.2
|
162 |
+
# six==1.16.0
|
163 |
+
# sk-video==1.1.10
|
164 |
+
# sniffio==1.3.0
|
165 |
+
# SQLAlchemy==2.0.27
|
166 |
+
# stack-data==0.6.3
|
167 |
+
# starlette==0.35.1
|
168 |
+
# sympy==1.12
|
169 |
+
thop==0.1.1.post2209072238
|
170 |
+
# threadpoolctl==3.2.0
|
171 |
+
# tifffile==2023.12.9
|
172 |
+
# timm==0.9.12
|
173 |
+
# tokenizers==0.15.0
|
174 |
+
# tomli==2.0.1
|
175 |
+
# tomlkit==0.12.0
|
176 |
+
# toolz==0.12.0
|
177 |
+
torch==2.0.1+cu118
|
178 |
+
# torchaudio==2.0.2+cu118
|
179 |
+
# torchdiffeq==0.2.3
|
180 |
+
# torchmetrics==1.3.0.post0
|
181 |
+
torchsde==0.2.6
|
182 |
+
torchvision==0.15.2+cu118
|
183 |
+
tqdm==4.66.1
|
184 |
+
# traitlets==5.14.1
|
185 |
+
# trampoline==0.1.2
|
186 |
+
# transformers==4.36.2
|
187 |
+
# triton==2.0.0
|
188 |
+
# Twisted==23.10.0
|
189 |
+
# typer==0.9.0
|
190 |
+
typing-inspect==0.9.0
|
191 |
+
typing_extensions==4.9.0
|
192 |
+
# tzdata==2023.4
|
193 |
+
# urllib3==2.1.0
|
194 |
+
# uvicorn==0.26.0
|
195 |
+
# wcwidth==0.2.13
|
196 |
+
# websockets==11.0.3
|
197 |
+
xformers==0.0.20
|
198 |
+
# yarl==1.9.4
|
199 |
+
# zipp==3.17.0
|
200 |
+
# zope.interface==6.2
|
201 |
+
onnxruntime-gpu==1.13.1
|
UniAnimate/run_align_pose.py
ADDED
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Openpose
|
2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
4 |
+
# 3rd Edited by ControlNet
|
5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
6 |
+
|
7 |
+
import os
|
8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import copy
|
14 |
+
import torch
|
15 |
+
import random
|
16 |
+
import argparse
|
17 |
+
import shutil
|
18 |
+
import tempfile
|
19 |
+
import subprocess
|
20 |
+
import numpy as np
|
21 |
+
import math
|
22 |
+
|
23 |
+
import torch.multiprocessing as mp
|
24 |
+
import torch.distributed as dist
|
25 |
+
import pickle
|
26 |
+
import logging
|
27 |
+
from io import BytesIO
|
28 |
+
import oss2 as oss
|
29 |
+
import os.path as osp
|
30 |
+
|
31 |
+
import sys
|
32 |
+
import dwpose.util as util
|
33 |
+
from dwpose.wholebody import Wholebody
|
34 |
+
|
35 |
+
|
36 |
+
def smoothing_factor(t_e, cutoff):
|
37 |
+
r = 2 * math.pi * cutoff * t_e
|
38 |
+
return r / (r + 1)
|
39 |
+
|
40 |
+
|
41 |
+
def exponential_smoothing(a, x, x_prev):
|
42 |
+
return a * x + (1 - a) * x_prev
|
43 |
+
|
44 |
+
|
45 |
+
class OneEuroFilter:
|
46 |
+
def __init__(self, t0, x0, dx0=0.0, min_cutoff=1.0, beta=0.0,
|
47 |
+
d_cutoff=1.0):
|
48 |
+
"""Initialize the one euro filter."""
|
49 |
+
# The parameters.
|
50 |
+
self.min_cutoff = float(min_cutoff)
|
51 |
+
self.beta = float(beta)
|
52 |
+
self.d_cutoff = float(d_cutoff)
|
53 |
+
# Previous values.
|
54 |
+
self.x_prev = x0
|
55 |
+
self.dx_prev = float(dx0)
|
56 |
+
self.t_prev = float(t0)
|
57 |
+
|
58 |
+
def __call__(self, t, x):
|
59 |
+
"""Compute the filtered signal."""
|
60 |
+
t_e = t - self.t_prev
|
61 |
+
|
62 |
+
# The filtered derivative of the signal.
|
63 |
+
a_d = smoothing_factor(t_e, self.d_cutoff)
|
64 |
+
dx = (x - self.x_prev) / t_e
|
65 |
+
dx_hat = exponential_smoothing(a_d, dx, self.dx_prev)
|
66 |
+
|
67 |
+
# The filtered signal.
|
68 |
+
cutoff = self.min_cutoff + self.beta * abs(dx_hat)
|
69 |
+
a = smoothing_factor(t_e, cutoff)
|
70 |
+
x_hat = exponential_smoothing(a, x, self.x_prev)
|
71 |
+
|
72 |
+
# Memorize the previous values.
|
73 |
+
self.x_prev = x_hat
|
74 |
+
self.dx_prev = dx_hat
|
75 |
+
self.t_prev = t
|
76 |
+
|
77 |
+
return x_hat
|
78 |
+
|
79 |
+
|
80 |
+
def get_logger(name="essmc2"):
|
81 |
+
logger = logging.getLogger(name)
|
82 |
+
logger.propagate = False
|
83 |
+
if len(logger.handlers) == 0:
|
84 |
+
std_handler = logging.StreamHandler(sys.stdout)
|
85 |
+
formatter = logging.Formatter(
|
86 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
87 |
+
std_handler.setFormatter(formatter)
|
88 |
+
std_handler.setLevel(logging.INFO)
|
89 |
+
logger.setLevel(logging.INFO)
|
90 |
+
logger.addHandler(std_handler)
|
91 |
+
return logger
|
92 |
+
|
93 |
+
class DWposeDetector:
|
94 |
+
def __init__(self):
|
95 |
+
|
96 |
+
self.pose_estimation = Wholebody()
|
97 |
+
|
98 |
+
def __call__(self, oriImg):
|
99 |
+
oriImg = oriImg.copy()
|
100 |
+
H, W, C = oriImg.shape
|
101 |
+
with torch.no_grad():
|
102 |
+
candidate, subset = self.pose_estimation(oriImg)
|
103 |
+
candidate = candidate[0][np.newaxis, :, :]
|
104 |
+
subset = subset[0][np.newaxis, :]
|
105 |
+
nums, keys, locs = candidate.shape
|
106 |
+
candidate[..., 0] /= float(W)
|
107 |
+
candidate[..., 1] /= float(H)
|
108 |
+
body = candidate[:,:18].copy()
|
109 |
+
body = body.reshape(nums*18, locs)
|
110 |
+
score = subset[:,:18].copy()
|
111 |
+
|
112 |
+
for i in range(len(score)):
|
113 |
+
for j in range(len(score[i])):
|
114 |
+
if score[i][j] > 0.3:
|
115 |
+
score[i][j] = int(18*i+j)
|
116 |
+
else:
|
117 |
+
score[i][j] = -1
|
118 |
+
|
119 |
+
un_visible = subset<0.3
|
120 |
+
candidate[un_visible] = -1
|
121 |
+
|
122 |
+
bodyfoot_score = subset[:,:24].copy()
|
123 |
+
for i in range(len(bodyfoot_score)):
|
124 |
+
for j in range(len(bodyfoot_score[i])):
|
125 |
+
if bodyfoot_score[i][j] > 0.3:
|
126 |
+
bodyfoot_score[i][j] = int(18*i+j)
|
127 |
+
else:
|
128 |
+
bodyfoot_score[i][j] = -1
|
129 |
+
if -1 not in bodyfoot_score[:,18] and -1 not in bodyfoot_score[:,19]:
|
130 |
+
bodyfoot_score[:,18] = np.array([18.])
|
131 |
+
else:
|
132 |
+
bodyfoot_score[:,18] = np.array([-1.])
|
133 |
+
if -1 not in bodyfoot_score[:,21] and -1 not in bodyfoot_score[:,22]:
|
134 |
+
bodyfoot_score[:,19] = np.array([19.])
|
135 |
+
else:
|
136 |
+
bodyfoot_score[:,19] = np.array([-1.])
|
137 |
+
bodyfoot_score = bodyfoot_score[:, :20]
|
138 |
+
|
139 |
+
bodyfoot = candidate[:,:24].copy()
|
140 |
+
|
141 |
+
for i in range(nums):
|
142 |
+
if -1 not in bodyfoot[i][18] and -1 not in bodyfoot[i][19]:
|
143 |
+
bodyfoot[i][18] = (bodyfoot[i][18]+bodyfoot[i][19])/2
|
144 |
+
else:
|
145 |
+
bodyfoot[i][18] = np.array([-1., -1.])
|
146 |
+
if -1 not in bodyfoot[i][21] and -1 not in bodyfoot[i][22]:
|
147 |
+
bodyfoot[i][19] = (bodyfoot[i][21]+bodyfoot[i][22])/2
|
148 |
+
else:
|
149 |
+
bodyfoot[i][19] = np.array([-1., -1.])
|
150 |
+
|
151 |
+
bodyfoot = bodyfoot[:,:20,:]
|
152 |
+
bodyfoot = bodyfoot.reshape(nums*20, locs)
|
153 |
+
|
154 |
+
foot = candidate[:,18:24]
|
155 |
+
|
156 |
+
faces = candidate[:,24:92]
|
157 |
+
|
158 |
+
hands = candidate[:,92:113]
|
159 |
+
hands = np.vstack([hands, candidate[:,113:]])
|
160 |
+
|
161 |
+
# bodies = dict(candidate=body, subset=score)
|
162 |
+
bodies = dict(candidate=bodyfoot, subset=bodyfoot_score)
|
163 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
164 |
+
|
165 |
+
# return draw_pose(pose, H, W)
|
166 |
+
return pose
|
167 |
+
|
168 |
+
def draw_pose(pose, H, W):
|
169 |
+
bodies = pose['bodies']
|
170 |
+
faces = pose['faces']
|
171 |
+
hands = pose['hands']
|
172 |
+
candidate = bodies['candidate']
|
173 |
+
subset = bodies['subset']
|
174 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
175 |
+
|
176 |
+
canvas = util.draw_body_and_foot(canvas, candidate, subset)
|
177 |
+
|
178 |
+
canvas = util.draw_handpose(canvas, hands)
|
179 |
+
|
180 |
+
canvas_without_face = copy.deepcopy(canvas)
|
181 |
+
|
182 |
+
canvas = util.draw_facepose(canvas, faces)
|
183 |
+
|
184 |
+
return canvas_without_face, canvas
|
185 |
+
|
186 |
+
def dw_func(_id, frame, dwpose_model, dwpose_woface_folder='tmp_dwpose_wo_face', dwpose_withface_folder='tmp_dwpose_with_face'):
|
187 |
+
|
188 |
+
# frame = cv2.imread(frame_name, cv2.IMREAD_COLOR)
|
189 |
+
pose = dwpose_model(frame)
|
190 |
+
|
191 |
+
return pose
|
192 |
+
|
193 |
+
|
194 |
+
def mp_main(args):
|
195 |
+
|
196 |
+
if args.source_video_paths.endswith('mp4'):
|
197 |
+
video_paths = [args.source_video_paths]
|
198 |
+
else:
|
199 |
+
# video list
|
200 |
+
video_paths = [os.path.join(args.source_video_paths, frame_name) for frame_name in os.listdir(args.source_video_paths)]
|
201 |
+
|
202 |
+
|
203 |
+
logger.info("There are {} videos for extracting poses".format(len(video_paths)))
|
204 |
+
|
205 |
+
logger.info('LOAD: DW Pose Model')
|
206 |
+
dwpose_model = DWposeDetector()
|
207 |
+
|
208 |
+
results_vis = []
|
209 |
+
for i, file_path in enumerate(video_paths):
|
210 |
+
logger.info(f"{i}/{len(video_paths)}, {file_path}")
|
211 |
+
videoCapture = cv2.VideoCapture(file_path)
|
212 |
+
while videoCapture.isOpened():
|
213 |
+
# get a frame
|
214 |
+
ret, frame = videoCapture.read()
|
215 |
+
if ret:
|
216 |
+
pose = dw_func(i, frame, dwpose_model)
|
217 |
+
results_vis.append(pose)
|
218 |
+
else:
|
219 |
+
break
|
220 |
+
logger.info(f'all frames in {file_path} have been read.')
|
221 |
+
videoCapture.release()
|
222 |
+
|
223 |
+
# added
|
224 |
+
# results_vis = results_vis[8:]
|
225 |
+
print(len(results_vis))
|
226 |
+
|
227 |
+
ref_name = args.ref_name
|
228 |
+
save_motion = args.saved_pose_dir
|
229 |
+
os.system(f'rm -rf {save_motion}');
|
230 |
+
os.makedirs(save_motion, exist_ok=True)
|
231 |
+
save_warp = args.saved_pose_dir
|
232 |
+
# os.makedirs(save_warp, exist_ok=True)
|
233 |
+
|
234 |
+
ref_frame = cv2.imread(ref_name, cv2.IMREAD_COLOR)
|
235 |
+
pose_ref = dw_func(i, ref_frame, dwpose_model)
|
236 |
+
|
237 |
+
bodies = results_vis[0]['bodies']
|
238 |
+
faces = results_vis[0]['faces']
|
239 |
+
hands = results_vis[0]['hands']
|
240 |
+
candidate = bodies['candidate']
|
241 |
+
|
242 |
+
ref_bodies = pose_ref['bodies']
|
243 |
+
ref_faces = pose_ref['faces']
|
244 |
+
ref_hands = pose_ref['hands']
|
245 |
+
ref_candidate = ref_bodies['candidate']
|
246 |
+
|
247 |
+
|
248 |
+
ref_2_x = ref_candidate[2][0]
|
249 |
+
ref_2_y = ref_candidate[2][1]
|
250 |
+
ref_5_x = ref_candidate[5][0]
|
251 |
+
ref_5_y = ref_candidate[5][1]
|
252 |
+
ref_8_x = ref_candidate[8][0]
|
253 |
+
ref_8_y = ref_candidate[8][1]
|
254 |
+
ref_11_x = ref_candidate[11][0]
|
255 |
+
ref_11_y = ref_candidate[11][1]
|
256 |
+
ref_center1 = 0.5*(ref_candidate[2]+ref_candidate[5])
|
257 |
+
ref_center2 = 0.5*(ref_candidate[8]+ref_candidate[11])
|
258 |
+
|
259 |
+
zero_2_x = candidate[2][0]
|
260 |
+
zero_2_y = candidate[2][1]
|
261 |
+
zero_5_x = candidate[5][0]
|
262 |
+
zero_5_y = candidate[5][1]
|
263 |
+
zero_8_x = candidate[8][0]
|
264 |
+
zero_8_y = candidate[8][1]
|
265 |
+
zero_11_x = candidate[11][0]
|
266 |
+
zero_11_y = candidate[11][1]
|
267 |
+
zero_center1 = 0.5*(candidate[2]+candidate[5])
|
268 |
+
zero_center2 = 0.5*(candidate[8]+candidate[11])
|
269 |
+
|
270 |
+
x_ratio = (ref_5_x-ref_2_x)/(zero_5_x-zero_2_x)
|
271 |
+
y_ratio = (ref_center2[1]-ref_center1[1])/(zero_center2[1]-zero_center1[1])
|
272 |
+
|
273 |
+
results_vis[0]['bodies']['candidate'][:,0] *= x_ratio
|
274 |
+
results_vis[0]['bodies']['candidate'][:,1] *= y_ratio
|
275 |
+
results_vis[0]['faces'][:,:,0] *= x_ratio
|
276 |
+
results_vis[0]['faces'][:,:,1] *= y_ratio
|
277 |
+
results_vis[0]['hands'][:,:,0] *= x_ratio
|
278 |
+
results_vis[0]['hands'][:,:,1] *= y_ratio
|
279 |
+
|
280 |
+
########neck########
|
281 |
+
l_neck_ref = ((ref_candidate[0][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[0][1] - ref_candidate[1][1]) ** 2) ** 0.5
|
282 |
+
l_neck_0 = ((candidate[0][0] - candidate[1][0]) ** 2 + (candidate[0][1] - candidate[1][1]) ** 2) ** 0.5
|
283 |
+
neck_ratio = l_neck_ref / l_neck_0
|
284 |
+
|
285 |
+
x_offset_neck = (candidate[1][0]-candidate[0][0])*(1.-neck_ratio)
|
286 |
+
y_offset_neck = (candidate[1][1]-candidate[0][1])*(1.-neck_ratio)
|
287 |
+
|
288 |
+
results_vis[0]['bodies']['candidate'][0,0] += x_offset_neck
|
289 |
+
results_vis[0]['bodies']['candidate'][0,1] += y_offset_neck
|
290 |
+
results_vis[0]['bodies']['candidate'][14,0] += x_offset_neck
|
291 |
+
results_vis[0]['bodies']['candidate'][14,1] += y_offset_neck
|
292 |
+
results_vis[0]['bodies']['candidate'][15,0] += x_offset_neck
|
293 |
+
results_vis[0]['bodies']['candidate'][15,1] += y_offset_neck
|
294 |
+
results_vis[0]['bodies']['candidate'][16,0] += x_offset_neck
|
295 |
+
results_vis[0]['bodies']['candidate'][16,1] += y_offset_neck
|
296 |
+
results_vis[0]['bodies']['candidate'][17,0] += x_offset_neck
|
297 |
+
results_vis[0]['bodies']['candidate'][17,1] += y_offset_neck
|
298 |
+
|
299 |
+
########shoulder2########
|
300 |
+
l_shoulder2_ref = ((ref_candidate[2][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[2][1] - ref_candidate[1][1]) ** 2) ** 0.5
|
301 |
+
l_shoulder2_0 = ((candidate[2][0] - candidate[1][0]) ** 2 + (candidate[2][1] - candidate[1][1]) ** 2) ** 0.5
|
302 |
+
|
303 |
+
shoulder2_ratio = l_shoulder2_ref / l_shoulder2_0
|
304 |
+
|
305 |
+
x_offset_shoulder2 = (candidate[1][0]-candidate[2][0])*(1.-shoulder2_ratio)
|
306 |
+
y_offset_shoulder2 = (candidate[1][1]-candidate[2][1])*(1.-shoulder2_ratio)
|
307 |
+
|
308 |
+
results_vis[0]['bodies']['candidate'][2,0] += x_offset_shoulder2
|
309 |
+
results_vis[0]['bodies']['candidate'][2,1] += y_offset_shoulder2
|
310 |
+
results_vis[0]['bodies']['candidate'][3,0] += x_offset_shoulder2
|
311 |
+
results_vis[0]['bodies']['candidate'][3,1] += y_offset_shoulder2
|
312 |
+
results_vis[0]['bodies']['candidate'][4,0] += x_offset_shoulder2
|
313 |
+
results_vis[0]['bodies']['candidate'][4,1] += y_offset_shoulder2
|
314 |
+
results_vis[0]['hands'][1,:,0] += x_offset_shoulder2
|
315 |
+
results_vis[0]['hands'][1,:,1] += y_offset_shoulder2
|
316 |
+
|
317 |
+
########shoulder5########
|
318 |
+
l_shoulder5_ref = ((ref_candidate[5][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[5][1] - ref_candidate[1][1]) ** 2) ** 0.5
|
319 |
+
l_shoulder5_0 = ((candidate[5][0] - candidate[1][0]) ** 2 + (candidate[5][1] - candidate[1][1]) ** 2) ** 0.5
|
320 |
+
|
321 |
+
shoulder5_ratio = l_shoulder5_ref / l_shoulder5_0
|
322 |
+
|
323 |
+
x_offset_shoulder5 = (candidate[1][0]-candidate[5][0])*(1.-shoulder5_ratio)
|
324 |
+
y_offset_shoulder5 = (candidate[1][1]-candidate[5][1])*(1.-shoulder5_ratio)
|
325 |
+
|
326 |
+
results_vis[0]['bodies']['candidate'][5,0] += x_offset_shoulder5
|
327 |
+
results_vis[0]['bodies']['candidate'][5,1] += y_offset_shoulder5
|
328 |
+
results_vis[0]['bodies']['candidate'][6,0] += x_offset_shoulder5
|
329 |
+
results_vis[0]['bodies']['candidate'][6,1] += y_offset_shoulder5
|
330 |
+
results_vis[0]['bodies']['candidate'][7,0] += x_offset_shoulder5
|
331 |
+
results_vis[0]['bodies']['candidate'][7,1] += y_offset_shoulder5
|
332 |
+
results_vis[0]['hands'][0,:,0] += x_offset_shoulder5
|
333 |
+
results_vis[0]['hands'][0,:,1] += y_offset_shoulder5
|
334 |
+
|
335 |
+
########arm3########
|
336 |
+
l_arm3_ref = ((ref_candidate[3][0] - ref_candidate[2][0]) ** 2 + (ref_candidate[3][1] - ref_candidate[2][1]) ** 2) ** 0.5
|
337 |
+
l_arm3_0 = ((candidate[3][0] - candidate[2][0]) ** 2 + (candidate[3][1] - candidate[2][1]) ** 2) ** 0.5
|
338 |
+
|
339 |
+
arm3_ratio = l_arm3_ref / l_arm3_0
|
340 |
+
|
341 |
+
x_offset_arm3 = (candidate[2][0]-candidate[3][0])*(1.-arm3_ratio)
|
342 |
+
y_offset_arm3 = (candidate[2][1]-candidate[3][1])*(1.-arm3_ratio)
|
343 |
+
|
344 |
+
results_vis[0]['bodies']['candidate'][3,0] += x_offset_arm3
|
345 |
+
results_vis[0]['bodies']['candidate'][3,1] += y_offset_arm3
|
346 |
+
results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm3
|
347 |
+
results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm3
|
348 |
+
results_vis[0]['hands'][1,:,0] += x_offset_arm3
|
349 |
+
results_vis[0]['hands'][1,:,1] += y_offset_arm3
|
350 |
+
|
351 |
+
########arm4########
|
352 |
+
l_arm4_ref = ((ref_candidate[4][0] - ref_candidate[3][0]) ** 2 + (ref_candidate[4][1] - ref_candidate[3][1]) ** 2) ** 0.5
|
353 |
+
l_arm4_0 = ((candidate[4][0] - candidate[3][0]) ** 2 + (candidate[4][1] - candidate[3][1]) ** 2) ** 0.5
|
354 |
+
|
355 |
+
arm4_ratio = l_arm4_ref / l_arm4_0
|
356 |
+
|
357 |
+
x_offset_arm4 = (candidate[3][0]-candidate[4][0])*(1.-arm4_ratio)
|
358 |
+
y_offset_arm4 = (candidate[3][1]-candidate[4][1])*(1.-arm4_ratio)
|
359 |
+
|
360 |
+
results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm4
|
361 |
+
results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm4
|
362 |
+
results_vis[0]['hands'][1,:,0] += x_offset_arm4
|
363 |
+
results_vis[0]['hands'][1,:,1] += y_offset_arm4
|
364 |
+
|
365 |
+
########arm6########
|
366 |
+
l_arm6_ref = ((ref_candidate[6][0] - ref_candidate[5][0]) ** 2 + (ref_candidate[6][1] - ref_candidate[5][1]) ** 2) ** 0.5
|
367 |
+
l_arm6_0 = ((candidate[6][0] - candidate[5][0]) ** 2 + (candidate[6][1] - candidate[5][1]) ** 2) ** 0.5
|
368 |
+
|
369 |
+
arm6_ratio = l_arm6_ref / l_arm6_0
|
370 |
+
|
371 |
+
x_offset_arm6 = (candidate[5][0]-candidate[6][0])*(1.-arm6_ratio)
|
372 |
+
y_offset_arm6 = (candidate[5][1]-candidate[6][1])*(1.-arm6_ratio)
|
373 |
+
|
374 |
+
results_vis[0]['bodies']['candidate'][6,0] += x_offset_arm6
|
375 |
+
results_vis[0]['bodies']['candidate'][6,1] += y_offset_arm6
|
376 |
+
results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm6
|
377 |
+
results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm6
|
378 |
+
results_vis[0]['hands'][0,:,0] += x_offset_arm6
|
379 |
+
results_vis[0]['hands'][0,:,1] += y_offset_arm6
|
380 |
+
|
381 |
+
########arm7########
|
382 |
+
l_arm7_ref = ((ref_candidate[7][0] - ref_candidate[6][0]) ** 2 + (ref_candidate[7][1] - ref_candidate[6][1]) ** 2) ** 0.5
|
383 |
+
l_arm7_0 = ((candidate[7][0] - candidate[6][0]) ** 2 + (candidate[7][1] - candidate[6][1]) ** 2) ** 0.5
|
384 |
+
|
385 |
+
arm7_ratio = l_arm7_ref / l_arm7_0
|
386 |
+
|
387 |
+
x_offset_arm7 = (candidate[6][0]-candidate[7][0])*(1.-arm7_ratio)
|
388 |
+
y_offset_arm7 = (candidate[6][1]-candidate[7][1])*(1.-arm7_ratio)
|
389 |
+
|
390 |
+
results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm7
|
391 |
+
results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm7
|
392 |
+
results_vis[0]['hands'][0,:,0] += x_offset_arm7
|
393 |
+
results_vis[0]['hands'][0,:,1] += y_offset_arm7
|
394 |
+
|
395 |
+
########head14########
|
396 |
+
l_head14_ref = ((ref_candidate[14][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[14][1] - ref_candidate[0][1]) ** 2) ** 0.5
|
397 |
+
l_head14_0 = ((candidate[14][0] - candidate[0][0]) ** 2 + (candidate[14][1] - candidate[0][1]) ** 2) ** 0.5
|
398 |
+
|
399 |
+
head14_ratio = l_head14_ref / l_head14_0
|
400 |
+
|
401 |
+
x_offset_head14 = (candidate[0][0]-candidate[14][0])*(1.-head14_ratio)
|
402 |
+
y_offset_head14 = (candidate[0][1]-candidate[14][1])*(1.-head14_ratio)
|
403 |
+
|
404 |
+
results_vis[0]['bodies']['candidate'][14,0] += x_offset_head14
|
405 |
+
results_vis[0]['bodies']['candidate'][14,1] += y_offset_head14
|
406 |
+
results_vis[0]['bodies']['candidate'][16,0] += x_offset_head14
|
407 |
+
results_vis[0]['bodies']['candidate'][16,1] += y_offset_head14
|
408 |
+
|
409 |
+
########head15########
|
410 |
+
l_head15_ref = ((ref_candidate[15][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[15][1] - ref_candidate[0][1]) ** 2) ** 0.5
|
411 |
+
l_head15_0 = ((candidate[15][0] - candidate[0][0]) ** 2 + (candidate[15][1] - candidate[0][1]) ** 2) ** 0.5
|
412 |
+
|
413 |
+
head15_ratio = l_head15_ref / l_head15_0
|
414 |
+
|
415 |
+
x_offset_head15 = (candidate[0][0]-candidate[15][0])*(1.-head15_ratio)
|
416 |
+
y_offset_head15 = (candidate[0][1]-candidate[15][1])*(1.-head15_ratio)
|
417 |
+
|
418 |
+
results_vis[0]['bodies']['candidate'][15,0] += x_offset_head15
|
419 |
+
results_vis[0]['bodies']['candidate'][15,1] += y_offset_head15
|
420 |
+
results_vis[0]['bodies']['candidate'][17,0] += x_offset_head15
|
421 |
+
results_vis[0]['bodies']['candidate'][17,1] += y_offset_head15
|
422 |
+
|
423 |
+
########head16########
|
424 |
+
l_head16_ref = ((ref_candidate[16][0] - ref_candidate[14][0]) ** 2 + (ref_candidate[16][1] - ref_candidate[14][1]) ** 2) ** 0.5
|
425 |
+
l_head16_0 = ((candidate[16][0] - candidate[14][0]) ** 2 + (candidate[16][1] - candidate[14][1]) ** 2) ** 0.5
|
426 |
+
|
427 |
+
head16_ratio = l_head16_ref / l_head16_0
|
428 |
+
|
429 |
+
x_offset_head16 = (candidate[14][0]-candidate[16][0])*(1.-head16_ratio)
|
430 |
+
y_offset_head16 = (candidate[14][1]-candidate[16][1])*(1.-head16_ratio)
|
431 |
+
|
432 |
+
results_vis[0]['bodies']['candidate'][16,0] += x_offset_head16
|
433 |
+
results_vis[0]['bodies']['candidate'][16,1] += y_offset_head16
|
434 |
+
|
435 |
+
########head17########
|
436 |
+
l_head17_ref = ((ref_candidate[17][0] - ref_candidate[15][0]) ** 2 + (ref_candidate[17][1] - ref_candidate[15][1]) ** 2) ** 0.5
|
437 |
+
l_head17_0 = ((candidate[17][0] - candidate[15][0]) ** 2 + (candidate[17][1] - candidate[15][1]) ** 2) ** 0.5
|
438 |
+
|
439 |
+
head17_ratio = l_head17_ref / l_head17_0
|
440 |
+
|
441 |
+
x_offset_head17 = (candidate[15][0]-candidate[17][0])*(1.-head17_ratio)
|
442 |
+
y_offset_head17 = (candidate[15][1]-candidate[17][1])*(1.-head17_ratio)
|
443 |
+
|
444 |
+
results_vis[0]['bodies']['candidate'][17,0] += x_offset_head17
|
445 |
+
results_vis[0]['bodies']['candidate'][17,1] += y_offset_head17
|
446 |
+
|
447 |
+
########MovingAverage########
|
448 |
+
|
449 |
+
########left leg########
|
450 |
+
l_ll1_ref = ((ref_candidate[8][0] - ref_candidate[9][0]) ** 2 + (ref_candidate[8][1] - ref_candidate[9][1]) ** 2) ** 0.5
|
451 |
+
l_ll1_0 = ((candidate[8][0] - candidate[9][0]) ** 2 + (candidate[8][1] - candidate[9][1]) ** 2) ** 0.5
|
452 |
+
ll1_ratio = l_ll1_ref / l_ll1_0
|
453 |
+
|
454 |
+
x_offset_ll1 = (candidate[9][0]-candidate[8][0])*(ll1_ratio-1.)
|
455 |
+
y_offset_ll1 = (candidate[9][1]-candidate[8][1])*(ll1_ratio-1.)
|
456 |
+
|
457 |
+
results_vis[0]['bodies']['candidate'][9,0] += x_offset_ll1
|
458 |
+
results_vis[0]['bodies']['candidate'][9,1] += y_offset_ll1
|
459 |
+
results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll1
|
460 |
+
results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll1
|
461 |
+
results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll1
|
462 |
+
results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll1
|
463 |
+
|
464 |
+
l_ll2_ref = ((ref_candidate[9][0] - ref_candidate[10][0]) ** 2 + (ref_candidate[9][1] - ref_candidate[10][1]) ** 2) ** 0.5
|
465 |
+
l_ll2_0 = ((candidate[9][0] - candidate[10][0]) ** 2 + (candidate[9][1] - candidate[10][1]) ** 2) ** 0.5
|
466 |
+
ll2_ratio = l_ll2_ref / l_ll2_0
|
467 |
+
|
468 |
+
x_offset_ll2 = (candidate[10][0]-candidate[9][0])*(ll2_ratio-1.)
|
469 |
+
y_offset_ll2 = (candidate[10][1]-candidate[9][1])*(ll2_ratio-1.)
|
470 |
+
|
471 |
+
results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll2
|
472 |
+
results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll2
|
473 |
+
results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll2
|
474 |
+
results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll2
|
475 |
+
|
476 |
+
########right leg########
|
477 |
+
l_rl1_ref = ((ref_candidate[11][0] - ref_candidate[12][0]) ** 2 + (ref_candidate[11][1] - ref_candidate[12][1]) ** 2) ** 0.5
|
478 |
+
l_rl1_0 = ((candidate[11][0] - candidate[12][0]) ** 2 + (candidate[11][1] - candidate[12][1]) ** 2) ** 0.5
|
479 |
+
rl1_ratio = l_rl1_ref / l_rl1_0
|
480 |
+
|
481 |
+
x_offset_rl1 = (candidate[12][0]-candidate[11][0])*(rl1_ratio-1.)
|
482 |
+
y_offset_rl1 = (candidate[12][1]-candidate[11][1])*(rl1_ratio-1.)
|
483 |
+
|
484 |
+
results_vis[0]['bodies']['candidate'][12,0] += x_offset_rl1
|
485 |
+
results_vis[0]['bodies']['candidate'][12,1] += y_offset_rl1
|
486 |
+
results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl1
|
487 |
+
results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl1
|
488 |
+
results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl1
|
489 |
+
results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl1
|
490 |
+
|
491 |
+
l_rl2_ref = ((ref_candidate[12][0] - ref_candidate[13][0]) ** 2 + (ref_candidate[12][1] - ref_candidate[13][1]) ** 2) ** 0.5
|
492 |
+
l_rl2_0 = ((candidate[12][0] - candidate[13][0]) ** 2 + (candidate[12][1] - candidate[13][1]) ** 2) ** 0.5
|
493 |
+
rl2_ratio = l_rl2_ref / l_rl2_0
|
494 |
+
|
495 |
+
x_offset_rl2 = (candidate[13][0]-candidate[12][0])*(rl2_ratio-1.)
|
496 |
+
y_offset_rl2 = (candidate[13][1]-candidate[12][1])*(rl2_ratio-1.)
|
497 |
+
|
498 |
+
results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl2
|
499 |
+
results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl2
|
500 |
+
results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl2
|
501 |
+
results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl2
|
502 |
+
|
503 |
+
offset = ref_candidate[1] - results_vis[0]['bodies']['candidate'][1]
|
504 |
+
|
505 |
+
results_vis[0]['bodies']['candidate'] += offset[np.newaxis, :]
|
506 |
+
results_vis[0]['faces'] += offset[np.newaxis, np.newaxis, :]
|
507 |
+
results_vis[0]['hands'] += offset[np.newaxis, np.newaxis, :]
|
508 |
+
|
509 |
+
for i in range(1, len(results_vis)):
|
510 |
+
results_vis[i]['bodies']['candidate'][:,0] *= x_ratio
|
511 |
+
results_vis[i]['bodies']['candidate'][:,1] *= y_ratio
|
512 |
+
results_vis[i]['faces'][:,:,0] *= x_ratio
|
513 |
+
results_vis[i]['faces'][:,:,1] *= y_ratio
|
514 |
+
results_vis[i]['hands'][:,:,0] *= x_ratio
|
515 |
+
results_vis[i]['hands'][:,:,1] *= y_ratio
|
516 |
+
|
517 |
+
########neck########
|
518 |
+
x_offset_neck = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][0][0])*(1.-neck_ratio)
|
519 |
+
y_offset_neck = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][0][1])*(1.-neck_ratio)
|
520 |
+
|
521 |
+
results_vis[i]['bodies']['candidate'][0,0] += x_offset_neck
|
522 |
+
results_vis[i]['bodies']['candidate'][0,1] += y_offset_neck
|
523 |
+
results_vis[i]['bodies']['candidate'][14,0] += x_offset_neck
|
524 |
+
results_vis[i]['bodies']['candidate'][14,1] += y_offset_neck
|
525 |
+
results_vis[i]['bodies']['candidate'][15,0] += x_offset_neck
|
526 |
+
results_vis[i]['bodies']['candidate'][15,1] += y_offset_neck
|
527 |
+
results_vis[i]['bodies']['candidate'][16,0] += x_offset_neck
|
528 |
+
results_vis[i]['bodies']['candidate'][16,1] += y_offset_neck
|
529 |
+
results_vis[i]['bodies']['candidate'][17,0] += x_offset_neck
|
530 |
+
results_vis[i]['bodies']['candidate'][17,1] += y_offset_neck
|
531 |
+
|
532 |
+
########shoulder2########
|
533 |
+
|
534 |
+
|
535 |
+
x_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][2][0])*(1.-shoulder2_ratio)
|
536 |
+
y_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][2][1])*(1.-shoulder2_ratio)
|
537 |
+
|
538 |
+
results_vis[i]['bodies']['candidate'][2,0] += x_offset_shoulder2
|
539 |
+
results_vis[i]['bodies']['candidate'][2,1] += y_offset_shoulder2
|
540 |
+
results_vis[i]['bodies']['candidate'][3,0] += x_offset_shoulder2
|
541 |
+
results_vis[i]['bodies']['candidate'][3,1] += y_offset_shoulder2
|
542 |
+
results_vis[i]['bodies']['candidate'][4,0] += x_offset_shoulder2
|
543 |
+
results_vis[i]['bodies']['candidate'][4,1] += y_offset_shoulder2
|
544 |
+
results_vis[i]['hands'][1,:,0] += x_offset_shoulder2
|
545 |
+
results_vis[i]['hands'][1,:,1] += y_offset_shoulder2
|
546 |
+
|
547 |
+
########shoulder5########
|
548 |
+
|
549 |
+
x_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][5][0])*(1.-shoulder5_ratio)
|
550 |
+
y_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][5][1])*(1.-shoulder5_ratio)
|
551 |
+
|
552 |
+
results_vis[i]['bodies']['candidate'][5,0] += x_offset_shoulder5
|
553 |
+
results_vis[i]['bodies']['candidate'][5,1] += y_offset_shoulder5
|
554 |
+
results_vis[i]['bodies']['candidate'][6,0] += x_offset_shoulder5
|
555 |
+
results_vis[i]['bodies']['candidate'][6,1] += y_offset_shoulder5
|
556 |
+
results_vis[i]['bodies']['candidate'][7,0] += x_offset_shoulder5
|
557 |
+
results_vis[i]['bodies']['candidate'][7,1] += y_offset_shoulder5
|
558 |
+
results_vis[i]['hands'][0,:,0] += x_offset_shoulder5
|
559 |
+
results_vis[i]['hands'][0,:,1] += y_offset_shoulder5
|
560 |
+
|
561 |
+
########arm3########
|
562 |
+
|
563 |
+
x_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][0]-results_vis[i]['bodies']['candidate'][3][0])*(1.-arm3_ratio)
|
564 |
+
y_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][1]-results_vis[i]['bodies']['candidate'][3][1])*(1.-arm3_ratio)
|
565 |
+
|
566 |
+
results_vis[i]['bodies']['candidate'][3,0] += x_offset_arm3
|
567 |
+
results_vis[i]['bodies']['candidate'][3,1] += y_offset_arm3
|
568 |
+
results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm3
|
569 |
+
results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm3
|
570 |
+
results_vis[i]['hands'][1,:,0] += x_offset_arm3
|
571 |
+
results_vis[i]['hands'][1,:,1] += y_offset_arm3
|
572 |
+
|
573 |
+
########arm4########
|
574 |
+
|
575 |
+
x_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][0]-results_vis[i]['bodies']['candidate'][4][0])*(1.-arm4_ratio)
|
576 |
+
y_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][1]-results_vis[i]['bodies']['candidate'][4][1])*(1.-arm4_ratio)
|
577 |
+
|
578 |
+
results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm4
|
579 |
+
results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm4
|
580 |
+
results_vis[i]['hands'][1,:,0] += x_offset_arm4
|
581 |
+
results_vis[i]['hands'][1,:,1] += y_offset_arm4
|
582 |
+
|
583 |
+
########arm6########
|
584 |
+
|
585 |
+
x_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][0]-results_vis[i]['bodies']['candidate'][6][0])*(1.-arm6_ratio)
|
586 |
+
y_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][1]-results_vis[i]['bodies']['candidate'][6][1])*(1.-arm6_ratio)
|
587 |
+
|
588 |
+
results_vis[i]['bodies']['candidate'][6,0] += x_offset_arm6
|
589 |
+
results_vis[i]['bodies']['candidate'][6,1] += y_offset_arm6
|
590 |
+
results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm6
|
591 |
+
results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm6
|
592 |
+
results_vis[i]['hands'][0,:,0] += x_offset_arm6
|
593 |
+
results_vis[i]['hands'][0,:,1] += y_offset_arm6
|
594 |
+
|
595 |
+
########arm7########
|
596 |
+
|
597 |
+
x_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][0]-results_vis[i]['bodies']['candidate'][7][0])*(1.-arm7_ratio)
|
598 |
+
y_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][1]-results_vis[i]['bodies']['candidate'][7][1])*(1.-arm7_ratio)
|
599 |
+
|
600 |
+
results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm7
|
601 |
+
results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm7
|
602 |
+
results_vis[i]['hands'][0,:,0] += x_offset_arm7
|
603 |
+
results_vis[i]['hands'][0,:,1] += y_offset_arm7
|
604 |
+
|
605 |
+
########head14########
|
606 |
+
|
607 |
+
x_offset_head14 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][14][0])*(1.-head14_ratio)
|
608 |
+
y_offset_head14 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][14][1])*(1.-head14_ratio)
|
609 |
+
|
610 |
+
results_vis[i]['bodies']['candidate'][14,0] += x_offset_head14
|
611 |
+
results_vis[i]['bodies']['candidate'][14,1] += y_offset_head14
|
612 |
+
results_vis[i]['bodies']['candidate'][16,0] += x_offset_head14
|
613 |
+
results_vis[i]['bodies']['candidate'][16,1] += y_offset_head14
|
614 |
+
|
615 |
+
########head15########
|
616 |
+
|
617 |
+
x_offset_head15 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][15][0])*(1.-head15_ratio)
|
618 |
+
y_offset_head15 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][15][1])*(1.-head15_ratio)
|
619 |
+
|
620 |
+
results_vis[i]['bodies']['candidate'][15,0] += x_offset_head15
|
621 |
+
results_vis[i]['bodies']['candidate'][15,1] += y_offset_head15
|
622 |
+
results_vis[i]['bodies']['candidate'][17,0] += x_offset_head15
|
623 |
+
results_vis[i]['bodies']['candidate'][17,1] += y_offset_head15
|
624 |
+
|
625 |
+
########head16########
|
626 |
+
|
627 |
+
x_offset_head16 = (results_vis[i]['bodies']['candidate'][14][0]-results_vis[i]['bodies']['candidate'][16][0])*(1.-head16_ratio)
|
628 |
+
y_offset_head16 = (results_vis[i]['bodies']['candidate'][14][1]-results_vis[i]['bodies']['candidate'][16][1])*(1.-head16_ratio)
|
629 |
+
|
630 |
+
results_vis[i]['bodies']['candidate'][16,0] += x_offset_head16
|
631 |
+
results_vis[i]['bodies']['candidate'][16,1] += y_offset_head16
|
632 |
+
|
633 |
+
########head17########
|
634 |
+
x_offset_head17 = (results_vis[i]['bodies']['candidate'][15][0]-results_vis[i]['bodies']['candidate'][17][0])*(1.-head17_ratio)
|
635 |
+
y_offset_head17 = (results_vis[i]['bodies']['candidate'][15][1]-results_vis[i]['bodies']['candidate'][17][1])*(1.-head17_ratio)
|
636 |
+
|
637 |
+
results_vis[i]['bodies']['candidate'][17,0] += x_offset_head17
|
638 |
+
results_vis[i]['bodies']['candidate'][17,1] += y_offset_head17
|
639 |
+
|
640 |
+
# ########MovingAverage########
|
641 |
+
|
642 |
+
########left leg########
|
643 |
+
x_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][0]-results_vis[i]['bodies']['candidate'][8][0])*(ll1_ratio-1.)
|
644 |
+
y_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][1]-results_vis[i]['bodies']['candidate'][8][1])*(ll1_ratio-1.)
|
645 |
+
|
646 |
+
results_vis[i]['bodies']['candidate'][9,0] += x_offset_ll1
|
647 |
+
results_vis[i]['bodies']['candidate'][9,1] += y_offset_ll1
|
648 |
+
results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll1
|
649 |
+
results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll1
|
650 |
+
results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll1
|
651 |
+
results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll1
|
652 |
+
|
653 |
+
|
654 |
+
|
655 |
+
x_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][0]-results_vis[i]['bodies']['candidate'][9][0])*(ll2_ratio-1.)
|
656 |
+
y_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][1]-results_vis[i]['bodies']['candidate'][9][1])*(ll2_ratio-1.)
|
657 |
+
|
658 |
+
results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll2
|
659 |
+
results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll2
|
660 |
+
results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll2
|
661 |
+
results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll2
|
662 |
+
|
663 |
+
########right leg########
|
664 |
+
|
665 |
+
x_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][0]-results_vis[i]['bodies']['candidate'][11][0])*(rl1_ratio-1.)
|
666 |
+
y_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][1]-results_vis[i]['bodies']['candidate'][11][1])*(rl1_ratio-1.)
|
667 |
+
|
668 |
+
results_vis[i]['bodies']['candidate'][12,0] += x_offset_rl1
|
669 |
+
results_vis[i]['bodies']['candidate'][12,1] += y_offset_rl1
|
670 |
+
results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl1
|
671 |
+
results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl1
|
672 |
+
results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl1
|
673 |
+
results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl1
|
674 |
+
|
675 |
+
|
676 |
+
x_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][0]-results_vis[i]['bodies']['candidate'][12][0])*(rl2_ratio-1.)
|
677 |
+
y_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][1]-results_vis[i]['bodies']['candidate'][12][1])*(rl2_ratio-1.)
|
678 |
+
|
679 |
+
results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl2
|
680 |
+
results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl2
|
681 |
+
results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl2
|
682 |
+
results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl2
|
683 |
+
|
684 |
+
results_vis[i]['bodies']['candidate'] += offset[np.newaxis, :]
|
685 |
+
results_vis[i]['faces'] += offset[np.newaxis, np.newaxis, :]
|
686 |
+
results_vis[i]['hands'] += offset[np.newaxis, np.newaxis, :]
|
687 |
+
|
688 |
+
for i in range(len(results_vis)):
|
689 |
+
dwpose_woface, dwpose_wface = draw_pose(results_vis[i], H=768, W=512)
|
690 |
+
img_path = save_motion+'/' + str(i).zfill(4) + '.jpg'
|
691 |
+
cv2.imwrite(img_path, dwpose_woface)
|
692 |
+
|
693 |
+
dwpose_woface, dwpose_wface = draw_pose(pose_ref, H=768, W=512)
|
694 |
+
img_path = save_warp+'/' + 'ref_pose.jpg'
|
695 |
+
cv2.imwrite(img_path, dwpose_woface)
|
696 |
+
|
697 |
+
|
698 |
+
logger = get_logger('dw pose extraction')
|
699 |
+
|
700 |
+
|
701 |
+
if __name__=='__main__':
|
702 |
+
def parse_args():
|
703 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
704 |
+
parser.add_argument("--ref_name", type=str, default="data/images/IMG_20240514_104337.jpg",)
|
705 |
+
parser.add_argument("--source_video_paths", type=str, default="data/videos/source_video.mp4",)
|
706 |
+
parser.add_argument("--saved_pose_dir", type=str, default="data/saved_pose/IMG_20240514_104337",)
|
707 |
+
args = parser.parse_args()
|
708 |
+
|
709 |
+
return args
|
710 |
+
|
711 |
+
args = parse_args()
|
712 |
+
mp_main(args)
|
UniAnimate/test_func/save_targer_keys.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
import os.path as osp
|
8 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
|
9 |
+
from thop import profile
|
10 |
+
from ptflops import get_model_complexity_info
|
11 |
+
|
12 |
+
import artist.data as data
|
13 |
+
from tools.modules.config import cfg
|
14 |
+
from tools.modules.unet.util import *
|
15 |
+
from utils.config import Config as pConfig
|
16 |
+
from utils.registry_class import ENGINE, MODEL
|
17 |
+
|
18 |
+
|
19 |
+
def save_temporal_key():
|
20 |
+
cfg_update = pConfig(load=True)
|
21 |
+
|
22 |
+
for k, v in cfg_update.cfg_dict.items():
|
23 |
+
if isinstance(v, dict) and k in cfg:
|
24 |
+
cfg[k].update(v)
|
25 |
+
else:
|
26 |
+
cfg[k] = v
|
27 |
+
|
28 |
+
model = MODEL.build(cfg.UNet)
|
29 |
+
|
30 |
+
temp_name = ''
|
31 |
+
temp_key_list = []
|
32 |
+
spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json'
|
33 |
+
for name, module in model.named_modules():
|
34 |
+
if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
|
35 |
+
temp_name = name
|
36 |
+
print(f'Model: {name}')
|
37 |
+
elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
|
38 |
+
temp_name = ''
|
39 |
+
|
40 |
+
if hasattr(module, 'weight'):
|
41 |
+
if temp_name != '' and (temp_name in name):
|
42 |
+
temp_key_list.append(name)
|
43 |
+
print(f'{name}')
|
44 |
+
# print(name)
|
45 |
+
|
46 |
+
save_module_list = []
|
47 |
+
for k, p in model.named_parameters():
|
48 |
+
for item in temp_key_list:
|
49 |
+
if item in k:
|
50 |
+
print(f'{item} --> {k}')
|
51 |
+
save_module_list.append(k)
|
52 |
+
|
53 |
+
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
|
54 |
+
|
55 |
+
# spth = 'workspace/module_list/{}'
|
56 |
+
json.dump(save_module_list, open(spth, 'w'))
|
57 |
+
a = 0
|
58 |
+
|
59 |
+
|
60 |
+
def save_spatial_key():
|
61 |
+
cfg_update = pConfig(load=True)
|
62 |
+
|
63 |
+
for k, v in cfg_update.cfg_dict.items():
|
64 |
+
if isinstance(v, dict) and k in cfg:
|
65 |
+
cfg[k].update(v)
|
66 |
+
else:
|
67 |
+
cfg[k] = v
|
68 |
+
|
69 |
+
model = MODEL.build(cfg.UNet)
|
70 |
+
temp_name = ''
|
71 |
+
temp_key_list = []
|
72 |
+
spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json'
|
73 |
+
for name, module in model.named_modules():
|
74 |
+
if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
|
75 |
+
temp_name = name
|
76 |
+
print(f'Model: {name}')
|
77 |
+
elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
|
78 |
+
temp_name = ''
|
79 |
+
|
80 |
+
if hasattr(module, 'weight'):
|
81 |
+
if temp_name != '' and (temp_name in name):
|
82 |
+
temp_key_list.append(name)
|
83 |
+
print(f'{name}')
|
84 |
+
# print(name)
|
85 |
+
|
86 |
+
save_module_list = []
|
87 |
+
for k, p in model.named_parameters():
|
88 |
+
for item in temp_key_list:
|
89 |
+
if item in k:
|
90 |
+
print(f'{item} --> {k}')
|
91 |
+
save_module_list.append(k)
|
92 |
+
|
93 |
+
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
|
94 |
+
|
95 |
+
# spth = 'workspace/module_list/{}'
|
96 |
+
json.dump(save_module_list, open(spth, 'w'))
|
97 |
+
a = 0
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
# save_temporal_key()
|
102 |
+
save_spatial_key()
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
# print([k for (k, _) in self.input_blocks.named_parameters()])
|
107 |
+
|
108 |
+
|
UniAnimate/test_func/test_EndDec.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
import os.path as osp
|
7 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from tools import *
|
13 |
+
import utils.transforms as data
|
14 |
+
from utils.seed import setup_seed
|
15 |
+
from tools.modules.config import cfg
|
16 |
+
from utils.config import Config as pConfig
|
17 |
+
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER
|
18 |
+
|
19 |
+
|
20 |
+
def test_enc_dec(gpu=0):
|
21 |
+
setup_seed(0)
|
22 |
+
cfg_update = pConfig(load=True)
|
23 |
+
|
24 |
+
for k, v in cfg_update.cfg_dict.items():
|
25 |
+
if isinstance(v, dict) and k in cfg:
|
26 |
+
cfg[k].update(v)
|
27 |
+
else:
|
28 |
+
cfg[k] = v
|
29 |
+
|
30 |
+
save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
|
31 |
+
os.system('rm -rf %s' % (save_dir))
|
32 |
+
os.makedirs(save_dir, exist_ok=True)
|
33 |
+
|
34 |
+
train_trans = data.Compose([
|
35 |
+
data.CenterCropWide(size=cfg.resolution),
|
36 |
+
data.ToTensor(),
|
37 |
+
data.Normalize(mean=cfg.mean, std=cfg.std)])
|
38 |
+
|
39 |
+
vit_trans = data.Compose([
|
40 |
+
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
|
41 |
+
data.Resize(cfg.vit_resolution),
|
42 |
+
data.ToTensor(),
|
43 |
+
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
|
44 |
+
|
45 |
+
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
|
46 |
+
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
|
47 |
+
|
48 |
+
txt_size = cfg.resolution[1]
|
49 |
+
nc = int(38 * (txt_size / 256))
|
50 |
+
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
|
51 |
+
|
52 |
+
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
|
53 |
+
print('There are %d videos' % (len(dataset)))
|
54 |
+
|
55 |
+
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
|
56 |
+
autoencoder.eval() # freeze
|
57 |
+
for param in autoencoder.parameters():
|
58 |
+
param.requires_grad = False
|
59 |
+
autoencoder.to(gpu)
|
60 |
+
for idx, item in enumerate(dataset):
|
61 |
+
local_path = os.path.join(save_dir, '%04d.mp4' % idx)
|
62 |
+
# ref_frame, video_data, caption = item
|
63 |
+
ref_frame, vit_frame, video_data = item[:3]
|
64 |
+
video_data = video_data.to(gpu)
|
65 |
+
|
66 |
+
image_list = []
|
67 |
+
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
|
68 |
+
with torch.no_grad():
|
69 |
+
decode_data = []
|
70 |
+
for chunk_data in video_data_list:
|
71 |
+
latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
|
72 |
+
# latent_z = get_first_stage_encoding(encoder_posterior).detach()
|
73 |
+
kwargs = {"timesteps": chunk_data.shape[0]}
|
74 |
+
recons_data = autoencoder.decode(latent_z, **kwargs)
|
75 |
+
|
76 |
+
vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
|
77 |
+
vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384
|
78 |
+
vis_data = vis_data.cpu()
|
79 |
+
vis_data.clamp_(0, 1)
|
80 |
+
vis_data = vis_data.permute(0, 2, 3, 1)
|
81 |
+
vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
|
82 |
+
image_list.extend(vis_data)
|
83 |
+
|
84 |
+
num_image = len(image_list)
|
85 |
+
frame_dir = os.path.join(save_dir, 'temp')
|
86 |
+
os.makedirs(frame_dir, exist_ok=True)
|
87 |
+
for idx in range(num_image):
|
88 |
+
tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
|
89 |
+
cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
90 |
+
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
|
91 |
+
os.system(cmd); os.system(f'rm -rf {frame_dir}')
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
test_enc_dec()
|
UniAnimate/test_func/test_dataset.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import os.path as osp
|
6 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
import torchvision.transforms as T
|
9 |
+
|
10 |
+
import utils.transforms as data
|
11 |
+
from tools.modules.config import cfg
|
12 |
+
from utils.config import Config as pConfig
|
13 |
+
from utils.registry_class import ENGINE, DATASETS
|
14 |
+
|
15 |
+
from tools import *
|
16 |
+
|
17 |
+
def test_video_dataset():
|
18 |
+
cfg_update = pConfig(load=True)
|
19 |
+
|
20 |
+
for k, v in cfg_update.cfg_dict.items():
|
21 |
+
if isinstance(v, dict) and k in cfg:
|
22 |
+
cfg[k].update(v)
|
23 |
+
else:
|
24 |
+
cfg[k] = v
|
25 |
+
|
26 |
+
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
|
27 |
+
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name)
|
28 |
+
os.system('rm -rf %s' % (save_dir))
|
29 |
+
os.makedirs(save_dir, exist_ok=True)
|
30 |
+
|
31 |
+
train_trans = data.Compose([
|
32 |
+
data.CenterCropWide(size=cfg.resolution),
|
33 |
+
data.ToTensor(),
|
34 |
+
data.Normalize(mean=cfg.mean, std=cfg.std)])
|
35 |
+
vit_trans = T.Compose([
|
36 |
+
data.CenterCropWide(cfg.vit_resolution),
|
37 |
+
T.ToTensor(),
|
38 |
+
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
|
39 |
+
|
40 |
+
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
|
41 |
+
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
|
42 |
+
|
43 |
+
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
|
44 |
+
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
|
45 |
+
|
46 |
+
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
|
47 |
+
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
|
48 |
+
|
49 |
+
txt_size = cfg.resolution[1]
|
50 |
+
nc = int(38 * (txt_size / 256))
|
51 |
+
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
|
52 |
+
|
53 |
+
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans)
|
54 |
+
print('There are %d videos' % (len(dataset)))
|
55 |
+
for idx, item in enumerate(dataset):
|
56 |
+
ref_frame, vit_frame, video_data, caption, video_key = item
|
57 |
+
|
58 |
+
video_data = video_data.mul_(video_std).add_(video_mean)
|
59 |
+
video_data.clamp_(0, 1)
|
60 |
+
video_data = video_data.permute(0, 2, 3, 1)
|
61 |
+
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
|
62 |
+
|
63 |
+
# Single Image
|
64 |
+
ref_frame = ref_frame.mul_(img_mean).add_(img_std)
|
65 |
+
ref_frame.clamp_(0, 1)
|
66 |
+
ref_frame = ref_frame.permute(1, 2, 0)
|
67 |
+
ref_frame = (ref_frame.numpy() * 255).astype('uint8')
|
68 |
+
|
69 |
+
# Text image
|
70 |
+
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
|
71 |
+
draw = ImageDraw.Draw(txt_img)
|
72 |
+
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
|
73 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
74 |
+
txt_img = np.array(txt_img)
|
75 |
+
|
76 |
+
video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data]
|
77 |
+
spath = os.path.join(save_dir, '%04d.gif' % (idx))
|
78 |
+
imageio.mimwrite(spath, video_data, fps =8)
|
79 |
+
|
80 |
+
# if idx > 100: break
|
81 |
+
|
82 |
+
|
83 |
+
def test_vit_image(test_video_flag=True):
|
84 |
+
cfg_update = pConfig(load=True)
|
85 |
+
|
86 |
+
for k, v in cfg_update.cfg_dict.items():
|
87 |
+
if isinstance(v, dict) and k in cfg:
|
88 |
+
cfg[k].update(v)
|
89 |
+
else:
|
90 |
+
cfg[k] = v
|
91 |
+
|
92 |
+
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
|
93 |
+
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name)
|
94 |
+
os.system('rm -rf %s' % (save_dir))
|
95 |
+
os.makedirs(save_dir, exist_ok=True)
|
96 |
+
|
97 |
+
train_trans = data.Compose([
|
98 |
+
data.CenterCropWide(size=cfg.resolution),
|
99 |
+
data.ToTensor(),
|
100 |
+
data.Normalize(mean=cfg.mean, std=cfg.std)])
|
101 |
+
vit_trans = data.Compose([
|
102 |
+
data.CenterCropWide(cfg.resolution),
|
103 |
+
data.Resize(cfg.vit_resolution),
|
104 |
+
data.ToTensor(),
|
105 |
+
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
|
106 |
+
|
107 |
+
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
|
108 |
+
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
|
109 |
+
|
110 |
+
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
|
111 |
+
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
|
112 |
+
|
113 |
+
txt_size = cfg.resolution[1]
|
114 |
+
nc = int(38 * (txt_size / 256))
|
115 |
+
font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13)
|
116 |
+
|
117 |
+
dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
|
118 |
+
print('There are %d videos' % (len(dataset)))
|
119 |
+
for idx, item in enumerate(dataset):
|
120 |
+
ref_frame, vit_frame, video_data, caption, video_key = item
|
121 |
+
video_data = video_data.mul_(img_std).add_(img_mean)
|
122 |
+
video_data.clamp_(0, 1)
|
123 |
+
video_data = video_data.permute(0, 2, 3, 1)
|
124 |
+
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
|
125 |
+
|
126 |
+
# Single Image
|
127 |
+
vit_frame = vit_frame.mul_(vit_std).add_(vit_mean)
|
128 |
+
vit_frame.clamp_(0, 1)
|
129 |
+
vit_frame = vit_frame.permute(1, 2, 0)
|
130 |
+
vit_frame = (vit_frame.numpy() * 255).astype('uint8')
|
131 |
+
|
132 |
+
zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8)
|
133 |
+
zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame
|
134 |
+
|
135 |
+
# Text image
|
136 |
+
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
|
137 |
+
draw = ImageDraw.Draw(txt_img)
|
138 |
+
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
|
139 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
140 |
+
txt_img = np.array(txt_img)
|
141 |
+
|
142 |
+
video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data]
|
143 |
+
spath = os.path.join(save_dir, '%04d.gif' % (idx))
|
144 |
+
imageio.mimwrite(spath, video_data, fps =8)
|
145 |
+
|
146 |
+
# if idx > 100: break
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
# test_video_dataset()
|
151 |
+
test_vit_image()
|
152 |
+
|
UniAnimate/test_func/test_models.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
import os.path as osp
|
7 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
|
8 |
+
from thop import profile
|
9 |
+
from ptflops import get_model_complexity_info
|
10 |
+
|
11 |
+
import artist.data as data
|
12 |
+
from tools.modules.config import cfg
|
13 |
+
from utils.config import Config as pConfig
|
14 |
+
from utils.registry_class import ENGINE, MODEL
|
15 |
+
|
16 |
+
|
17 |
+
def test_model():
|
18 |
+
cfg_update = pConfig(load=True)
|
19 |
+
|
20 |
+
for k, v in cfg_update.cfg_dict.items():
|
21 |
+
if isinstance(v, dict) and k in cfg:
|
22 |
+
cfg[k].update(v)
|
23 |
+
else:
|
24 |
+
cfg[k] = v
|
25 |
+
|
26 |
+
model = MODEL.build(cfg.UNet)
|
27 |
+
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
|
28 |
+
|
29 |
+
# state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu')
|
30 |
+
# model.load_state_dict(state_dict, strict=False)
|
31 |
+
model = model.cuda()
|
32 |
+
|
33 |
+
x = torch.Tensor(1, 4, 16, 32, 56).cuda()
|
34 |
+
t = torch.Tensor(1).cuda()
|
35 |
+
sims = torch.Tensor(1, 32).cuda()
|
36 |
+
fps = torch.Tensor([8]).cuda()
|
37 |
+
y = torch.Tensor(1, 1, 1024).cuda()
|
38 |
+
image = torch.Tensor(1, 3, 256, 448).cuda()
|
39 |
+
|
40 |
+
ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps)
|
41 |
+
print('Out shape if {}'.format(ret.shape))
|
42 |
+
|
43 |
+
# flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps))
|
44 |
+
# print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))
|
45 |
+
|
46 |
+
def prepare_input(resolution):
|
47 |
+
return dict(x=[x, t, y, image, sims, fps])
|
48 |
+
|
49 |
+
flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56),
|
50 |
+
input_constructor = prepare_input,
|
51 |
+
as_strings=True, print_per_layer_stat=True)
|
52 |
+
print(' - Flops: ' + flops)
|
53 |
+
print(' - Params: ' + params)
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
test_model()
|
UniAnimate/test_func/test_save_video.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
cap = cv2.VideoCapture('workspace/img_dir/tst.mp4')
|
5 |
+
|
6 |
+
fourcc = cv2.VideoWriter_fourcc(*'H264')
|
7 |
+
|
8 |
+
ret, frame = cap.read()
|
9 |
+
vid_size = frame.shape[:2][::-1]
|
10 |
+
|
11 |
+
out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size)
|
12 |
+
out.write(frame)
|
13 |
+
|
14 |
+
while(cap.isOpened()):
|
15 |
+
ret, frame = cap.read()
|
16 |
+
if not ret: break
|
17 |
+
out.write(frame)
|
18 |
+
|
19 |
+
|
20 |
+
cap.release()
|
21 |
+
out.release()
|
22 |
+
|
23 |
+
|
24 |
+
|
UniAnimate/tools/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .datasets import *
|
2 |
+
from .modules import *
|
3 |
+
from .inferences import *
|
UniAnimate/tools/datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .image_dataset import *
|
2 |
+
from .video_dataset import *
|
UniAnimate/tools/datasets/image_dataset.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import logging
|
6 |
+
import tempfile
|
7 |
+
import numpy as np
|
8 |
+
from copy import copy
|
9 |
+
from PIL import Image
|
10 |
+
from io import BytesIO
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from utils.registry_class import DATASETS
|
13 |
+
|
14 |
+
@DATASETS.register_class()
|
15 |
+
class ImageDataset(Dataset):
|
16 |
+
def __init__(self,
|
17 |
+
data_list,
|
18 |
+
data_dir_list,
|
19 |
+
max_words=1000,
|
20 |
+
vit_resolution=[224, 224],
|
21 |
+
resolution=(384, 256),
|
22 |
+
max_frames=1,
|
23 |
+
transforms=None,
|
24 |
+
vit_transforms=None,
|
25 |
+
**kwargs):
|
26 |
+
|
27 |
+
self.max_frames = max_frames
|
28 |
+
self.resolution = resolution
|
29 |
+
self.transforms = transforms
|
30 |
+
self.vit_resolution = vit_resolution
|
31 |
+
self.vit_transforms = vit_transforms
|
32 |
+
|
33 |
+
image_list = []
|
34 |
+
for item_path, data_dir in zip(data_list, data_dir_list):
|
35 |
+
lines = open(item_path, 'r').readlines()
|
36 |
+
lines = [[data_dir, item.strip()] for item in lines]
|
37 |
+
image_list.extend(lines)
|
38 |
+
self.image_list = image_list
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.image_list)
|
42 |
+
|
43 |
+
def __getitem__(self, index):
|
44 |
+
data_dir, file_path = self.image_list[index]
|
45 |
+
img_key = file_path.split('|||')[0]
|
46 |
+
try:
|
47 |
+
ref_frame, vit_frame, video_data, caption = self._get_image_data(data_dir, file_path)
|
48 |
+
except Exception as e:
|
49 |
+
logging.info('{} get frames failed... with error: {}'.format(img_key, e))
|
50 |
+
caption = ''
|
51 |
+
img_key = ''
|
52 |
+
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
|
53 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
54 |
+
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
|
55 |
+
return ref_frame, vit_frame, video_data, caption, img_key
|
56 |
+
|
57 |
+
def _get_image_data(self, data_dir, file_path):
|
58 |
+
frame_list = []
|
59 |
+
img_key, caption = file_path.split('|||')
|
60 |
+
file_path = os.path.join(data_dir, img_key)
|
61 |
+
for _ in range(5):
|
62 |
+
try:
|
63 |
+
image = Image.open(file_path)
|
64 |
+
if image.mode != 'RGB':
|
65 |
+
image = image.convert('RGB')
|
66 |
+
frame_list.append(image)
|
67 |
+
break
|
68 |
+
except Exception as e:
|
69 |
+
logging.info('{} read video frame failed with error: {}'.format(img_key, e))
|
70 |
+
continue
|
71 |
+
|
72 |
+
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
|
73 |
+
try:
|
74 |
+
if len(frame_list) > 0:
|
75 |
+
mid_frame = frame_list[0]
|
76 |
+
vit_frame = self.vit_transforms(mid_frame)
|
77 |
+
frame_tensor = self.transforms(frame_list)
|
78 |
+
video_data[:len(frame_list), ...] = frame_tensor
|
79 |
+
else:
|
80 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
81 |
+
except:
|
82 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
83 |
+
ref_frame = copy(video_data[0])
|
84 |
+
|
85 |
+
return ref_frame, vit_frame, video_data, caption
|
86 |
+
|
UniAnimate/tools/datasets/video_dataset.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import logging
|
7 |
+
import tempfile
|
8 |
+
import numpy as np
|
9 |
+
from copy import copy
|
10 |
+
from PIL import Image
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from utils.registry_class import DATASETS
|
13 |
+
|
14 |
+
|
15 |
+
@DATASETS.register_class()
|
16 |
+
class VideoDataset(Dataset):
|
17 |
+
def __init__(self,
|
18 |
+
data_list,
|
19 |
+
data_dir_list,
|
20 |
+
max_words=1000,
|
21 |
+
resolution=(384, 256),
|
22 |
+
vit_resolution=(224, 224),
|
23 |
+
max_frames=16,
|
24 |
+
sample_fps=8,
|
25 |
+
transforms=None,
|
26 |
+
vit_transforms=None,
|
27 |
+
get_first_frame=False,
|
28 |
+
**kwargs):
|
29 |
+
|
30 |
+
self.max_words = max_words
|
31 |
+
self.max_frames = max_frames
|
32 |
+
self.resolution = resolution
|
33 |
+
self.vit_resolution = vit_resolution
|
34 |
+
self.sample_fps = sample_fps
|
35 |
+
self.transforms = transforms
|
36 |
+
self.vit_transforms = vit_transforms
|
37 |
+
self.get_first_frame = get_first_frame
|
38 |
+
|
39 |
+
image_list = []
|
40 |
+
for item_path, data_dir in zip(data_list, data_dir_list):
|
41 |
+
lines = open(item_path, 'r').readlines()
|
42 |
+
lines = [[data_dir, item] for item in lines]
|
43 |
+
image_list.extend(lines)
|
44 |
+
self.image_list = image_list
|
45 |
+
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
data_dir, file_path = self.image_list[index]
|
49 |
+
video_key = file_path.split('|||')[0]
|
50 |
+
try:
|
51 |
+
ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path)
|
52 |
+
except Exception as e:
|
53 |
+
logging.info('{} get frames failed... with error: {}'.format(video_key, e))
|
54 |
+
caption = ''
|
55 |
+
video_key = ''
|
56 |
+
ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0])
|
57 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
58 |
+
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
|
59 |
+
return ref_frame, vit_frame, video_data, caption, video_key
|
60 |
+
|
61 |
+
|
62 |
+
def _get_video_data(self, data_dir, file_path):
|
63 |
+
video_key, caption = file_path.split('|||')
|
64 |
+
file_path = os.path.join(data_dir, video_key)
|
65 |
+
|
66 |
+
for _ in range(5):
|
67 |
+
try:
|
68 |
+
capture = cv2.VideoCapture(file_path)
|
69 |
+
_fps = capture.get(cv2.CAP_PROP_FPS)
|
70 |
+
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
71 |
+
stride = round(_fps / self.sample_fps)
|
72 |
+
cover_frame_num = (stride * self.max_frames)
|
73 |
+
if _total_frame_num < cover_frame_num + 5:
|
74 |
+
start_frame = 0
|
75 |
+
end_frame = _total_frame_num
|
76 |
+
else:
|
77 |
+
start_frame = random.randint(0, _total_frame_num-cover_frame_num-5)
|
78 |
+
end_frame = start_frame + cover_frame_num
|
79 |
+
|
80 |
+
pointer, frame_list = 0, []
|
81 |
+
while(True):
|
82 |
+
ret, frame = capture.read()
|
83 |
+
pointer +=1
|
84 |
+
if (not ret) or (frame is None): break
|
85 |
+
if pointer < start_frame: continue
|
86 |
+
if pointer >= end_frame - 1: break
|
87 |
+
if (pointer - start_frame) % stride == 0:
|
88 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
89 |
+
frame = Image.fromarray(frame)
|
90 |
+
frame_list.append(frame)
|
91 |
+
break
|
92 |
+
except Exception as e:
|
93 |
+
logging.info('{} read video frame failed with error: {}'.format(video_key, e))
|
94 |
+
continue
|
95 |
+
|
96 |
+
video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0])
|
97 |
+
if self.get_first_frame:
|
98 |
+
ref_idx = 0
|
99 |
+
else:
|
100 |
+
ref_idx = int(len(frame_list)/2)
|
101 |
+
try:
|
102 |
+
if len(frame_list)>0:
|
103 |
+
mid_frame = copy(frame_list[ref_idx])
|
104 |
+
vit_frame = self.vit_transforms(mid_frame)
|
105 |
+
frames = self.transforms(frame_list)
|
106 |
+
video_data[:len(frame_list), ...] = frames
|
107 |
+
else:
|
108 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
109 |
+
except:
|
110 |
+
vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0])
|
111 |
+
ref_frame = copy(frames[ref_idx])
|
112 |
+
|
113 |
+
return ref_frame, vit_frame, video_data, caption
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
return len(self.image_list)
|
117 |
+
|
118 |
+
|
UniAnimate/tools/inferences/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .inference_unianimate_entrance import *
|
2 |
+
from .inference_unianimate_long_entrance import *
|
UniAnimate/tools/inferences/inference_unianimate_entrance.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
/*
|
3 |
+
*Copyright (c) 2021, Alibaba Group;
|
4 |
+
*Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
*you may not use this file except in compliance with the License.
|
6 |
+
*You may obtain a copy of the License at
|
7 |
+
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
*Unless required by applicable law or agreed to in writing, software
|
11 |
+
*distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
*See the License for the specific language governing permissions and
|
14 |
+
*limitations under the License.
|
15 |
+
*/
|
16 |
+
'''
|
17 |
+
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import os.path as osp
|
21 |
+
import sys
|
22 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
import torch
|
26 |
+
import pynvml
|
27 |
+
import logging
|
28 |
+
import cv2
|
29 |
+
import numpy as np
|
30 |
+
from PIL import Image
|
31 |
+
from tqdm import tqdm
|
32 |
+
import torch.cuda.amp as amp
|
33 |
+
from importlib import reload
|
34 |
+
import torch.distributed as dist
|
35 |
+
import torch.multiprocessing as mp
|
36 |
+
import random
|
37 |
+
from einops import rearrange
|
38 |
+
import torchvision.transforms as T
|
39 |
+
import torchvision.transforms.functional as TF
|
40 |
+
from torch.nn.parallel import DistributedDataParallel
|
41 |
+
|
42 |
+
import utils.transforms as data
|
43 |
+
from ..modules.config import cfg
|
44 |
+
from utils.seed import setup_seed
|
45 |
+
from utils.multi_port import find_free_port
|
46 |
+
from utils.assign_cfg import assign_signle_cfg
|
47 |
+
from utils.distributed import generalized_all_gather, all_reduce
|
48 |
+
from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col
|
49 |
+
from tools.modules.autoencoder import get_first_stage_encoding
|
50 |
+
from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION
|
51 |
+
from copy import copy
|
52 |
+
import cv2
|
53 |
+
|
54 |
+
|
55 |
+
@INFER_ENGINE.register_function()
|
56 |
+
def inference_unianimate_entrance(cfg_update, **kwargs):
|
57 |
+
for k, v in cfg_update.items():
|
58 |
+
if isinstance(v, dict) and k in cfg:
|
59 |
+
cfg[k].update(v)
|
60 |
+
else:
|
61 |
+
cfg[k] = v
|
62 |
+
|
63 |
+
if not 'MASTER_ADDR' in os.environ:
|
64 |
+
os.environ['MASTER_ADDR']='localhost'
|
65 |
+
os.environ['MASTER_PORT']= find_free_port()
|
66 |
+
cfg.pmi_rank = int(os.getenv('RANK', 0))
|
67 |
+
cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
|
68 |
+
|
69 |
+
if cfg.debug:
|
70 |
+
cfg.gpus_per_machine = 1
|
71 |
+
cfg.world_size = 1
|
72 |
+
else:
|
73 |
+
cfg.gpus_per_machine = torch.cuda.device_count()
|
74 |
+
cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
|
75 |
+
|
76 |
+
if cfg.world_size == 1:
|
77 |
+
worker(0, cfg, cfg_update)
|
78 |
+
else:
|
79 |
+
mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update))
|
80 |
+
return cfg
|
81 |
+
|
82 |
+
|
83 |
+
def make_masked_images(imgs, masks):
|
84 |
+
masked_imgs = []
|
85 |
+
for i, mask in enumerate(masks):
|
86 |
+
# concatenation
|
87 |
+
masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1))
|
88 |
+
return torch.stack(masked_imgs, dim=0)
|
89 |
+
|
90 |
+
def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]):
|
91 |
+
for _ in range(5):
|
92 |
+
try:
|
93 |
+
dwpose_all = {}
|
94 |
+
frames_all = {}
|
95 |
+
for ii_index in sorted(os.listdir(pose_file_path)):
|
96 |
+
if ii_index != "ref_pose.jpg":
|
97 |
+
dwpose_all[ii_index] = Image.open(os.path.join(pose_file_path, ii_index))
|
98 |
+
frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path), cv2.COLOR_BGR2RGB))
|
99 |
+
|
100 |
+
pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg"))
|
101 |
+
|
102 |
+
# Sample max_frames poses for video generation
|
103 |
+
stride = frame_interval
|
104 |
+
total_frame_num = len(frames_all)
|
105 |
+
cover_frame_num = (stride * (max_frames - 1) + 1)
|
106 |
+
|
107 |
+
if total_frame_num < cover_frame_num:
|
108 |
+
print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed')
|
109 |
+
start_frame = 0
|
110 |
+
end_frame = total_frame_num
|
111 |
+
stride = max((total_frame_num - 1) // (max_frames - 1), 1)
|
112 |
+
end_frame = stride * max_frames
|
113 |
+
else:
|
114 |
+
start_frame = 0
|
115 |
+
end_frame = start_frame + cover_frame_num
|
116 |
+
|
117 |
+
frame_list = []
|
118 |
+
dwpose_list = []
|
119 |
+
random_ref_frame = frames_all[list(frames_all.keys())[0]]
|
120 |
+
if random_ref_frame.mode != 'RGB':
|
121 |
+
random_ref_frame = random_ref_frame.convert('RGB')
|
122 |
+
random_ref_dwpose = pose_ref
|
123 |
+
if random_ref_dwpose.mode != 'RGB':
|
124 |
+
random_ref_dwpose = random_ref_dwpose.convert('RGB')
|
125 |
+
|
126 |
+
for i_index in range(start_frame, end_frame, stride):
|
127 |
+
if i_index < len(frames_all): # Check index within bounds
|
128 |
+
i_key = list(frames_all.keys())[i_index]
|
129 |
+
i_frame = frames_all[i_key]
|
130 |
+
if i_frame.mode != 'RGB':
|
131 |
+
i_frame = i_frame.convert('RGB')
|
132 |
+
|
133 |
+
i_dwpose = dwpose_all[i_key]
|
134 |
+
if i_dwpose.mode != 'RGB':
|
135 |
+
i_dwpose = i_dwpose.convert('RGB')
|
136 |
+
frame_list.append(i_frame)
|
137 |
+
dwpose_list.append(i_dwpose)
|
138 |
+
|
139 |
+
if frame_list:
|
140 |
+
middle_indix = 0
|
141 |
+
ref_frame = frame_list[middle_indix]
|
142 |
+
vit_frame = vit_transforms(ref_frame)
|
143 |
+
random_ref_frame_tmp = train_trans_pose(random_ref_frame)
|
144 |
+
random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose)
|
145 |
+
misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0)
|
146 |
+
video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0)
|
147 |
+
dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0)
|
148 |
+
|
149 |
+
video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
150 |
+
dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
151 |
+
misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
152 |
+
random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
153 |
+
random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
154 |
+
|
155 |
+
video_data[:len(frame_list), ...] = video_data_tmp
|
156 |
+
misc_data[:len(frame_list), ...] = misc_data_tmp
|
157 |
+
dwpose_data[:len(frame_list), ...] = dwpose_data_tmp
|
158 |
+
random_ref_frame_data[:, ...] = random_ref_frame_tmp
|
159 |
+
random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp
|
160 |
+
|
161 |
+
return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
logging.info(f'Error reading video frame: {e}')
|
165 |
+
continue
|
166 |
+
|
167 |
+
return None, None, None, None, None, None
|
168 |
+
|
169 |
+
def worker(gpu, cfg, cfg_update):
|
170 |
+
'''
|
171 |
+
Inference worker for each gpu
|
172 |
+
'''
|
173 |
+
for k, v in cfg_update.items():
|
174 |
+
if isinstance(v, dict) and k in cfg:
|
175 |
+
cfg[k].update(v)
|
176 |
+
else:
|
177 |
+
cfg[k] = v
|
178 |
+
|
179 |
+
cfg.gpu = gpu
|
180 |
+
cfg.seed = int(cfg.seed)
|
181 |
+
cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu
|
182 |
+
setup_seed(cfg.seed + cfg.rank)
|
183 |
+
|
184 |
+
if not cfg.debug:
|
185 |
+
torch.cuda.set_device(gpu)
|
186 |
+
torch.backends.cudnn.benchmark = True
|
187 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
188 |
+
torch.backends.cudnn.benchmark = False
|
189 |
+
dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)
|
190 |
+
|
191 |
+
# [Log] Save logging and make log dir
|
192 |
+
log_dir = generalized_all_gather(cfg.log_dir)[0]
|
193 |
+
inf_name = osp.basename(cfg.cfg_file).split('.')[0]
|
194 |
+
test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1]
|
195 |
+
|
196 |
+
cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name))
|
197 |
+
os.makedirs(cfg.log_dir, exist_ok=True)
|
198 |
+
log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank))
|
199 |
+
cfg.log_file = log_file
|
200 |
+
reload(logging)
|
201 |
+
logging.basicConfig(
|
202 |
+
level=logging.INFO,
|
203 |
+
format='[%(asctime)s] %(levelname)s: %(message)s',
|
204 |
+
handlers=[
|
205 |
+
logging.FileHandler(filename=log_file),
|
206 |
+
logging.StreamHandler(stream=sys.stdout)])
|
207 |
+
logging.info(cfg)
|
208 |
+
logging.info(f"Running UniAnimate inference on gpu {gpu}")
|
209 |
+
|
210 |
+
# [Diffusion]
|
211 |
+
diffusion = DIFFUSION.build(cfg.Diffusion)
|
212 |
+
|
213 |
+
# [Data] Data Transform
|
214 |
+
train_trans = data.Compose([
|
215 |
+
data.Resize(cfg.resolution),
|
216 |
+
data.ToTensor(),
|
217 |
+
data.Normalize(mean=cfg.mean, std=cfg.std)
|
218 |
+
])
|
219 |
+
|
220 |
+
train_trans_pose = data.Compose([
|
221 |
+
data.Resize(cfg.resolution),
|
222 |
+
data.ToTensor(),
|
223 |
+
]
|
224 |
+
)
|
225 |
+
|
226 |
+
vit_transforms = T.Compose([
|
227 |
+
data.Resize(cfg.vit_resolution),
|
228 |
+
T.ToTensor(),
|
229 |
+
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
|
230 |
+
|
231 |
+
# [Model] embedder
|
232 |
+
clip_encoder = EMBEDDER.build(cfg.embedder)
|
233 |
+
clip_encoder.model.to(gpu)
|
234 |
+
with torch.no_grad():
|
235 |
+
_, _, zero_y = clip_encoder(text="")
|
236 |
+
|
237 |
+
|
238 |
+
# [Model] auotoencoder
|
239 |
+
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
|
240 |
+
autoencoder.eval() # freeze
|
241 |
+
for param in autoencoder.parameters():
|
242 |
+
param.requires_grad = False
|
243 |
+
autoencoder.cuda()
|
244 |
+
|
245 |
+
# [Model] UNet
|
246 |
+
if "config" in cfg.UNet:
|
247 |
+
cfg.UNet["config"] = cfg
|
248 |
+
cfg.UNet["zero_y"] = zero_y
|
249 |
+
model = MODEL.build(cfg.UNet)
|
250 |
+
state_dict = torch.load(cfg.test_model, map_location='cpu')
|
251 |
+
if 'state_dict' in state_dict:
|
252 |
+
state_dict = state_dict['state_dict']
|
253 |
+
if 'step' in state_dict:
|
254 |
+
resume_step = state_dict['step']
|
255 |
+
else:
|
256 |
+
resume_step = 0
|
257 |
+
status = model.load_state_dict(state_dict, strict=True)
|
258 |
+
logging.info('Load model from {} with status {}'.format(cfg.test_model, status))
|
259 |
+
model = model.to(gpu)
|
260 |
+
model.eval()
|
261 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
262 |
+
model.to(torch.float16)
|
263 |
+
else:
|
264 |
+
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
|
265 |
+
torch.cuda.empty_cache()
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
test_list = cfg.test_list_path
|
270 |
+
num_videos = len(test_list)
|
271 |
+
logging.info(f'There are {num_videos} videos. with {cfg.round} times')
|
272 |
+
# test_list = [item for item in test_list for _ in range(cfg.round)]
|
273 |
+
test_list = [item for _ in range(cfg.round) for item in test_list]
|
274 |
+
|
275 |
+
for idx, file_path in enumerate(test_list):
|
276 |
+
cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2]
|
277 |
+
|
278 |
+
manual_seed = int(cfg.seed + cfg.rank + idx//num_videos)
|
279 |
+
setup_seed(manual_seed)
|
280 |
+
logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...")
|
281 |
+
|
282 |
+
|
283 |
+
vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution)
|
284 |
+
misc_data = misc_data.unsqueeze(0).to(gpu)
|
285 |
+
vit_frame = vit_frame.unsqueeze(0).to(gpu)
|
286 |
+
dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
|
287 |
+
random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
|
288 |
+
random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)
|
289 |
+
|
290 |
+
### save for visualization
|
291 |
+
misc_backups = copy(misc_data)
|
292 |
+
frames_num = misc_data.shape[1]
|
293 |
+
misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
|
294 |
+
mv_data_video = []
|
295 |
+
|
296 |
+
|
297 |
+
### local image (first frame)
|
298 |
+
image_local = []
|
299 |
+
if 'local_image' in cfg.video_compositions:
|
300 |
+
frames_num = misc_data.shape[1]
|
301 |
+
bs_vd_local = misc_data.shape[0]
|
302 |
+
image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1)
|
303 |
+
image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
|
304 |
+
image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
|
305 |
+
if hasattr(cfg, "latent_local_image") and cfg.latent_local_image:
|
306 |
+
with torch.no_grad():
|
307 |
+
temporal_length = frames_num
|
308 |
+
encoder_posterior = autoencoder.encode(video_data[:,0])
|
309 |
+
local_image_data = get_first_stage_encoding(encoder_posterior).detach()
|
310 |
+
image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
### encode the video_data
|
315 |
+
bs_vd = misc_data.shape[0]
|
316 |
+
misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
|
317 |
+
misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0)
|
318 |
+
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
|
322 |
+
random_ref_frame = []
|
323 |
+
if 'randomref' in cfg.video_compositions:
|
324 |
+
random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
|
325 |
+
if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref:
|
326 |
+
|
327 |
+
temporal_length = random_ref_frame_data.shape[1]
|
328 |
+
encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5))
|
329 |
+
random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach()
|
330 |
+
random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
|
331 |
+
|
332 |
+
random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
|
333 |
+
|
334 |
+
|
335 |
+
if 'dwpose' in cfg.video_compositions:
|
336 |
+
bs_vd_local = dwpose_data.shape[0]
|
337 |
+
dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local)
|
338 |
+
if 'randomref_pose' in cfg.video_compositions:
|
339 |
+
dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1)
|
340 |
+
dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local)
|
341 |
+
|
342 |
+
|
343 |
+
y_visual = []
|
344 |
+
if 'image' in cfg.video_compositions:
|
345 |
+
with torch.no_grad():
|
346 |
+
vit_frame = vit_frame.squeeze(1)
|
347 |
+
y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
|
348 |
+
y_visual0 = y_visual.clone()
|
349 |
+
|
350 |
+
|
351 |
+
with amp.autocast(enabled=True):
|
352 |
+
pynvml.nvmlInit()
|
353 |
+
handle=pynvml.nvmlDeviceGetHandleByIndex(0)
|
354 |
+
meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
|
355 |
+
cur_seed = torch.initial_seed()
|
356 |
+
logging.info(f"Current seed {cur_seed} ...")
|
357 |
+
|
358 |
+
noise = torch.randn([1, 4, cfg.max_frames, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)])
|
359 |
+
noise = noise.to(gpu)
|
360 |
+
|
361 |
+
if hasattr(cfg.Diffusion, "noise_strength"):
|
362 |
+
b, c, f, _, _= noise.shape
|
363 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
|
364 |
+
noise = noise + cfg.Diffusion.noise_strength * offset_noise
|
365 |
+
|
366 |
+
# add a noise prior
|
367 |
+
noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise)
|
368 |
+
|
369 |
+
# construct model inputs (CFG)
|
370 |
+
full_model_kwargs=[{
|
371 |
+
'y': None,
|
372 |
+
"local_image": None if len(image_local) == 0 else image_local[:],
|
373 |
+
'image': None if len(y_visual) == 0 else y_visual0[:],
|
374 |
+
'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:],
|
375 |
+
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:],
|
376 |
+
},
|
377 |
+
{
|
378 |
+
'y': None,
|
379 |
+
"local_image": None,
|
380 |
+
'image': None,
|
381 |
+
'randomref': None,
|
382 |
+
'dwpose': None,
|
383 |
+
}]
|
384 |
+
|
385 |
+
# for visualization
|
386 |
+
full_model_kwargs_vis =[{
|
387 |
+
'y': None,
|
388 |
+
"local_image": None if len(image_local) == 0 else image_local_clone[:],
|
389 |
+
'image': None,
|
390 |
+
'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:],
|
391 |
+
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3],
|
392 |
+
},
|
393 |
+
{
|
394 |
+
'y': None,
|
395 |
+
"local_image": None,
|
396 |
+
'image': None,
|
397 |
+
'randomref': None,
|
398 |
+
'dwpose': None,
|
399 |
+
}]
|
400 |
+
|
401 |
+
|
402 |
+
partial_keys = [
|
403 |
+
['image', 'randomref', "dwpose"],
|
404 |
+
]
|
405 |
+
if hasattr(cfg, "partial_keys") and cfg.partial_keys:
|
406 |
+
partial_keys = cfg.partial_keys
|
407 |
+
|
408 |
+
|
409 |
+
for partial_keys_one in partial_keys:
|
410 |
+
model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one,
|
411 |
+
full_model_kwargs = full_model_kwargs,
|
412 |
+
use_fps_condition = cfg.use_fps_condition)
|
413 |
+
model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one,
|
414 |
+
full_model_kwargs = full_model_kwargs_vis,
|
415 |
+
use_fps_condition = cfg.use_fps_condition)
|
416 |
+
noise_one = noise
|
417 |
+
|
418 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
419 |
+
clip_encoder.cpu() # add this line
|
420 |
+
autoencoder.cpu() # add this line
|
421 |
+
torch.cuda.empty_cache() # add this line
|
422 |
+
|
423 |
+
video_data = diffusion.ddim_sample_loop(
|
424 |
+
noise=noise_one,
|
425 |
+
model=model.eval(),
|
426 |
+
model_kwargs=model_kwargs_one,
|
427 |
+
guide_scale=cfg.guide_scale,
|
428 |
+
ddim_timesteps=cfg.ddim_timesteps,
|
429 |
+
eta=0.0)
|
430 |
+
|
431 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
432 |
+
# if run forward of autoencoder or clip_encoder second times, load them again
|
433 |
+
clip_encoder.cuda()
|
434 |
+
autoencoder.cuda()
|
435 |
+
video_data = 1. / cfg.scale_factor * video_data
|
436 |
+
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
|
437 |
+
chunk_size = min(cfg.decoder_bs, video_data.shape[0])
|
438 |
+
video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0)
|
439 |
+
decode_data = []
|
440 |
+
for vd_data in video_data_list:
|
441 |
+
gen_frames = autoencoder.decode(vd_data)
|
442 |
+
decode_data.append(gen_frames)
|
443 |
+
video_data = torch.cat(decode_data, dim=0)
|
444 |
+
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float()
|
445 |
+
|
446 |
+
text_size = cfg.resolution[-1]
|
447 |
+
cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_')
|
448 |
+
name = f'seed_{cur_seed}'
|
449 |
+
for ii in partial_keys_one:
|
450 |
+
name = name + "_" + ii
|
451 |
+
file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4'
|
452 |
+
local_path = os.path.join(cfg.log_dir, f'{file_name}')
|
453 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
454 |
+
captions = "human"
|
455 |
+
del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]]
|
456 |
+
del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]]
|
457 |
+
|
458 |
+
save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups,
|
459 |
+
cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps)
|
460 |
+
|
461 |
+
# try:
|
462 |
+
# save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
|
463 |
+
# logging.info('Save video to dir %s:' % (local_path))
|
464 |
+
# except Exception as e:
|
465 |
+
# logging.info(f'Step: save text or video error with {e}')
|
466 |
+
|
467 |
+
logging.info('Congratulations! The inference is completed!')
|
468 |
+
# synchronize to finish some processes
|
469 |
+
if not cfg.debug:
|
470 |
+
torch.cuda.synchronize()
|
471 |
+
dist.barrier()
|
472 |
+
|
473 |
+
def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False):
|
474 |
+
|
475 |
+
if use_fps_condition is True:
|
476 |
+
partial_keys.append('fps')
|
477 |
+
|
478 |
+
partial_model_kwargs = [{}, {}]
|
479 |
+
for partial_key in partial_keys:
|
480 |
+
partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
|
481 |
+
partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
|
482 |
+
|
483 |
+
return partial_model_kwargs
|
UniAnimate/tools/inferences/inference_unianimate_long_entrance.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
/*
|
3 |
+
*Copyright (c) 2021, Alibaba Group;
|
4 |
+
*Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
*you may not use this file except in compliance with the License.
|
6 |
+
*You may obtain a copy of the License at
|
7 |
+
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
*Unless required by applicable law or agreed to in writing, software
|
11 |
+
*distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
*See the License for the specific language governing permissions and
|
14 |
+
*limitations under the License.
|
15 |
+
*/
|
16 |
+
'''
|
17 |
+
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import os.path as osp
|
21 |
+
import sys
|
22 |
+
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4]))
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
import torch
|
26 |
+
import pynvml
|
27 |
+
import logging
|
28 |
+
import cv2
|
29 |
+
import numpy as np
|
30 |
+
from PIL import Image
|
31 |
+
from tqdm import tqdm
|
32 |
+
import torch.cuda.amp as amp
|
33 |
+
from importlib import reload
|
34 |
+
import torch.distributed as dist
|
35 |
+
import torch.multiprocessing as mp
|
36 |
+
import random
|
37 |
+
from einops import rearrange
|
38 |
+
import torchvision.transforms as T
|
39 |
+
import torchvision.transforms.functional as TF
|
40 |
+
from torch.nn.parallel import DistributedDataParallel
|
41 |
+
|
42 |
+
import utils.transforms as data
|
43 |
+
from ..modules.config import cfg
|
44 |
+
from utils.seed import setup_seed
|
45 |
+
from utils.multi_port import find_free_port
|
46 |
+
from utils.assign_cfg import assign_signle_cfg
|
47 |
+
from utils.distributed import generalized_all_gather, all_reduce
|
48 |
+
from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col
|
49 |
+
from tools.modules.autoencoder import get_first_stage_encoding
|
50 |
+
from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION
|
51 |
+
from copy import copy
|
52 |
+
import cv2
|
53 |
+
|
54 |
+
|
55 |
+
@INFER_ENGINE.register_function()
|
56 |
+
def inference_unianimate_long_entrance(cfg_update, **kwargs):
|
57 |
+
for k, v in cfg_update.items():
|
58 |
+
if isinstance(v, dict) and k in cfg:
|
59 |
+
cfg[k].update(v)
|
60 |
+
else:
|
61 |
+
cfg[k] = v
|
62 |
+
|
63 |
+
if not 'MASTER_ADDR' in os.environ:
|
64 |
+
os.environ['MASTER_ADDR']='localhost'
|
65 |
+
os.environ['MASTER_PORT']= find_free_port()
|
66 |
+
cfg.pmi_rank = int(os.getenv('RANK', 0))
|
67 |
+
cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
|
68 |
+
|
69 |
+
if cfg.debug:
|
70 |
+
cfg.gpus_per_machine = 1
|
71 |
+
cfg.world_size = 1
|
72 |
+
else:
|
73 |
+
cfg.gpus_per_machine = torch.cuda.device_count()
|
74 |
+
cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine
|
75 |
+
|
76 |
+
if cfg.world_size == 1:
|
77 |
+
worker(0, cfg, cfg_update)
|
78 |
+
else:
|
79 |
+
mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update))
|
80 |
+
return cfg
|
81 |
+
|
82 |
+
|
83 |
+
def make_masked_images(imgs, masks):
|
84 |
+
masked_imgs = []
|
85 |
+
for i, mask in enumerate(masks):
|
86 |
+
# concatenation
|
87 |
+
masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1))
|
88 |
+
return torch.stack(masked_imgs, dim=0)
|
89 |
+
|
90 |
+
def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]):
|
91 |
+
|
92 |
+
for _ in range(5):
|
93 |
+
try:
|
94 |
+
dwpose_all = {}
|
95 |
+
frames_all = {}
|
96 |
+
for ii_index in sorted(os.listdir(pose_file_path)):
|
97 |
+
if ii_index != "ref_pose.jpg":
|
98 |
+
dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index)
|
99 |
+
frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB))
|
100 |
+
# frames_all[ii_index] = Image.open(ref_image_path)
|
101 |
+
|
102 |
+
pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg"))
|
103 |
+
first_eq_ref = False
|
104 |
+
|
105 |
+
# sample max_frames poses for video generation
|
106 |
+
stride = frame_interval
|
107 |
+
_total_frame_num = len(frames_all)
|
108 |
+
if max_frames == "None":
|
109 |
+
max_frames = (_total_frame_num-1)//frame_interval + 1
|
110 |
+
cover_frame_num = (stride * (max_frames-1)+1)
|
111 |
+
if _total_frame_num < cover_frame_num:
|
112 |
+
print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed')
|
113 |
+
start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame
|
114 |
+
end_frame = _total_frame_num
|
115 |
+
stride = max((_total_frame_num-1//(max_frames-1)),1)
|
116 |
+
end_frame = stride*max_frames
|
117 |
+
else:
|
118 |
+
start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame
|
119 |
+
end_frame = start_frame + cover_frame_num
|
120 |
+
|
121 |
+
frame_list = []
|
122 |
+
dwpose_list = []
|
123 |
+
random_ref_frame = frames_all[list(frames_all.keys())[0]]
|
124 |
+
if random_ref_frame.mode != 'RGB':
|
125 |
+
random_ref_frame = random_ref_frame.convert('RGB')
|
126 |
+
random_ref_dwpose = pose_ref
|
127 |
+
if random_ref_dwpose.mode != 'RGB':
|
128 |
+
random_ref_dwpose = random_ref_dwpose.convert('RGB')
|
129 |
+
for i_index in range(start_frame, end_frame, stride):
|
130 |
+
if i_index == start_frame and first_eq_ref:
|
131 |
+
i_key = list(frames_all.keys())[i_index]
|
132 |
+
i_frame = frames_all[i_key]
|
133 |
+
|
134 |
+
if i_frame.mode != 'RGB':
|
135 |
+
i_frame = i_frame.convert('RGB')
|
136 |
+
i_dwpose = frames_pose_ref
|
137 |
+
if i_dwpose.mode != 'RGB':
|
138 |
+
i_dwpose = i_dwpose.convert('RGB')
|
139 |
+
frame_list.append(i_frame)
|
140 |
+
dwpose_list.append(i_dwpose)
|
141 |
+
else:
|
142 |
+
# added
|
143 |
+
if first_eq_ref:
|
144 |
+
i_index = i_index - stride
|
145 |
+
|
146 |
+
i_key = list(frames_all.keys())[i_index]
|
147 |
+
i_frame = frames_all[i_key]
|
148 |
+
if i_frame.mode != 'RGB':
|
149 |
+
i_frame = i_frame.convert('RGB')
|
150 |
+
i_dwpose = dwpose_all[i_key]
|
151 |
+
if i_dwpose.mode != 'RGB':
|
152 |
+
i_dwpose = i_dwpose.convert('RGB')
|
153 |
+
frame_list.append(i_frame)
|
154 |
+
dwpose_list.append(i_dwpose)
|
155 |
+
have_frames = len(frame_list)>0
|
156 |
+
middle_indix = 0
|
157 |
+
if have_frames:
|
158 |
+
ref_frame = frame_list[middle_indix]
|
159 |
+
vit_frame = vit_transforms(ref_frame)
|
160 |
+
random_ref_frame_tmp = train_trans_pose(random_ref_frame)
|
161 |
+
random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose)
|
162 |
+
misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0)
|
163 |
+
video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0)
|
164 |
+
dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0)
|
165 |
+
|
166 |
+
video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
167 |
+
dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
168 |
+
misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
169 |
+
random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768]
|
170 |
+
random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
|
171 |
+
if have_frames:
|
172 |
+
video_data[:len(frame_list), ...] = video_data_tmp
|
173 |
+
misc_data[:len(frame_list), ...] = misc_data_tmp
|
174 |
+
dwpose_data[:len(frame_list), ...] = dwpose_data_tmp
|
175 |
+
random_ref_frame_data[:,...] = random_ref_frame_tmp
|
176 |
+
random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp
|
177 |
+
|
178 |
+
break
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e))
|
182 |
+
continue
|
183 |
+
|
184 |
+
return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
def worker(gpu, cfg, cfg_update):
|
189 |
+
'''
|
190 |
+
Inference worker for each gpu
|
191 |
+
'''
|
192 |
+
for k, v in cfg_update.items():
|
193 |
+
if isinstance(v, dict) and k in cfg:
|
194 |
+
cfg[k].update(v)
|
195 |
+
else:
|
196 |
+
cfg[k] = v
|
197 |
+
|
198 |
+
cfg.gpu = gpu
|
199 |
+
cfg.seed = int(cfg.seed)
|
200 |
+
cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu
|
201 |
+
setup_seed(cfg.seed + cfg.rank)
|
202 |
+
|
203 |
+
if not cfg.debug:
|
204 |
+
torch.cuda.set_device(gpu)
|
205 |
+
torch.backends.cudnn.benchmark = True
|
206 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
207 |
+
torch.backends.cudnn.benchmark = False
|
208 |
+
dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank)
|
209 |
+
|
210 |
+
# [Log] Save logging and make log dir
|
211 |
+
log_dir = generalized_all_gather(cfg.log_dir)[0]
|
212 |
+
inf_name = osp.basename(cfg.cfg_file).split('.')[0]
|
213 |
+
test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1]
|
214 |
+
|
215 |
+
cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name))
|
216 |
+
os.makedirs(cfg.log_dir, exist_ok=True)
|
217 |
+
log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank))
|
218 |
+
cfg.log_file = log_file
|
219 |
+
reload(logging)
|
220 |
+
logging.basicConfig(
|
221 |
+
level=logging.INFO,
|
222 |
+
format='[%(asctime)s] %(levelname)s: %(message)s',
|
223 |
+
handlers=[
|
224 |
+
logging.FileHandler(filename=log_file),
|
225 |
+
logging.StreamHandler(stream=sys.stdout)])
|
226 |
+
logging.info(cfg)
|
227 |
+
logging.info(f"Running UniAnimate inference on gpu {gpu}")
|
228 |
+
|
229 |
+
# [Diffusion]
|
230 |
+
diffusion = DIFFUSION.build(cfg.Diffusion)
|
231 |
+
|
232 |
+
# [Data] Data Transform
|
233 |
+
train_trans = data.Compose([
|
234 |
+
data.Resize(cfg.resolution),
|
235 |
+
data.ToTensor(),
|
236 |
+
data.Normalize(mean=cfg.mean, std=cfg.std)
|
237 |
+
])
|
238 |
+
|
239 |
+
train_trans_pose = data.Compose([
|
240 |
+
data.Resize(cfg.resolution),
|
241 |
+
data.ToTensor(),
|
242 |
+
]
|
243 |
+
)
|
244 |
+
|
245 |
+
vit_transforms = T.Compose([
|
246 |
+
data.Resize(cfg.vit_resolution),
|
247 |
+
T.ToTensor(),
|
248 |
+
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
|
249 |
+
|
250 |
+
# [Model] embedder
|
251 |
+
clip_encoder = EMBEDDER.build(cfg.embedder)
|
252 |
+
clip_encoder.model.to(gpu)
|
253 |
+
with torch.no_grad():
|
254 |
+
_, _, zero_y = clip_encoder(text="")
|
255 |
+
|
256 |
+
|
257 |
+
# [Model] auotoencoder
|
258 |
+
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
|
259 |
+
autoencoder.eval() # freeze
|
260 |
+
for param in autoencoder.parameters():
|
261 |
+
param.requires_grad = False
|
262 |
+
autoencoder.cuda()
|
263 |
+
|
264 |
+
# [Model] UNet
|
265 |
+
if "config" in cfg.UNet:
|
266 |
+
cfg.UNet["config"] = cfg
|
267 |
+
cfg.UNet["zero_y"] = zero_y
|
268 |
+
model = MODEL.build(cfg.UNet)
|
269 |
+
state_dict = torch.load(cfg.test_model, map_location='cpu')
|
270 |
+
if 'state_dict' in state_dict:
|
271 |
+
state_dict = state_dict['state_dict']
|
272 |
+
if 'step' in state_dict:
|
273 |
+
resume_step = state_dict['step']
|
274 |
+
else:
|
275 |
+
resume_step = 0
|
276 |
+
status = model.load_state_dict(state_dict, strict=True)
|
277 |
+
logging.info('Load model from {} with status {}'.format(cfg.test_model, status))
|
278 |
+
model = model.to(gpu)
|
279 |
+
model.eval()
|
280 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
281 |
+
model.to(torch.float16)
|
282 |
+
else:
|
283 |
+
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model
|
284 |
+
torch.cuda.empty_cache()
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
test_list = cfg.test_list_path
|
289 |
+
num_videos = len(test_list)
|
290 |
+
logging.info(f'There are {num_videos} videos. with {cfg.round} times')
|
291 |
+
test_list = [item for _ in range(cfg.round) for item in test_list]
|
292 |
+
|
293 |
+
for idx, file_path in enumerate(test_list):
|
294 |
+
cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2]
|
295 |
+
|
296 |
+
manual_seed = int(cfg.seed + cfg.rank + idx//num_videos)
|
297 |
+
setup_seed(manual_seed)
|
298 |
+
logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...")
|
299 |
+
|
300 |
+
|
301 |
+
vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution)
|
302 |
+
cfg.max_frames_new = max_frames
|
303 |
+
misc_data = misc_data.unsqueeze(0).to(gpu)
|
304 |
+
vit_frame = vit_frame.unsqueeze(0).to(gpu)
|
305 |
+
dwpose_data = dwpose_data.unsqueeze(0).to(gpu)
|
306 |
+
random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu)
|
307 |
+
random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu)
|
308 |
+
|
309 |
+
### save for visualization
|
310 |
+
misc_backups = copy(misc_data)
|
311 |
+
frames_num = misc_data.shape[1]
|
312 |
+
misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
|
313 |
+
mv_data_video = []
|
314 |
+
|
315 |
+
|
316 |
+
### local image (first frame)
|
317 |
+
image_local = []
|
318 |
+
if 'local_image' in cfg.video_compositions:
|
319 |
+
frames_num = misc_data.shape[1]
|
320 |
+
bs_vd_local = misc_data.shape[0]
|
321 |
+
image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1)
|
322 |
+
image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
|
323 |
+
image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local)
|
324 |
+
if hasattr(cfg, "latent_local_image") and cfg.latent_local_image:
|
325 |
+
with torch.no_grad():
|
326 |
+
temporal_length = frames_num
|
327 |
+
encoder_posterior = autoencoder.encode(video_data[:,0])
|
328 |
+
local_image_data = get_first_stage_encoding(encoder_posterior).detach()
|
329 |
+
image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
### encode the video_data
|
334 |
+
bs_vd = misc_data.shape[0]
|
335 |
+
misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
|
336 |
+
misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0)
|
337 |
+
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
|
341 |
+
random_ref_frame = []
|
342 |
+
if 'randomref' in cfg.video_compositions:
|
343 |
+
random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
|
344 |
+
if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref:
|
345 |
+
|
346 |
+
temporal_length = random_ref_frame_data.shape[1]
|
347 |
+
encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5))
|
348 |
+
random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach()
|
349 |
+
random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40]
|
350 |
+
|
351 |
+
random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w')
|
352 |
+
|
353 |
+
|
354 |
+
if 'dwpose' in cfg.video_compositions:
|
355 |
+
bs_vd_local = dwpose_data.shape[0]
|
356 |
+
dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local)
|
357 |
+
if 'randomref_pose' in cfg.video_compositions:
|
358 |
+
dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1)
|
359 |
+
dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local)
|
360 |
+
|
361 |
+
|
362 |
+
y_visual = []
|
363 |
+
if 'image' in cfg.video_compositions:
|
364 |
+
with torch.no_grad():
|
365 |
+
vit_frame = vit_frame.squeeze(1)
|
366 |
+
y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024]
|
367 |
+
y_visual0 = y_visual.clone()
|
368 |
+
|
369 |
+
|
370 |
+
with amp.autocast(enabled=True):
|
371 |
+
pynvml.nvmlInit()
|
372 |
+
handle=pynvml.nvmlDeviceGetHandleByIndex(0)
|
373 |
+
meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle)
|
374 |
+
cur_seed = torch.initial_seed()
|
375 |
+
logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....")
|
376 |
+
|
377 |
+
noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)])
|
378 |
+
noise = noise.to(gpu)
|
379 |
+
|
380 |
+
# add a noise prior
|
381 |
+
noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise)
|
382 |
+
|
383 |
+
if hasattr(cfg.Diffusion, "noise_strength"):
|
384 |
+
b, c, f, _, _= noise.shape
|
385 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
|
386 |
+
noise = noise + cfg.Diffusion.noise_strength * offset_noise
|
387 |
+
|
388 |
+
# construct model inputs (CFG)
|
389 |
+
full_model_kwargs=[{
|
390 |
+
'y': None,
|
391 |
+
"local_image": None if len(image_local) == 0 else image_local[:],
|
392 |
+
'image': None if len(y_visual) == 0 else y_visual0[:],
|
393 |
+
'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:],
|
394 |
+
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:],
|
395 |
+
},
|
396 |
+
{
|
397 |
+
'y': None,
|
398 |
+
"local_image": None,
|
399 |
+
'image': None,
|
400 |
+
'randomref': None,
|
401 |
+
'dwpose': None,
|
402 |
+
}]
|
403 |
+
|
404 |
+
# for visualization
|
405 |
+
full_model_kwargs_vis =[{
|
406 |
+
'y': None,
|
407 |
+
"local_image": None if len(image_local) == 0 else image_local_clone[:],
|
408 |
+
'image': None,
|
409 |
+
'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:],
|
410 |
+
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3],
|
411 |
+
},
|
412 |
+
{
|
413 |
+
'y': None,
|
414 |
+
"local_image": None,
|
415 |
+
'image': None,
|
416 |
+
'randomref': None,
|
417 |
+
'dwpose': None,
|
418 |
+
}]
|
419 |
+
|
420 |
+
|
421 |
+
partial_keys = [
|
422 |
+
['image', 'randomref', "dwpose"],
|
423 |
+
]
|
424 |
+
if hasattr(cfg, "partial_keys") and cfg.partial_keys:
|
425 |
+
partial_keys = cfg.partial_keys
|
426 |
+
|
427 |
+
for partial_keys_one in partial_keys:
|
428 |
+
model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one,
|
429 |
+
full_model_kwargs = full_model_kwargs,
|
430 |
+
use_fps_condition = cfg.use_fps_condition)
|
431 |
+
model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one,
|
432 |
+
full_model_kwargs = full_model_kwargs_vis,
|
433 |
+
use_fps_condition = cfg.use_fps_condition)
|
434 |
+
noise_one = noise
|
435 |
+
|
436 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
437 |
+
clip_encoder.cpu() # add this line
|
438 |
+
autoencoder.cpu() # add this line
|
439 |
+
torch.cuda.empty_cache() # add this line
|
440 |
+
|
441 |
+
video_data = diffusion.ddim_sample_loop(
|
442 |
+
noise=noise_one,
|
443 |
+
context_size=cfg.context_size,
|
444 |
+
context_stride=cfg.context_stride,
|
445 |
+
context_overlap=cfg.context_overlap,
|
446 |
+
model=model.eval(),
|
447 |
+
model_kwargs=model_kwargs_one,
|
448 |
+
guide_scale=cfg.guide_scale,
|
449 |
+
ddim_timesteps=cfg.ddim_timesteps,
|
450 |
+
eta=0.0,
|
451 |
+
context_batch_size=getattr(cfg, "context_batch_size", 1)
|
452 |
+
)
|
453 |
+
|
454 |
+
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE:
|
455 |
+
# if run forward of autoencoder or clip_encoder second times, load them again
|
456 |
+
clip_encoder.cuda()
|
457 |
+
autoencoder.cuda()
|
458 |
+
|
459 |
+
|
460 |
+
video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w]
|
461 |
+
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
|
462 |
+
chunk_size = min(cfg.decoder_bs, video_data.shape[0])
|
463 |
+
video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0)
|
464 |
+
decode_data = []
|
465 |
+
for vd_data in video_data_list:
|
466 |
+
gen_frames = autoencoder.decode(vd_data)
|
467 |
+
decode_data.append(gen_frames)
|
468 |
+
video_data = torch.cat(decode_data, dim=0)
|
469 |
+
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float()
|
470 |
+
|
471 |
+
text_size = cfg.resolution[-1]
|
472 |
+
cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_')
|
473 |
+
name = f'seed_{cur_seed}'
|
474 |
+
for ii in partial_keys_one:
|
475 |
+
name = name + "_" + ii
|
476 |
+
file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4'
|
477 |
+
local_path = os.path.join(cfg.log_dir, f'{file_name}')
|
478 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
479 |
+
captions = "human"
|
480 |
+
del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]]
|
481 |
+
del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]]
|
482 |
+
|
483 |
+
save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups,
|
484 |
+
cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps)
|
485 |
+
|
486 |
+
# try:
|
487 |
+
# save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size)
|
488 |
+
# logging.info('Save video to dir %s:' % (local_path))
|
489 |
+
# except Exception as e:
|
490 |
+
# logging.info(f'Step: save text or video error with {e}')
|
491 |
+
|
492 |
+
logging.info('Congratulations! The inference is completed!')
|
493 |
+
# synchronize to finish some processes
|
494 |
+
if not cfg.debug:
|
495 |
+
torch.cuda.synchronize()
|
496 |
+
dist.barrier()
|
497 |
+
|
498 |
+
def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False):
|
499 |
+
|
500 |
+
if use_fps_condition is True:
|
501 |
+
partial_keys.append('fps')
|
502 |
+
|
503 |
+
partial_model_kwargs = [{}, {}]
|
504 |
+
for partial_key in partial_keys:
|
505 |
+
partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key]
|
506 |
+
partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key]
|
507 |
+
|
508 |
+
return partial_model_kwargs
|
UniAnimate/tools/modules/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clip_embedder import FrozenOpenCLIPEmbedder
|
2 |
+
from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL
|
3 |
+
from .clip_embedder import *
|
4 |
+
from .autoencoder import *
|
5 |
+
from .unet import *
|
6 |
+
from .diffusions import *
|
7 |
+
from .embedding_manager import *
|
UniAnimate/tools/modules/autoencoder.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
import collections
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from utils.registry_class import AUTO_ENCODER,DISTRIBUTION
|
9 |
+
|
10 |
+
|
11 |
+
def nonlinearity(x):
|
12 |
+
# swish
|
13 |
+
return x*torch.sigmoid(x)
|
14 |
+
|
15 |
+
def Normalize(in_channels, num_groups=32):
|
16 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
17 |
+
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def get_first_stage_encoding(encoder_posterior, scale_factor=0.18215):
|
21 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
22 |
+
z = encoder_posterior.sample()
|
23 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
24 |
+
z = encoder_posterior
|
25 |
+
else:
|
26 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
27 |
+
return scale_factor * z
|
28 |
+
|
29 |
+
|
30 |
+
@AUTO_ENCODER.register_class()
|
31 |
+
class AutoencoderKL(nn.Module):
|
32 |
+
def __init__(self,
|
33 |
+
ddconfig,
|
34 |
+
embed_dim,
|
35 |
+
pretrained=None,
|
36 |
+
ignore_keys=[],
|
37 |
+
image_key="image",
|
38 |
+
colorize_nlabels=None,
|
39 |
+
monitor=None,
|
40 |
+
ema_decay=None,
|
41 |
+
learn_logvar=False,
|
42 |
+
use_vid_decoder=False,
|
43 |
+
**kwargs):
|
44 |
+
super().__init__()
|
45 |
+
self.learn_logvar = learn_logvar
|
46 |
+
self.image_key = image_key
|
47 |
+
self.encoder = Encoder(**ddconfig)
|
48 |
+
self.decoder = Decoder(**ddconfig)
|
49 |
+
assert ddconfig["double_z"]
|
50 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
51 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
52 |
+
self.embed_dim = embed_dim
|
53 |
+
if colorize_nlabels is not None:
|
54 |
+
assert type(colorize_nlabels)==int
|
55 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
56 |
+
if monitor is not None:
|
57 |
+
self.monitor = monitor
|
58 |
+
|
59 |
+
self.use_ema = ema_decay is not None
|
60 |
+
|
61 |
+
if pretrained is not None:
|
62 |
+
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
|
63 |
+
|
64 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
65 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
66 |
+
keys = list(sd.keys())
|
67 |
+
sd_new = collections.OrderedDict()
|
68 |
+
for k in keys:
|
69 |
+
if k.find('first_stage_model') >= 0:
|
70 |
+
k_new = k.split('first_stage_model.')[-1]
|
71 |
+
sd_new[k_new] = sd[k]
|
72 |
+
self.load_state_dict(sd_new, strict=True)
|
73 |
+
logging.info(f"Restored from {path}")
|
74 |
+
|
75 |
+
def on_train_batch_end(self, *args, **kwargs):
|
76 |
+
if self.use_ema:
|
77 |
+
self.model_ema(self)
|
78 |
+
|
79 |
+
def encode(self, x):
|
80 |
+
h = self.encoder(x)
|
81 |
+
moments = self.quant_conv(h)
|
82 |
+
posterior = DiagonalGaussianDistribution(moments)
|
83 |
+
return posterior
|
84 |
+
|
85 |
+
def encode_firsr_stage(self, x, scale_factor=1.0):
|
86 |
+
h = self.encoder(x)
|
87 |
+
moments = self.quant_conv(h)
|
88 |
+
posterior = DiagonalGaussianDistribution(moments)
|
89 |
+
z = get_first_stage_encoding(posterior, scale_factor)
|
90 |
+
return z
|
91 |
+
|
92 |
+
def encode_ms(self, x):
|
93 |
+
hs = self.encoder(x, True)
|
94 |
+
h = hs[-1]
|
95 |
+
moments = self.quant_conv(h)
|
96 |
+
posterior = DiagonalGaussianDistribution(moments)
|
97 |
+
hs[-1] = h
|
98 |
+
return hs
|
99 |
+
|
100 |
+
def decode(self, z, **kwargs):
|
101 |
+
z = self.post_quant_conv(z)
|
102 |
+
dec = self.decoder(z, **kwargs)
|
103 |
+
return dec
|
104 |
+
|
105 |
+
|
106 |
+
def forward(self, input, sample_posterior=True):
|
107 |
+
posterior = self.encode(input)
|
108 |
+
if sample_posterior:
|
109 |
+
z = posterior.sample()
|
110 |
+
else:
|
111 |
+
z = posterior.mode()
|
112 |
+
dec = self.decode(z)
|
113 |
+
return dec, posterior
|
114 |
+
|
115 |
+
def get_input(self, batch, k):
|
116 |
+
x = batch[k]
|
117 |
+
if len(x.shape) == 3:
|
118 |
+
x = x[..., None]
|
119 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
120 |
+
return x
|
121 |
+
|
122 |
+
def get_last_layer(self):
|
123 |
+
return self.decoder.conv_out.weight
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
127 |
+
log = dict()
|
128 |
+
x = self.get_input(batch, self.image_key)
|
129 |
+
x = x.to(self.device)
|
130 |
+
if not only_inputs:
|
131 |
+
xrec, posterior = self(x)
|
132 |
+
if x.shape[1] > 3:
|
133 |
+
# colorize with random projection
|
134 |
+
assert xrec.shape[1] > 3
|
135 |
+
x = self.to_rgb(x)
|
136 |
+
xrec = self.to_rgb(xrec)
|
137 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
138 |
+
log["reconstructions"] = xrec
|
139 |
+
if log_ema or self.use_ema:
|
140 |
+
with self.ema_scope():
|
141 |
+
xrec_ema, posterior_ema = self(x)
|
142 |
+
if x.shape[1] > 3:
|
143 |
+
# colorize with random projection
|
144 |
+
assert xrec_ema.shape[1] > 3
|
145 |
+
xrec_ema = self.to_rgb(xrec_ema)
|
146 |
+
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
147 |
+
log["reconstructions_ema"] = xrec_ema
|
148 |
+
log["inputs"] = x
|
149 |
+
return log
|
150 |
+
|
151 |
+
def to_rgb(self, x):
|
152 |
+
assert self.image_key == "segmentation"
|
153 |
+
if not hasattr(self, "colorize"):
|
154 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
155 |
+
x = F.conv2d(x, weight=self.colorize)
|
156 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
@AUTO_ENCODER.register_class()
|
161 |
+
class AutoencoderVideo(AutoencoderKL):
|
162 |
+
def __init__(self,
|
163 |
+
ddconfig,
|
164 |
+
embed_dim,
|
165 |
+
pretrained=None,
|
166 |
+
ignore_keys=[],
|
167 |
+
image_key="image",
|
168 |
+
colorize_nlabels=None,
|
169 |
+
monitor=None,
|
170 |
+
ema_decay=None,
|
171 |
+
use_vid_decoder=True,
|
172 |
+
learn_logvar=False,
|
173 |
+
**kwargs):
|
174 |
+
use_vid_decoder = True
|
175 |
+
super().__init__(ddconfig, embed_dim, pretrained, ignore_keys, image_key, colorize_nlabels, monitor, ema_decay, learn_logvar, use_vid_decoder, **kwargs)
|
176 |
+
|
177 |
+
def decode(self, z, **kwargs):
|
178 |
+
# z = self.post_quant_conv(z)
|
179 |
+
dec = self.decoder(z, **kwargs)
|
180 |
+
return dec
|
181 |
+
|
182 |
+
def encode(self, x):
|
183 |
+
h = self.encoder(x)
|
184 |
+
# moments = self.quant_conv(h)
|
185 |
+
moments = h
|
186 |
+
posterior = DiagonalGaussianDistribution(moments)
|
187 |
+
return posterior
|
188 |
+
|
189 |
+
|
190 |
+
class IdentityFirstStage(torch.nn.Module):
|
191 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
192 |
+
self.vq_interface = vq_interface
|
193 |
+
super().__init__()
|
194 |
+
|
195 |
+
def encode(self, x, *args, **kwargs):
|
196 |
+
return x
|
197 |
+
|
198 |
+
def decode(self, x, *args, **kwargs):
|
199 |
+
return x
|
200 |
+
|
201 |
+
def quantize(self, x, *args, **kwargs):
|
202 |
+
if self.vq_interface:
|
203 |
+
return x, None, [None, None, None]
|
204 |
+
return x
|
205 |
+
|
206 |
+
def forward(self, x, *args, **kwargs):
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
@DISTRIBUTION.register_class()
|
212 |
+
class DiagonalGaussianDistribution(object):
|
213 |
+
def __init__(self, parameters, deterministic=False):
|
214 |
+
self.parameters = parameters
|
215 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
216 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
217 |
+
self.deterministic = deterministic
|
218 |
+
self.std = torch.exp(0.5 * self.logvar)
|
219 |
+
self.var = torch.exp(self.logvar)
|
220 |
+
if self.deterministic:
|
221 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
222 |
+
|
223 |
+
def sample(self):
|
224 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
225 |
+
return x
|
226 |
+
|
227 |
+
def kl(self, other=None):
|
228 |
+
if self.deterministic:
|
229 |
+
return torch.Tensor([0.])
|
230 |
+
else:
|
231 |
+
if other is None:
|
232 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
233 |
+
+ self.var - 1.0 - self.logvar,
|
234 |
+
dim=[1, 2, 3])
|
235 |
+
else:
|
236 |
+
return 0.5 * torch.sum(
|
237 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
238 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
239 |
+
dim=[1, 2, 3])
|
240 |
+
|
241 |
+
def nll(self, sample, dims=[1,2,3]):
|
242 |
+
if self.deterministic:
|
243 |
+
return torch.Tensor([0.])
|
244 |
+
logtwopi = np.log(2.0 * np.pi)
|
245 |
+
return 0.5 * torch.sum(
|
246 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
247 |
+
dim=dims)
|
248 |
+
|
249 |
+
def mode(self):
|
250 |
+
return self.mean
|
251 |
+
|
252 |
+
|
253 |
+
# -------------------------------modules--------------------------------
|
254 |
+
|
255 |
+
class Downsample(nn.Module):
|
256 |
+
def __init__(self, in_channels, with_conv):
|
257 |
+
super().__init__()
|
258 |
+
self.with_conv = with_conv
|
259 |
+
if self.with_conv:
|
260 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
261 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
262 |
+
in_channels,
|
263 |
+
kernel_size=3,
|
264 |
+
stride=2,
|
265 |
+
padding=0)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
if self.with_conv:
|
269 |
+
pad = (0,1,0,1)
|
270 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
271 |
+
x = self.conv(x)
|
272 |
+
else:
|
273 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
274 |
+
return x
|
275 |
+
|
276 |
+
class ResnetBlock(nn.Module):
|
277 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
278 |
+
dropout, temb_channels=512):
|
279 |
+
super().__init__()
|
280 |
+
self.in_channels = in_channels
|
281 |
+
out_channels = in_channels if out_channels is None else out_channels
|
282 |
+
self.out_channels = out_channels
|
283 |
+
self.use_conv_shortcut = conv_shortcut
|
284 |
+
|
285 |
+
self.norm1 = Normalize(in_channels)
|
286 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
287 |
+
out_channels,
|
288 |
+
kernel_size=3,
|
289 |
+
stride=1,
|
290 |
+
padding=1)
|
291 |
+
if temb_channels > 0:
|
292 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
293 |
+
out_channels)
|
294 |
+
self.norm2 = Normalize(out_channels)
|
295 |
+
self.dropout = torch.nn.Dropout(dropout)
|
296 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
297 |
+
out_channels,
|
298 |
+
kernel_size=3,
|
299 |
+
stride=1,
|
300 |
+
padding=1)
|
301 |
+
if self.in_channels != self.out_channels:
|
302 |
+
if self.use_conv_shortcut:
|
303 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
304 |
+
out_channels,
|
305 |
+
kernel_size=3,
|
306 |
+
stride=1,
|
307 |
+
padding=1)
|
308 |
+
else:
|
309 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
310 |
+
out_channels,
|
311 |
+
kernel_size=1,
|
312 |
+
stride=1,
|
313 |
+
padding=0)
|
314 |
+
|
315 |
+
def forward(self, x, temb):
|
316 |
+
h = x
|
317 |
+
h = self.norm1(h)
|
318 |
+
h = nonlinearity(h)
|
319 |
+
h = self.conv1(h)
|
320 |
+
|
321 |
+
if temb is not None:
|
322 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
323 |
+
|
324 |
+
h = self.norm2(h)
|
325 |
+
h = nonlinearity(h)
|
326 |
+
h = self.dropout(h)
|
327 |
+
h = self.conv2(h)
|
328 |
+
|
329 |
+
if self.in_channels != self.out_channels:
|
330 |
+
if self.use_conv_shortcut:
|
331 |
+
x = self.conv_shortcut(x)
|
332 |
+
else:
|
333 |
+
x = self.nin_shortcut(x)
|
334 |
+
|
335 |
+
return x+h
|
336 |
+
|
337 |
+
|
338 |
+
class AttnBlock(nn.Module):
|
339 |
+
def __init__(self, in_channels):
|
340 |
+
super().__init__()
|
341 |
+
self.in_channels = in_channels
|
342 |
+
|
343 |
+
self.norm = Normalize(in_channels)
|
344 |
+
self.q = torch.nn.Conv2d(in_channels,
|
345 |
+
in_channels,
|
346 |
+
kernel_size=1,
|
347 |
+
stride=1,
|
348 |
+
padding=0)
|
349 |
+
self.k = torch.nn.Conv2d(in_channels,
|
350 |
+
in_channels,
|
351 |
+
kernel_size=1,
|
352 |
+
stride=1,
|
353 |
+
padding=0)
|
354 |
+
self.v = torch.nn.Conv2d(in_channels,
|
355 |
+
in_channels,
|
356 |
+
kernel_size=1,
|
357 |
+
stride=1,
|
358 |
+
padding=0)
|
359 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
360 |
+
in_channels,
|
361 |
+
kernel_size=1,
|
362 |
+
stride=1,
|
363 |
+
padding=0)
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
h_ = x
|
367 |
+
h_ = self.norm(h_)
|
368 |
+
q = self.q(h_)
|
369 |
+
k = self.k(h_)
|
370 |
+
v = self.v(h_)
|
371 |
+
|
372 |
+
# compute attention
|
373 |
+
b,c,h,w = q.shape
|
374 |
+
q = q.reshape(b,c,h*w)
|
375 |
+
q = q.permute(0,2,1) # b,hw,c
|
376 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
377 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
378 |
+
w_ = w_ * (int(c)**(-0.5))
|
379 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
380 |
+
|
381 |
+
# attend to values
|
382 |
+
v = v.reshape(b,c,h*w)
|
383 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
384 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
385 |
+
h_ = h_.reshape(b,c,h,w)
|
386 |
+
|
387 |
+
h_ = self.proj_out(h_)
|
388 |
+
|
389 |
+
return x+h_
|
390 |
+
|
391 |
+
class AttnBlock(nn.Module):
|
392 |
+
def __init__(self, in_channels):
|
393 |
+
super().__init__()
|
394 |
+
self.in_channels = in_channels
|
395 |
+
|
396 |
+
self.norm = Normalize(in_channels)
|
397 |
+
self.q = torch.nn.Conv2d(in_channels,
|
398 |
+
in_channels,
|
399 |
+
kernel_size=1,
|
400 |
+
stride=1,
|
401 |
+
padding=0)
|
402 |
+
self.k = torch.nn.Conv2d(in_channels,
|
403 |
+
in_channels,
|
404 |
+
kernel_size=1,
|
405 |
+
stride=1,
|
406 |
+
padding=0)
|
407 |
+
self.v = torch.nn.Conv2d(in_channels,
|
408 |
+
in_channels,
|
409 |
+
kernel_size=1,
|
410 |
+
stride=1,
|
411 |
+
padding=0)
|
412 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
413 |
+
in_channels,
|
414 |
+
kernel_size=1,
|
415 |
+
stride=1,
|
416 |
+
padding=0)
|
417 |
+
|
418 |
+
def forward(self, x):
|
419 |
+
h_ = x
|
420 |
+
h_ = self.norm(h_)
|
421 |
+
q = self.q(h_)
|
422 |
+
k = self.k(h_)
|
423 |
+
v = self.v(h_)
|
424 |
+
|
425 |
+
# compute attention
|
426 |
+
b,c,h,w = q.shape
|
427 |
+
q = q.reshape(b,c,h*w)
|
428 |
+
q = q.permute(0,2,1) # b,hw,c
|
429 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
430 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
431 |
+
w_ = w_ * (int(c)**(-0.5))
|
432 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
433 |
+
|
434 |
+
# attend to values
|
435 |
+
v = v.reshape(b,c,h*w)
|
436 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
437 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
438 |
+
h_ = h_.reshape(b,c,h,w)
|
439 |
+
|
440 |
+
h_ = self.proj_out(h_)
|
441 |
+
|
442 |
+
return x+h_
|
443 |
+
|
444 |
+
class Upsample(nn.Module):
|
445 |
+
def __init__(self, in_channels, with_conv):
|
446 |
+
super().__init__()
|
447 |
+
self.with_conv = with_conv
|
448 |
+
if self.with_conv:
|
449 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
450 |
+
in_channels,
|
451 |
+
kernel_size=3,
|
452 |
+
stride=1,
|
453 |
+
padding=1)
|
454 |
+
|
455 |
+
def forward(self, x):
|
456 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
457 |
+
if self.with_conv:
|
458 |
+
x = self.conv(x)
|
459 |
+
return x
|
460 |
+
|
461 |
+
|
462 |
+
class Downsample(nn.Module):
|
463 |
+
def __init__(self, in_channels, with_conv):
|
464 |
+
super().__init__()
|
465 |
+
self.with_conv = with_conv
|
466 |
+
if self.with_conv:
|
467 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
468 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
469 |
+
in_channels,
|
470 |
+
kernel_size=3,
|
471 |
+
stride=2,
|
472 |
+
padding=0)
|
473 |
+
|
474 |
+
def forward(self, x):
|
475 |
+
if self.with_conv:
|
476 |
+
pad = (0,1,0,1)
|
477 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
478 |
+
x = self.conv(x)
|
479 |
+
else:
|
480 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
481 |
+
return x
|
482 |
+
|
483 |
+
class Encoder(nn.Module):
|
484 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
485 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
486 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
487 |
+
**ignore_kwargs):
|
488 |
+
super().__init__()
|
489 |
+
if use_linear_attn: attn_type = "linear"
|
490 |
+
self.ch = ch
|
491 |
+
self.temb_ch = 0
|
492 |
+
self.num_resolutions = len(ch_mult)
|
493 |
+
self.num_res_blocks = num_res_blocks
|
494 |
+
self.resolution = resolution
|
495 |
+
self.in_channels = in_channels
|
496 |
+
|
497 |
+
# downsampling
|
498 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
499 |
+
self.ch,
|
500 |
+
kernel_size=3,
|
501 |
+
stride=1,
|
502 |
+
padding=1)
|
503 |
+
|
504 |
+
curr_res = resolution
|
505 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
506 |
+
self.in_ch_mult = in_ch_mult
|
507 |
+
self.down = nn.ModuleList()
|
508 |
+
for i_level in range(self.num_resolutions):
|
509 |
+
block = nn.ModuleList()
|
510 |
+
attn = nn.ModuleList()
|
511 |
+
block_in = ch*in_ch_mult[i_level]
|
512 |
+
block_out = ch*ch_mult[i_level]
|
513 |
+
for i_block in range(self.num_res_blocks):
|
514 |
+
block.append(ResnetBlock(in_channels=block_in,
|
515 |
+
out_channels=block_out,
|
516 |
+
temb_channels=self.temb_ch,
|
517 |
+
dropout=dropout))
|
518 |
+
block_in = block_out
|
519 |
+
if curr_res in attn_resolutions:
|
520 |
+
attn.append(AttnBlock(block_in))
|
521 |
+
down = nn.Module()
|
522 |
+
down.block = block
|
523 |
+
down.attn = attn
|
524 |
+
if i_level != self.num_resolutions-1:
|
525 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
526 |
+
curr_res = curr_res // 2
|
527 |
+
self.down.append(down)
|
528 |
+
|
529 |
+
# middle
|
530 |
+
self.mid = nn.Module()
|
531 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
532 |
+
out_channels=block_in,
|
533 |
+
temb_channels=self.temb_ch,
|
534 |
+
dropout=dropout)
|
535 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
536 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
537 |
+
out_channels=block_in,
|
538 |
+
temb_channels=self.temb_ch,
|
539 |
+
dropout=dropout)
|
540 |
+
|
541 |
+
# end
|
542 |
+
self.norm_out = Normalize(block_in)
|
543 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
544 |
+
2*z_channels if double_z else z_channels,
|
545 |
+
kernel_size=3,
|
546 |
+
stride=1,
|
547 |
+
padding=1)
|
548 |
+
|
549 |
+
def forward(self, x, return_feat=False):
|
550 |
+
# timestep embedding
|
551 |
+
temb = None
|
552 |
+
|
553 |
+
# downsampling
|
554 |
+
hs = [self.conv_in(x)]
|
555 |
+
for i_level in range(self.num_resolutions):
|
556 |
+
for i_block in range(self.num_res_blocks):
|
557 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
558 |
+
if len(self.down[i_level].attn) > 0:
|
559 |
+
h = self.down[i_level].attn[i_block](h)
|
560 |
+
hs.append(h)
|
561 |
+
if i_level != self.num_resolutions-1:
|
562 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
563 |
+
|
564 |
+
# middle
|
565 |
+
h = hs[-1]
|
566 |
+
h = self.mid.block_1(h, temb)
|
567 |
+
h = self.mid.attn_1(h)
|
568 |
+
h = self.mid.block_2(h, temb)
|
569 |
+
|
570 |
+
# end
|
571 |
+
h = self.norm_out(h)
|
572 |
+
h = nonlinearity(h)
|
573 |
+
h = self.conv_out(h)
|
574 |
+
if return_feat:
|
575 |
+
hs[-1] = h
|
576 |
+
return hs
|
577 |
+
else:
|
578 |
+
return h
|
579 |
+
|
580 |
+
|
581 |
+
class Decoder(nn.Module):
|
582 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
583 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
584 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
585 |
+
attn_type="vanilla", **ignorekwargs):
|
586 |
+
super().__init__()
|
587 |
+
if use_linear_attn: attn_type = "linear"
|
588 |
+
self.ch = ch
|
589 |
+
self.temb_ch = 0
|
590 |
+
self.num_resolutions = len(ch_mult)
|
591 |
+
self.num_res_blocks = num_res_blocks
|
592 |
+
self.resolution = resolution
|
593 |
+
self.in_channels = in_channels
|
594 |
+
self.give_pre_end = give_pre_end
|
595 |
+
self.tanh_out = tanh_out
|
596 |
+
|
597 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
598 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
599 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
600 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
601 |
+
self.z_shape = (1,z_channels, curr_res, curr_res)
|
602 |
+
# logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
603 |
+
|
604 |
+
# z to block_in
|
605 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
606 |
+
block_in,
|
607 |
+
kernel_size=3,
|
608 |
+
stride=1,
|
609 |
+
padding=1)
|
610 |
+
|
611 |
+
# middle
|
612 |
+
self.mid = nn.Module()
|
613 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
614 |
+
out_channels=block_in,
|
615 |
+
temb_channels=self.temb_ch,
|
616 |
+
dropout=dropout)
|
617 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
618 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
619 |
+
out_channels=block_in,
|
620 |
+
temb_channels=self.temb_ch,
|
621 |
+
dropout=dropout)
|
622 |
+
|
623 |
+
# upsampling
|
624 |
+
self.up = nn.ModuleList()
|
625 |
+
for i_level in reversed(range(self.num_resolutions)):
|
626 |
+
block = nn.ModuleList()
|
627 |
+
attn = nn.ModuleList()
|
628 |
+
block_out = ch*ch_mult[i_level]
|
629 |
+
for i_block in range(self.num_res_blocks+1):
|
630 |
+
block.append(ResnetBlock(in_channels=block_in,
|
631 |
+
out_channels=block_out,
|
632 |
+
temb_channels=self.temb_ch,
|
633 |
+
dropout=dropout))
|
634 |
+
block_in = block_out
|
635 |
+
if curr_res in attn_resolutions:
|
636 |
+
attn.append(AttnBlock(block_in))
|
637 |
+
up = nn.Module()
|
638 |
+
up.block = block
|
639 |
+
up.attn = attn
|
640 |
+
if i_level != 0:
|
641 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
642 |
+
curr_res = curr_res * 2
|
643 |
+
self.up.insert(0, up) # prepend to get consistent order
|
644 |
+
|
645 |
+
# end
|
646 |
+
self.norm_out = Normalize(block_in)
|
647 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
648 |
+
out_ch,
|
649 |
+
kernel_size=3,
|
650 |
+
stride=1,
|
651 |
+
padding=1)
|
652 |
+
|
653 |
+
def forward(self, z, **kwargs):
|
654 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
655 |
+
self.last_z_shape = z.shape
|
656 |
+
|
657 |
+
# timestep embedding
|
658 |
+
temb = None
|
659 |
+
|
660 |
+
# z to block_in
|
661 |
+
h = self.conv_in(z)
|
662 |
+
|
663 |
+
# middle
|
664 |
+
h = self.mid.block_1(h, temb)
|
665 |
+
h = self.mid.attn_1(h)
|
666 |
+
h = self.mid.block_2(h, temb)
|
667 |
+
|
668 |
+
# upsampling
|
669 |
+
for i_level in reversed(range(self.num_resolutions)):
|
670 |
+
for i_block in range(self.num_res_blocks+1):
|
671 |
+
h = self.up[i_level].block[i_block](h, temb)
|
672 |
+
if len(self.up[i_level].attn) > 0:
|
673 |
+
h = self.up[i_level].attn[i_block](h)
|
674 |
+
if i_level != 0:
|
675 |
+
h = self.up[i_level].upsample(h)
|
676 |
+
|
677 |
+
# end
|
678 |
+
if self.give_pre_end:
|
679 |
+
return h
|
680 |
+
|
681 |
+
h = self.norm_out(h)
|
682 |
+
h = nonlinearity(h)
|
683 |
+
h = self.conv_out(h)
|
684 |
+
if self.tanh_out:
|
685 |
+
h = torch.tanh(h)
|
686 |
+
return h
|
687 |
+
|
688 |
+
|
689 |
+
|
690 |
+
|
UniAnimate/tools/modules/clip_embedder.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
import open_clip
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as T
|
8 |
+
|
9 |
+
from utils.registry_class import EMBEDDER
|
10 |
+
|
11 |
+
|
12 |
+
@EMBEDDER.register_class()
|
13 |
+
class FrozenOpenCLIPEmbedder(nn.Module):
|
14 |
+
"""
|
15 |
+
Uses the OpenCLIP transformer encoder for text
|
16 |
+
"""
|
17 |
+
LAYERS = [
|
18 |
+
#"pooled",
|
19 |
+
"last",
|
20 |
+
"penultimate"
|
21 |
+
]
|
22 |
+
def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
|
23 |
+
freeze=True, layer="last"):
|
24 |
+
super().__init__()
|
25 |
+
assert layer in self.LAYERS
|
26 |
+
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
|
27 |
+
del model.visual
|
28 |
+
self.model = model
|
29 |
+
|
30 |
+
self.device = device
|
31 |
+
self.max_length = max_length
|
32 |
+
if freeze:
|
33 |
+
self.freeze()
|
34 |
+
self.layer = layer
|
35 |
+
if self.layer == "last":
|
36 |
+
self.layer_idx = 0
|
37 |
+
elif self.layer == "penultimate":
|
38 |
+
self.layer_idx = 1
|
39 |
+
else:
|
40 |
+
raise NotImplementedError()
|
41 |
+
|
42 |
+
def freeze(self):
|
43 |
+
self.model = self.model.eval()
|
44 |
+
for param in self.parameters():
|
45 |
+
param.requires_grad = False
|
46 |
+
|
47 |
+
def forward(self, text):
|
48 |
+
tokens = open_clip.tokenize(text)
|
49 |
+
z = self.encode_with_transformer(tokens.to(self.device))
|
50 |
+
return z
|
51 |
+
|
52 |
+
def encode_with_transformer(self, text):
|
53 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
54 |
+
x = x + self.model.positional_embedding
|
55 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
56 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
57 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
58 |
+
x = self.model.ln_final(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
62 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
63 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
64 |
+
break
|
65 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
66 |
+
x = checkpoint(r, x, attn_mask)
|
67 |
+
else:
|
68 |
+
x = r(x, attn_mask=attn_mask)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def encode(self, text):
|
72 |
+
return self(text)
|
73 |
+
|
74 |
+
|
75 |
+
@EMBEDDER.register_class()
|
76 |
+
class FrozenOpenCLIPVisualEmbedder(nn.Module):
|
77 |
+
"""
|
78 |
+
Uses the OpenCLIP transformer encoder for text
|
79 |
+
"""
|
80 |
+
LAYERS = [
|
81 |
+
#"pooled",
|
82 |
+
"last",
|
83 |
+
"penultimate"
|
84 |
+
]
|
85 |
+
def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77,
|
86 |
+
freeze=True, layer="last"):
|
87 |
+
super().__init__()
|
88 |
+
assert layer in self.LAYERS
|
89 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
90 |
+
arch, device=torch.device('cpu'), pretrained=pretrained)
|
91 |
+
|
92 |
+
del model.transformer
|
93 |
+
self.model = model
|
94 |
+
data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255
|
95 |
+
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
|
96 |
+
|
97 |
+
self.device = device
|
98 |
+
self.max_length = max_length # 77
|
99 |
+
if freeze:
|
100 |
+
self.freeze()
|
101 |
+
self.layer = layer # 'penultimate'
|
102 |
+
if self.layer == "last":
|
103 |
+
self.layer_idx = 0
|
104 |
+
elif self.layer == "penultimate":
|
105 |
+
self.layer_idx = 1
|
106 |
+
else:
|
107 |
+
raise NotImplementedError()
|
108 |
+
|
109 |
+
def freeze(self):
|
110 |
+
self.model = self.model.eval()
|
111 |
+
for param in self.parameters():
|
112 |
+
param.requires_grad = False
|
113 |
+
|
114 |
+
def forward(self, image):
|
115 |
+
# tokens = open_clip.tokenize(text)
|
116 |
+
z = self.model.encode_image(image.to(self.device))
|
117 |
+
return z
|
118 |
+
|
119 |
+
def encode_with_transformer(self, text):
|
120 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
121 |
+
x = x + self.model.positional_embedding
|
122 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
123 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
124 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
125 |
+
x = self.model.ln_final(x)
|
126 |
+
|
127 |
+
return x
|
128 |
+
|
129 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
130 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
131 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
132 |
+
break
|
133 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
134 |
+
x = checkpoint(r, x, attn_mask)
|
135 |
+
else:
|
136 |
+
x = r(x, attn_mask=attn_mask)
|
137 |
+
return x
|
138 |
+
|
139 |
+
def encode(self, text):
|
140 |
+
return self(text)
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
@EMBEDDER.register_class()
|
145 |
+
class FrozenOpenCLIPTextVisualEmbedder(nn.Module):
|
146 |
+
"""
|
147 |
+
Uses the OpenCLIP transformer encoder for text
|
148 |
+
"""
|
149 |
+
LAYERS = [
|
150 |
+
#"pooled",
|
151 |
+
"last",
|
152 |
+
"penultimate"
|
153 |
+
]
|
154 |
+
def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77,
|
155 |
+
freeze=True, layer="last", **kwargs):
|
156 |
+
super().__init__()
|
157 |
+
assert layer in self.LAYERS
|
158 |
+
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
|
159 |
+
self.model = model
|
160 |
+
|
161 |
+
self.device = device
|
162 |
+
self.max_length = max_length
|
163 |
+
if freeze:
|
164 |
+
self.freeze()
|
165 |
+
self.layer = layer
|
166 |
+
if self.layer == "last":
|
167 |
+
self.layer_idx = 0
|
168 |
+
elif self.layer == "penultimate":
|
169 |
+
self.layer_idx = 1
|
170 |
+
else:
|
171 |
+
raise NotImplementedError()
|
172 |
+
|
173 |
+
def freeze(self):
|
174 |
+
self.model = self.model.eval()
|
175 |
+
for param in self.parameters():
|
176 |
+
param.requires_grad = False
|
177 |
+
|
178 |
+
|
179 |
+
def forward(self, image=None, text=None):
|
180 |
+
|
181 |
+
xi = self.model.encode_image(image.to(self.device)) if image is not None else None
|
182 |
+
tokens = open_clip.tokenize(text)
|
183 |
+
xt, x = self.encode_with_transformer(tokens.to(self.device))
|
184 |
+
return xi, xt, x
|
185 |
+
|
186 |
+
def encode_with_transformer(self, text):
|
187 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
188 |
+
x = x + self.model.positional_embedding
|
189 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
190 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
191 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
192 |
+
x = self.model.ln_final(x)
|
193 |
+
xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
|
194 |
+
return xt, x
|
195 |
+
|
196 |
+
|
197 |
+
def encode_image(self, image):
|
198 |
+
return self.model.visual(image)
|
199 |
+
|
200 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
201 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
202 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
203 |
+
break
|
204 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
205 |
+
x = checkpoint(r, x, attn_mask)
|
206 |
+
else:
|
207 |
+
x = r(x, attn_mask=attn_mask)
|
208 |
+
return x
|
209 |
+
|
210 |
+
def encode(self, text):
|
211 |
+
|
212 |
+
return self(text)
|
UniAnimate/tools/modules/config.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
import os.path as osp
|
4 |
+
from datetime import datetime
|
5 |
+
from easydict import EasyDict
|
6 |
+
import os
|
7 |
+
|
8 |
+
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
9 |
+
|
10 |
+
# -------------------------------distributed training--------------------------
|
11 |
+
pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
|
12 |
+
gpus_per_machine = torch.cuda.device_count()
|
13 |
+
world_size = pmi_world_size * gpus_per_machine
|
14 |
+
# -----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
|
17 |
+
# ---------------------------Dataset Parameter---------------------------------
|
18 |
+
cfg.mean = [0.5, 0.5, 0.5]
|
19 |
+
cfg.std = [0.5, 0.5, 0.5]
|
20 |
+
cfg.max_words = 1000
|
21 |
+
cfg.num_workers = 8
|
22 |
+
cfg.prefetch_factor = 2
|
23 |
+
|
24 |
+
# PlaceHolder
|
25 |
+
cfg.resolution = [448, 256]
|
26 |
+
cfg.vit_out_dim = 1024
|
27 |
+
cfg.vit_resolution = 336
|
28 |
+
cfg.depth_clamp = 10.0
|
29 |
+
cfg.misc_size = 384
|
30 |
+
cfg.depth_std = 20.0
|
31 |
+
|
32 |
+
cfg.save_fps = 8
|
33 |
+
|
34 |
+
cfg.frame_lens = [32, 32, 32, 1]
|
35 |
+
cfg.sample_fps = [4, ]
|
36 |
+
cfg.vid_dataset = {
|
37 |
+
'type': 'VideoBaseDataset',
|
38 |
+
'data_list': [],
|
39 |
+
'max_words': cfg.max_words,
|
40 |
+
'resolution': cfg.resolution}
|
41 |
+
cfg.img_dataset = {
|
42 |
+
'type': 'ImageBaseDataset',
|
43 |
+
'data_list': ['laion_400m',],
|
44 |
+
'max_words': cfg.max_words,
|
45 |
+
'resolution': cfg.resolution}
|
46 |
+
|
47 |
+
cfg.batch_sizes = {
|
48 |
+
str(1):256,
|
49 |
+
str(4):4,
|
50 |
+
str(8):4,
|
51 |
+
str(16):4}
|
52 |
+
# -----------------------------------------------------------------------------
|
53 |
+
|
54 |
+
|
55 |
+
# ---------------------------Mode Parameters-----------------------------------
|
56 |
+
# Diffusion
|
57 |
+
cfg.Diffusion = {
|
58 |
+
'type': 'DiffusionDDIM',
|
59 |
+
'schedule': 'cosine', # cosine
|
60 |
+
'schedule_param': {
|
61 |
+
'num_timesteps': 1000,
|
62 |
+
'cosine_s': 0.008,
|
63 |
+
'zero_terminal_snr': True,
|
64 |
+
},
|
65 |
+
'mean_type': 'v', # [v, eps]
|
66 |
+
'loss_type': 'mse',
|
67 |
+
'var_type': 'fixed_small',
|
68 |
+
'rescale_timesteps': False,
|
69 |
+
'noise_strength': 0.1,
|
70 |
+
'ddim_timesteps': 50
|
71 |
+
}
|
72 |
+
cfg.ddim_timesteps = 50 # official: 250
|
73 |
+
cfg.use_div_loss = False
|
74 |
+
# classifier-free guidance
|
75 |
+
cfg.p_zero = 0.9
|
76 |
+
cfg.guide_scale = 3.0
|
77 |
+
|
78 |
+
# clip vision encoder
|
79 |
+
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
80 |
+
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
81 |
+
|
82 |
+
# sketch
|
83 |
+
cfg.sketch_mean = [0.485, 0.456, 0.406]
|
84 |
+
cfg.sketch_std = [0.229, 0.224, 0.225]
|
85 |
+
# cfg.misc_size = 256
|
86 |
+
cfg.depth_std = 20.0
|
87 |
+
cfg.depth_clamp = 10.0
|
88 |
+
cfg.hist_sigma = 10.0
|
89 |
+
|
90 |
+
# Model
|
91 |
+
cfg.scale_factor = 0.18215
|
92 |
+
cfg.use_checkpoint = True
|
93 |
+
cfg.use_sharded_ddp = False
|
94 |
+
cfg.use_fsdp = False
|
95 |
+
cfg.use_fp16 = True
|
96 |
+
cfg.temporal_attention = True
|
97 |
+
|
98 |
+
cfg.UNet = {
|
99 |
+
'type': 'UNetSD',
|
100 |
+
'in_dim': 4,
|
101 |
+
'dim': 320,
|
102 |
+
'y_dim': cfg.vit_out_dim,
|
103 |
+
'context_dim': 1024,
|
104 |
+
'out_dim': 8,
|
105 |
+
'dim_mult': [1, 2, 4, 4],
|
106 |
+
'num_heads': 8,
|
107 |
+
'head_dim': 64,
|
108 |
+
'num_res_blocks': 2,
|
109 |
+
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
110 |
+
'dropout': 0.1,
|
111 |
+
'temporal_attention': cfg.temporal_attention,
|
112 |
+
'temporal_attn_times': 1,
|
113 |
+
'use_checkpoint': cfg.use_checkpoint,
|
114 |
+
'use_fps_condition': False,
|
115 |
+
'use_sim_mask': False
|
116 |
+
}
|
117 |
+
|
118 |
+
# auotoencoder from stabel diffusion
|
119 |
+
cfg.guidances = []
|
120 |
+
cfg.auto_encoder = {
|
121 |
+
'type': 'AutoencoderKL',
|
122 |
+
'ddconfig': {
|
123 |
+
'double_z': True,
|
124 |
+
'z_channels': 4,
|
125 |
+
'resolution': 256,
|
126 |
+
'in_channels': 3,
|
127 |
+
'out_ch': 3,
|
128 |
+
'ch': 128,
|
129 |
+
'ch_mult': [1, 2, 4, 4],
|
130 |
+
'num_res_blocks': 2,
|
131 |
+
'attn_resolutions': [],
|
132 |
+
'dropout': 0.0,
|
133 |
+
'video_kernel_size': [3, 1, 1]
|
134 |
+
},
|
135 |
+
'embed_dim': 4,
|
136 |
+
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
|
137 |
+
}
|
138 |
+
# clip embedder
|
139 |
+
cfg.embedder = {
|
140 |
+
'type': 'FrozenOpenCLIPEmbedder',
|
141 |
+
'layer': 'penultimate',
|
142 |
+
'pretrained': 'models/open_clip_pytorch_model.bin'
|
143 |
+
}
|
144 |
+
# -----------------------------------------------------------------------------
|
145 |
+
|
146 |
+
# ---------------------------Training Settings---------------------------------
|
147 |
+
# training and optimizer
|
148 |
+
cfg.ema_decay = 0.9999
|
149 |
+
cfg.num_steps = 600000
|
150 |
+
cfg.lr = 5e-5
|
151 |
+
cfg.weight_decay = 0.0
|
152 |
+
cfg.betas = (0.9, 0.999)
|
153 |
+
cfg.eps = 1.0e-8
|
154 |
+
cfg.chunk_size = 16
|
155 |
+
cfg.decoder_bs = 8
|
156 |
+
cfg.alpha = 0.7
|
157 |
+
cfg.save_ckp_interval = 1000
|
158 |
+
|
159 |
+
# scheduler
|
160 |
+
cfg.warmup_steps = 10
|
161 |
+
cfg.decay_mode = 'cosine'
|
162 |
+
|
163 |
+
# acceleration
|
164 |
+
cfg.use_ema = True
|
165 |
+
if world_size<2:
|
166 |
+
cfg.use_ema = False
|
167 |
+
cfg.load_from = None
|
168 |
+
# -----------------------------------------------------------------------------
|
169 |
+
|
170 |
+
|
171 |
+
# ----------------------------Pretrain Settings---------------------------------
|
172 |
+
cfg.Pretrain = {
|
173 |
+
'type': 'pretrain_specific_strategies',
|
174 |
+
'fix_weight': False,
|
175 |
+
'grad_scale': 0.2,
|
176 |
+
'resume_checkpoint': 'models/jiuniu_0267000.pth',
|
177 |
+
'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json',
|
178 |
+
}
|
179 |
+
# -----------------------------------------------------------------------------
|
180 |
+
|
181 |
+
|
182 |
+
# -----------------------------Visual-------------------------------------------
|
183 |
+
# Visual videos
|
184 |
+
cfg.viz_interval = 1000
|
185 |
+
cfg.visual_train = {
|
186 |
+
'type': 'VisualTrainTextImageToVideo',
|
187 |
+
}
|
188 |
+
cfg.visual_inference = {
|
189 |
+
'type': 'VisualGeneratedVideos',
|
190 |
+
}
|
191 |
+
cfg.inference_list_path = ''
|
192 |
+
|
193 |
+
# logging
|
194 |
+
cfg.log_interval = 100
|
195 |
+
|
196 |
+
### Default log_dir
|
197 |
+
cfg.log_dir = 'outputs/'
|
198 |
+
# -----------------------------------------------------------------------------
|
199 |
+
|
200 |
+
|
201 |
+
# ---------------------------Others--------------------------------------------
|
202 |
+
# seed
|
203 |
+
cfg.seed = 8888
|
204 |
+
cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms'
|
205 |
+
# -----------------------------------------------------------------------------
|
206 |
+
|
UniAnimate/tools/modules/diffusions/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion_ddim import *
|
UniAnimate/tools/modules/diffusions/diffusion_ddim.py
ADDED
@@ -0,0 +1,1121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
from utils.registry_class import DIFFUSION
|
5 |
+
from .schedules import beta_schedule, sigma_schedule
|
6 |
+
from .losses import kl_divergence, discretized_gaussian_log_likelihood
|
7 |
+
# from .dpm_solver import NoiseScheduleVP, model_wrapper_guided_diffusion, model_wrapper, DPM_Solver
|
8 |
+
from typing import Callable, List, Optional
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
def _i(tensor, t, x):
|
12 |
+
r"""Index tensor using t and format the output according to x.
|
13 |
+
"""
|
14 |
+
if tensor.device != x.device:
|
15 |
+
tensor = tensor.to(x.device)
|
16 |
+
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
17 |
+
return tensor[t].view(shape).to(x)
|
18 |
+
|
19 |
+
@DIFFUSION.register_class()
|
20 |
+
class DiffusionDDIMSR(object):
|
21 |
+
def __init__(self, reverse_diffusion, forward_diffusion, **kwargs):
|
22 |
+
from .diffusion_gauss import GaussianDiffusion
|
23 |
+
self.reverse_diffusion = GaussianDiffusion(sigmas=sigma_schedule(reverse_diffusion.schedule, **reverse_diffusion.schedule_param),
|
24 |
+
prediction_type=reverse_diffusion.mean_type)
|
25 |
+
self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param),
|
26 |
+
prediction_type=forward_diffusion.mean_type)
|
27 |
+
|
28 |
+
|
29 |
+
@DIFFUSION.register_class()
|
30 |
+
class DiffusionDPM(object):
|
31 |
+
def __init__(self, forward_diffusion, **kwargs):
|
32 |
+
from .diffusion_gauss import GaussianDiffusion
|
33 |
+
self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param),
|
34 |
+
prediction_type=forward_diffusion.mean_type)
|
35 |
+
|
36 |
+
|
37 |
+
@DIFFUSION.register_class()
|
38 |
+
class DiffusionDDIM(object):
|
39 |
+
def __init__(self,
|
40 |
+
schedule='linear_sd',
|
41 |
+
schedule_param={},
|
42 |
+
mean_type='eps',
|
43 |
+
var_type='learned_range',
|
44 |
+
loss_type='mse',
|
45 |
+
epsilon = 1e-12,
|
46 |
+
rescale_timesteps=False,
|
47 |
+
noise_strength=0.0,
|
48 |
+
**kwargs):
|
49 |
+
|
50 |
+
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
|
51 |
+
assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small']
|
52 |
+
assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
|
53 |
+
|
54 |
+
betas = beta_schedule(schedule, **schedule_param)
|
55 |
+
assert min(betas) > 0 and max(betas) <= 1
|
56 |
+
|
57 |
+
if not isinstance(betas, torch.DoubleTensor):
|
58 |
+
betas = torch.tensor(betas, dtype=torch.float64)
|
59 |
+
|
60 |
+
self.betas = betas
|
61 |
+
self.num_timesteps = len(betas)
|
62 |
+
self.mean_type = mean_type # eps
|
63 |
+
self.var_type = var_type # 'fixed_small'
|
64 |
+
self.loss_type = loss_type # mse
|
65 |
+
self.epsilon = epsilon # 1e-12
|
66 |
+
self.rescale_timesteps = rescale_timesteps # False
|
67 |
+
self.noise_strength = noise_strength # 0.0
|
68 |
+
|
69 |
+
# alphas
|
70 |
+
alphas = 1 - self.betas
|
71 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
72 |
+
self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
73 |
+
self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])])
|
74 |
+
|
75 |
+
# q(x_t | x_{t-1})
|
76 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
77 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
78 |
+
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
|
79 |
+
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
80 |
+
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
|
81 |
+
|
82 |
+
# q(x_{t-1} | x_t, x_0)
|
83 |
+
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
84 |
+
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
|
85 |
+
self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
86 |
+
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
87 |
+
|
88 |
+
|
89 |
+
def sample_loss(self, x0, noise=None):
|
90 |
+
if noise is None:
|
91 |
+
noise = torch.randn_like(x0)
|
92 |
+
if self.noise_strength > 0:
|
93 |
+
b, c, f, _, _= x0.shape
|
94 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
|
95 |
+
noise = noise + self.noise_strength * offset_noise
|
96 |
+
return noise
|
97 |
+
|
98 |
+
|
99 |
+
def q_sample(self, x0, t, noise=None):
|
100 |
+
r"""Sample from q(x_t | x_0).
|
101 |
+
"""
|
102 |
+
# noise = torch.randn_like(x0) if noise is None else noise
|
103 |
+
noise = self.sample_loss(x0, noise)
|
104 |
+
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
|
105 |
+
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
|
106 |
+
|
107 |
+
def q_mean_variance(self, x0, t):
|
108 |
+
r"""Distribution of q(x_t | x_0).
|
109 |
+
"""
|
110 |
+
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
111 |
+
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
112 |
+
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
113 |
+
return mu, var, log_var
|
114 |
+
|
115 |
+
def q_posterior_mean_variance(self, x0, xt, t):
|
116 |
+
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
117 |
+
"""
|
118 |
+
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt
|
119 |
+
var = _i(self.posterior_variance, t, xt)
|
120 |
+
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
121 |
+
return mu, var, log_var
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
|
125 |
+
r"""Sample from p(x_{t-1} | x_t).
|
126 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
127 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
128 |
+
"""
|
129 |
+
# predict distribution of p(x_{t-1} | x_t)
|
130 |
+
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
131 |
+
|
132 |
+
# random sample (with optional conditional function)
|
133 |
+
noise = torch.randn_like(xt)
|
134 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0
|
135 |
+
if condition_fn is not None:
|
136 |
+
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
137 |
+
mu = mu.float() + var * grad.float()
|
138 |
+
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
139 |
+
return xt_1, x0
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
|
143 |
+
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
144 |
+
"""
|
145 |
+
# prepare input
|
146 |
+
b = noise.size(0)
|
147 |
+
xt = noise
|
148 |
+
|
149 |
+
# diffusion process
|
150 |
+
for step in torch.arange(self.num_timesteps).flip(0):
|
151 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
152 |
+
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale)
|
153 |
+
return xt
|
154 |
+
|
155 |
+
def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None):
|
156 |
+
r"""Distribution of p(x_{t-1} | x_t).
|
157 |
+
"""
|
158 |
+
# predict distribution
|
159 |
+
if guide_scale is None:
|
160 |
+
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
161 |
+
else:
|
162 |
+
# classifier-free guidance
|
163 |
+
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
164 |
+
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
165 |
+
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
166 |
+
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
167 |
+
dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
|
168 |
+
out = torch.cat([
|
169 |
+
u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]),
|
170 |
+
y_out[:, dim:]], dim=1) # guide_scale=9.0
|
171 |
+
|
172 |
+
# compute variance
|
173 |
+
if self.var_type == 'learned':
|
174 |
+
out, log_var = out.chunk(2, dim=1)
|
175 |
+
var = torch.exp(log_var)
|
176 |
+
elif self.var_type == 'learned_range':
|
177 |
+
out, fraction = out.chunk(2, dim=1)
|
178 |
+
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
179 |
+
max_log_var = _i(torch.log(self.betas), t, xt)
|
180 |
+
fraction = (fraction + 1) / 2.0
|
181 |
+
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
182 |
+
var = torch.exp(log_var)
|
183 |
+
elif self.var_type == 'fixed_large':
|
184 |
+
var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
|
185 |
+
log_var = torch.log(var)
|
186 |
+
elif self.var_type == 'fixed_small':
|
187 |
+
var = _i(self.posterior_variance, t, xt)
|
188 |
+
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
189 |
+
|
190 |
+
# compute mean and x0
|
191 |
+
if self.mean_type == 'x_{t-1}':
|
192 |
+
mu = out # x_{t-1}
|
193 |
+
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
|
194 |
+
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
|
195 |
+
elif self.mean_type == 'x0':
|
196 |
+
x0 = out
|
197 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
198 |
+
elif self.mean_type == 'eps':
|
199 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
200 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
201 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
202 |
+
elif self.mean_type == 'v':
|
203 |
+
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
|
204 |
+
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
|
205 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
206 |
+
|
207 |
+
# restrict the range of x0
|
208 |
+
if percentile is not None:
|
209 |
+
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
210 |
+
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
211 |
+
x0 = torch.min(s, torch.max(-s, x0)) / s
|
212 |
+
elif clamp is not None:
|
213 |
+
x0 = x0.clamp(-clamp, clamp)
|
214 |
+
return mu, var, log_var, x0
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
|
218 |
+
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
219 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
220 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
221 |
+
"""
|
222 |
+
stride = self.num_timesteps // ddim_timesteps
|
223 |
+
|
224 |
+
# predict distribution of p(x_{t-1} | x_t)
|
225 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
226 |
+
if condition_fn is not None:
|
227 |
+
# x0 -> eps
|
228 |
+
alpha = _i(self.alphas_cumprod, t, xt)
|
229 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
230 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
231 |
+
eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
232 |
+
|
233 |
+
# eps -> x0
|
234 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
235 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
236 |
+
|
237 |
+
# derive variables
|
238 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
239 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
240 |
+
alphas = _i(self.alphas_cumprod, t, xt)
|
241 |
+
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
242 |
+
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
243 |
+
|
244 |
+
# random sample
|
245 |
+
noise = torch.randn_like(xt)
|
246 |
+
direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
|
247 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
248 |
+
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
249 |
+
return xt_1, x0
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0):
|
253 |
+
# prepare input
|
254 |
+
b = noise.size(0)
|
255 |
+
xt = noise
|
256 |
+
|
257 |
+
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
258 |
+
steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
|
259 |
+
from tqdm import tqdm
|
260 |
+
for step in tqdm(steps):
|
261 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
262 |
+
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta)
|
263 |
+
# from ipdb import set_trace; set_trace()
|
264 |
+
return xt
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
|
268 |
+
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
269 |
+
"""
|
270 |
+
stride = self.num_timesteps // ddim_timesteps
|
271 |
+
|
272 |
+
# predict distribution of p(x_{t-1} | x_t)
|
273 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
274 |
+
|
275 |
+
# derive variables
|
276 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
277 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
278 |
+
alphas_next = _i(
|
279 |
+
torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]),
|
280 |
+
(t + stride).clamp(0, self.num_timesteps), xt)
|
281 |
+
|
282 |
+
# reverse sample
|
283 |
+
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
284 |
+
return mu, x0
|
285 |
+
|
286 |
+
@torch.no_grad()
|
287 |
+
def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
|
288 |
+
# prepare input
|
289 |
+
b = x0.size(0)
|
290 |
+
xt = x0
|
291 |
+
|
292 |
+
# reconstruction steps
|
293 |
+
steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)
|
294 |
+
for step in steps:
|
295 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
296 |
+
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps)
|
297 |
+
return xt
|
298 |
+
|
299 |
+
@torch.no_grad()
|
300 |
+
def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
|
301 |
+
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
302 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
303 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
304 |
+
"""
|
305 |
+
stride = self.num_timesteps // plms_timesteps
|
306 |
+
|
307 |
+
# function for compute eps
|
308 |
+
def compute_eps(xt, t):
|
309 |
+
# predict distribution of p(x_{t-1} | x_t)
|
310 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
311 |
+
|
312 |
+
# condition
|
313 |
+
if condition_fn is not None:
|
314 |
+
# x0 -> eps
|
315 |
+
alpha = _i(self.alphas_cumprod, t, xt)
|
316 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
317 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
318 |
+
eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
319 |
+
|
320 |
+
# eps -> x0
|
321 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
322 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
323 |
+
|
324 |
+
# derive eps
|
325 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
326 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
327 |
+
return eps
|
328 |
+
|
329 |
+
# function for compute x_0 and x_{t-1}
|
330 |
+
def compute_x0(eps, t):
|
331 |
+
# eps -> x0
|
332 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
333 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
334 |
+
|
335 |
+
# deterministic sample
|
336 |
+
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
337 |
+
direction = torch.sqrt(1 - alphas_prev) * eps
|
338 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
339 |
+
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
340 |
+
return xt_1, x0
|
341 |
+
|
342 |
+
# PLMS sample
|
343 |
+
eps = compute_eps(xt, t)
|
344 |
+
if len(eps_cache) == 0:
|
345 |
+
# 2nd order pseudo improved Euler
|
346 |
+
xt_1, x0 = compute_x0(eps, t)
|
347 |
+
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
348 |
+
eps_prime = (eps + eps_next) / 2.0
|
349 |
+
elif len(eps_cache) == 1:
|
350 |
+
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
351 |
+
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
352 |
+
elif len(eps_cache) == 2:
|
353 |
+
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
354 |
+
eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0
|
355 |
+
elif len(eps_cache) >= 3:
|
356 |
+
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
357 |
+
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0
|
358 |
+
xt_1, x0 = compute_x0(eps_prime, t)
|
359 |
+
return xt_1, x0, eps
|
360 |
+
|
361 |
+
@torch.no_grad()
|
362 |
+
def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
|
363 |
+
# prepare input
|
364 |
+
b = noise.size(0)
|
365 |
+
xt = noise
|
366 |
+
|
367 |
+
# diffusion process
|
368 |
+
steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
|
369 |
+
eps_cache = []
|
370 |
+
for step in steps:
|
371 |
+
# PLMS sampling step
|
372 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
373 |
+
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache)
|
374 |
+
|
375 |
+
# update eps cache
|
376 |
+
eps_cache.append(eps)
|
377 |
+
if len(eps_cache) >= 4:
|
378 |
+
eps_cache.pop(0)
|
379 |
+
return xt
|
380 |
+
|
381 |
+
def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None):
|
382 |
+
|
383 |
+
# noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32]
|
384 |
+
noise = self.sample_loss(x0, noise)
|
385 |
+
|
386 |
+
xt = self.q_sample(x0, t, noise=noise)
|
387 |
+
|
388 |
+
# compute loss
|
389 |
+
if self.loss_type in ['kl', 'rescaled_kl']:
|
390 |
+
loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
|
391 |
+
if self.loss_type == 'rescaled_kl':
|
392 |
+
loss = loss * self.num_timesteps
|
393 |
+
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse
|
394 |
+
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
395 |
+
|
396 |
+
# VLB for variation
|
397 |
+
loss_vlb = 0.0
|
398 |
+
if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small'
|
399 |
+
out, var = out.chunk(2, dim=1)
|
400 |
+
frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
|
401 |
+
loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
|
402 |
+
if self.loss_type.startswith('rescaled_'):
|
403 |
+
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
404 |
+
|
405 |
+
# MSE/L1 for x0/eps
|
406 |
+
# target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
|
407 |
+
target = {
|
408 |
+
'eps': noise,
|
409 |
+
'x0': x0,
|
410 |
+
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0],
|
411 |
+
'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type]
|
412 |
+
if loss_mask is not None:
|
413 |
+
loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same)
|
414 |
+
loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w
|
415 |
+
# use masked diffusion
|
416 |
+
loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
|
417 |
+
else:
|
418 |
+
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
|
419 |
+
if weight is not None:
|
420 |
+
loss = loss*weight
|
421 |
+
|
422 |
+
# div loss
|
423 |
+
if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1:
|
424 |
+
|
425 |
+
# derive x0
|
426 |
+
x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
427 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
428 |
+
|
429 |
+
# # derive xt_1, set eta=0 as ddim
|
430 |
+
# alphas_prev = _i(self.alphas_cumprod, (t - 1).clamp(0), xt)
|
431 |
+
# direction = torch.sqrt(1 - alphas_prev) * out
|
432 |
+
# xt_1 = torch.sqrt(alphas_prev) * x0_ + direction
|
433 |
+
|
434 |
+
# ncfhw, std on f
|
435 |
+
div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4)
|
436 |
+
# print(div_loss,loss)
|
437 |
+
loss = loss+div_loss
|
438 |
+
|
439 |
+
# total loss
|
440 |
+
loss = loss + loss_vlb
|
441 |
+
elif self.loss_type in ['charbonnier']:
|
442 |
+
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
443 |
+
|
444 |
+
# VLB for variation
|
445 |
+
loss_vlb = 0.0
|
446 |
+
if self.var_type in ['learned', 'learned_range']:
|
447 |
+
out, var = out.chunk(2, dim=1)
|
448 |
+
frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
|
449 |
+
loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
|
450 |
+
if self.loss_type.startswith('rescaled_'):
|
451 |
+
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
452 |
+
|
453 |
+
# MSE/L1 for x0/eps
|
454 |
+
target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
|
455 |
+
loss = torch.sqrt((out - target)**2 + self.epsilon)
|
456 |
+
if weight is not None:
|
457 |
+
loss = loss*weight
|
458 |
+
loss = loss.flatten(1).mean(dim=1)
|
459 |
+
|
460 |
+
# total loss
|
461 |
+
loss = loss + loss_vlb
|
462 |
+
return loss
|
463 |
+
|
464 |
+
def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None):
|
465 |
+
# compute groundtruth and predicted distributions
|
466 |
+
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
|
467 |
+
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile)
|
468 |
+
|
469 |
+
# compute KL loss
|
470 |
+
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
|
471 |
+
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
|
472 |
+
|
473 |
+
# compute discretized NLL loss (for p(x0 | x1) only)
|
474 |
+
nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2)
|
475 |
+
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
|
476 |
+
|
477 |
+
# NLL for p(x0 | x1) and KL otherwise
|
478 |
+
vlb = torch.where(t == 0, nll, kl)
|
479 |
+
return vlb, x0
|
480 |
+
|
481 |
+
@torch.no_grad()
|
482 |
+
def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None):
|
483 |
+
r"""Compute the entire variational lower bound, measured in bits-per-dim.
|
484 |
+
"""
|
485 |
+
# prepare input and output
|
486 |
+
b = x0.size(0)
|
487 |
+
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
|
488 |
+
|
489 |
+
# loop
|
490 |
+
for step in torch.arange(self.num_timesteps).flip(0):
|
491 |
+
# compute VLB
|
492 |
+
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
|
493 |
+
# noise = torch.randn_like(x0)
|
494 |
+
noise = self.sample_loss(x0)
|
495 |
+
xt = self.q_sample(x0, t, noise)
|
496 |
+
vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile)
|
497 |
+
|
498 |
+
# predict eps from x0
|
499 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
500 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
501 |
+
|
502 |
+
# collect metrics
|
503 |
+
metrics['vlb'].append(vlb)
|
504 |
+
metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1))
|
505 |
+
metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1))
|
506 |
+
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
|
507 |
+
|
508 |
+
# compute the prior KL term for VLB, measured in bits-per-dim
|
509 |
+
mu, _, log_var = self.q_mean_variance(x0, t)
|
510 |
+
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var))
|
511 |
+
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
|
512 |
+
|
513 |
+
# update metrics
|
514 |
+
metrics['prior_bits_per_dim'] = kl_prior
|
515 |
+
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
|
516 |
+
return metrics
|
517 |
+
|
518 |
+
def _scale_timesteps(self, t):
|
519 |
+
if self.rescale_timesteps:
|
520 |
+
return t.float() * 1000.0 / self.num_timesteps
|
521 |
+
return t
|
522 |
+
#return t.float()
|
523 |
+
|
524 |
+
|
525 |
+
|
526 |
+
|
527 |
+
|
528 |
+
|
529 |
+
@DIFFUSION.register_class()
|
530 |
+
class DiffusionDDIMLong(object):
|
531 |
+
def __init__(self,
|
532 |
+
schedule='linear_sd',
|
533 |
+
schedule_param={},
|
534 |
+
mean_type='eps',
|
535 |
+
var_type='learned_range',
|
536 |
+
loss_type='mse',
|
537 |
+
epsilon = 1e-12,
|
538 |
+
rescale_timesteps=False,
|
539 |
+
noise_strength=0.0,
|
540 |
+
**kwargs):
|
541 |
+
|
542 |
+
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
|
543 |
+
assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small']
|
544 |
+
assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier']
|
545 |
+
|
546 |
+
betas = beta_schedule(schedule, **schedule_param)
|
547 |
+
assert min(betas) > 0 and max(betas) <= 1
|
548 |
+
|
549 |
+
if not isinstance(betas, torch.DoubleTensor):
|
550 |
+
betas = torch.tensor(betas, dtype=torch.float64)
|
551 |
+
|
552 |
+
self.betas = betas
|
553 |
+
self.num_timesteps = len(betas)
|
554 |
+
self.mean_type = mean_type # v
|
555 |
+
self.var_type = var_type # 'fixed_small'
|
556 |
+
self.loss_type = loss_type # mse
|
557 |
+
self.epsilon = epsilon # 1e-12
|
558 |
+
self.rescale_timesteps = rescale_timesteps # False
|
559 |
+
self.noise_strength = noise_strength
|
560 |
+
|
561 |
+
# alphas
|
562 |
+
alphas = 1 - self.betas
|
563 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
564 |
+
self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
565 |
+
self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])])
|
566 |
+
|
567 |
+
# q(x_t | x_{t-1})
|
568 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
569 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
|
570 |
+
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
|
571 |
+
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
572 |
+
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
|
573 |
+
|
574 |
+
# q(x_{t-1} | x_t, x_0)
|
575 |
+
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
576 |
+
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20))
|
577 |
+
self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
578 |
+
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
579 |
+
|
580 |
+
|
581 |
+
def sample_loss(self, x0, noise=None):
|
582 |
+
if noise is None:
|
583 |
+
noise = torch.randn_like(x0)
|
584 |
+
if self.noise_strength > 0:
|
585 |
+
b, c, f, _, _= x0.shape
|
586 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
|
587 |
+
noise = noise + self.noise_strength * offset_noise
|
588 |
+
return noise
|
589 |
+
|
590 |
+
|
591 |
+
def q_sample(self, x0, t, noise=None):
|
592 |
+
r"""Sample from q(x_t | x_0).
|
593 |
+
"""
|
594 |
+
# noise = torch.randn_like(x0) if noise is None else noise
|
595 |
+
noise = self.sample_loss(x0, noise)
|
596 |
+
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
|
597 |
+
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
|
598 |
+
|
599 |
+
def q_mean_variance(self, x0, t):
|
600 |
+
r"""Distribution of q(x_t | x_0).
|
601 |
+
"""
|
602 |
+
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
603 |
+
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
604 |
+
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
605 |
+
return mu, var, log_var
|
606 |
+
|
607 |
+
def q_posterior_mean_variance(self, x0, xt, t):
|
608 |
+
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
609 |
+
"""
|
610 |
+
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt
|
611 |
+
var = _i(self.posterior_variance, t, xt)
|
612 |
+
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
613 |
+
return mu, var, log_var
|
614 |
+
|
615 |
+
@torch.no_grad()
|
616 |
+
def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
|
617 |
+
r"""Sample from p(x_{t-1} | x_t).
|
618 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
619 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
620 |
+
"""
|
621 |
+
# predict distribution of p(x_{t-1} | x_t)
|
622 |
+
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
623 |
+
|
624 |
+
# random sample (with optional conditional function)
|
625 |
+
noise = torch.randn_like(xt)
|
626 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0
|
627 |
+
if condition_fn is not None:
|
628 |
+
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
629 |
+
mu = mu.float() + var * grad.float()
|
630 |
+
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
631 |
+
return xt_1, x0
|
632 |
+
|
633 |
+
@torch.no_grad()
|
634 |
+
def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None):
|
635 |
+
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
636 |
+
"""
|
637 |
+
# prepare input
|
638 |
+
b = noise.size(0)
|
639 |
+
xt = noise
|
640 |
+
|
641 |
+
# diffusion process
|
642 |
+
for step in torch.arange(self.num_timesteps).flip(0):
|
643 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
644 |
+
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale)
|
645 |
+
return xt
|
646 |
+
|
647 |
+
def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1):
|
648 |
+
r"""Distribution of p(x_{t-1} | x_t).
|
649 |
+
"""
|
650 |
+
noise = xt
|
651 |
+
context_queue = list(
|
652 |
+
context_scheduler(
|
653 |
+
0,
|
654 |
+
31,
|
655 |
+
noise.shape[2],
|
656 |
+
context_size=context_size,
|
657 |
+
context_stride=context_stride,
|
658 |
+
context_overlap=context_overlap,
|
659 |
+
)
|
660 |
+
)
|
661 |
+
context_step = min(
|
662 |
+
context_stride, int(np.ceil(np.log2(noise.shape[2] / context_size))) + 1
|
663 |
+
)
|
664 |
+
# replace the final segment to improve temporal consistency
|
665 |
+
num_frames = noise.shape[2]
|
666 |
+
context_queue[-1] = [
|
667 |
+
e % num_frames
|
668 |
+
for e in range(num_frames - context_size * context_step, num_frames, context_step)
|
669 |
+
]
|
670 |
+
|
671 |
+
import math
|
672 |
+
# context_batch_size = 1
|
673 |
+
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
|
674 |
+
global_context = []
|
675 |
+
for i in range(num_context_batches):
|
676 |
+
global_context.append(
|
677 |
+
context_queue[
|
678 |
+
i * context_batch_size : (i + 1) * context_batch_size
|
679 |
+
]
|
680 |
+
)
|
681 |
+
noise_pred = torch.zeros_like(noise)
|
682 |
+
noise_pred_uncond = torch.zeros_like(noise)
|
683 |
+
counter = torch.zeros(
|
684 |
+
(1, 1, xt.shape[2], 1, 1),
|
685 |
+
device=xt.device,
|
686 |
+
dtype=xt.dtype,
|
687 |
+
)
|
688 |
+
|
689 |
+
for i_index, context in enumerate(global_context):
|
690 |
+
|
691 |
+
|
692 |
+
latent_model_input = torch.cat([xt[:, :, c] for c in context])
|
693 |
+
bs_context = len(context)
|
694 |
+
|
695 |
+
model_kwargs_new = [{
|
696 |
+
'y': None,
|
697 |
+
"local_image": None if not model_kwargs[0].__contains__('local_image') else torch.cat([model_kwargs[0]["local_image"][:, :, c] for c in context]),
|
698 |
+
'image': None if not model_kwargs[0].__contains__('image') else model_kwargs[0]["image"].repeat(bs_context, 1, 1),
|
699 |
+
'dwpose': None if not model_kwargs[0].__contains__('dwpose') else torch.cat([model_kwargs[0]["dwpose"][:, :, [0]+[ii+1 for ii in c]] for c in context]),
|
700 |
+
'randomref': None if not model_kwargs[0].__contains__('randomref') else torch.cat([model_kwargs[0]["randomref"][:, :, c] for c in context]),
|
701 |
+
},
|
702 |
+
{
|
703 |
+
'y': None,
|
704 |
+
"local_image": None,
|
705 |
+
'image': None,
|
706 |
+
'randomref': None,
|
707 |
+
'dwpose': None,
|
708 |
+
}]
|
709 |
+
|
710 |
+
if guide_scale is None:
|
711 |
+
out = model(latent_model_input, self._scale_timesteps(t), **model_kwargs)
|
712 |
+
for j, c in enumerate(context):
|
713 |
+
noise_pred[:, :, c] = noise_pred[:, :, c] + out
|
714 |
+
counter[:, :, c] = counter[:, :, c] + 1
|
715 |
+
else:
|
716 |
+
# classifier-free guidance
|
717 |
+
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
718 |
+
# assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
719 |
+
y_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[0])
|
720 |
+
u_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[1])
|
721 |
+
dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2
|
722 |
+
for j, c in enumerate(context):
|
723 |
+
noise_pred[:, :, c] = noise_pred[:, :, c] + y_out[j:j+1]
|
724 |
+
noise_pred_uncond[:, :, c] = noise_pred_uncond[:, :, c] + u_out[j:j+1]
|
725 |
+
counter[:, :, c] = counter[:, :, c] + 1
|
726 |
+
|
727 |
+
noise_pred = noise_pred / counter
|
728 |
+
noise_pred_uncond = noise_pred_uncond / counter
|
729 |
+
out = torch.cat([
|
730 |
+
noise_pred_uncond[:, :dim] + guide_scale * (noise_pred[:, :dim] - noise_pred_uncond[:, :dim]),
|
731 |
+
noise_pred[:, dim:]], dim=1) # guide_scale=2.5
|
732 |
+
|
733 |
+
|
734 |
+
# compute variance
|
735 |
+
if self.var_type == 'learned':
|
736 |
+
out, log_var = out.chunk(2, dim=1)
|
737 |
+
var = torch.exp(log_var)
|
738 |
+
elif self.var_type == 'learned_range':
|
739 |
+
out, fraction = out.chunk(2, dim=1)
|
740 |
+
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
741 |
+
max_log_var = _i(torch.log(self.betas), t, xt)
|
742 |
+
fraction = (fraction + 1) / 2.0
|
743 |
+
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
744 |
+
var = torch.exp(log_var)
|
745 |
+
elif self.var_type == 'fixed_large':
|
746 |
+
var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt)
|
747 |
+
log_var = torch.log(var)
|
748 |
+
elif self.var_type == 'fixed_small':
|
749 |
+
var = _i(self.posterior_variance, t, xt)
|
750 |
+
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
751 |
+
|
752 |
+
# compute mean and x0
|
753 |
+
if self.mean_type == 'x_{t-1}':
|
754 |
+
mu = out # x_{t-1}
|
755 |
+
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
|
756 |
+
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt
|
757 |
+
elif self.mean_type == 'x0':
|
758 |
+
x0 = out
|
759 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
760 |
+
elif self.mean_type == 'eps':
|
761 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
762 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
763 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
764 |
+
elif self.mean_type == 'v':
|
765 |
+
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \
|
766 |
+
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out
|
767 |
+
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
768 |
+
|
769 |
+
# restrict the range of x0
|
770 |
+
if percentile is not None:
|
771 |
+
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
772 |
+
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
773 |
+
x0 = torch.min(s, torch.max(-s, x0)) / s
|
774 |
+
elif clamp is not None:
|
775 |
+
x0 = x0.clamp(-clamp, clamp)
|
776 |
+
return mu, var, log_var, x0
|
777 |
+
|
778 |
+
@torch.no_grad()
|
779 |
+
def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1):
|
780 |
+
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
781 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
782 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
783 |
+
"""
|
784 |
+
stride = self.num_timesteps // ddim_timesteps
|
785 |
+
|
786 |
+
# predict distribution of p(x_{t-1} | x_t)
|
787 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale, context_size, context_stride, context_overlap, context_batch_size)
|
788 |
+
if condition_fn is not None:
|
789 |
+
# x0 -> eps
|
790 |
+
alpha = _i(self.alphas_cumprod, t, xt)
|
791 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
792 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
793 |
+
eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
794 |
+
|
795 |
+
# eps -> x0
|
796 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
797 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
798 |
+
|
799 |
+
# derive variables
|
800 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
801 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
802 |
+
alphas = _i(self.alphas_cumprod, t, xt)
|
803 |
+
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
804 |
+
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
805 |
+
|
806 |
+
# random sample
|
807 |
+
noise = torch.randn_like(xt)
|
808 |
+
direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps
|
809 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
810 |
+
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
811 |
+
return xt_1, x0
|
812 |
+
|
813 |
+
@torch.no_grad()
|
814 |
+
def ddim_sample_loop(self, noise, context_size, context_stride, context_overlap, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_batch_size=1):
|
815 |
+
# prepare input
|
816 |
+
b = noise.size(0)
|
817 |
+
xt = noise
|
818 |
+
|
819 |
+
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
820 |
+
steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
|
821 |
+
from tqdm import tqdm
|
822 |
+
|
823 |
+
for step in tqdm(steps):
|
824 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
825 |
+
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta, context_size=context_size, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size)
|
826 |
+
return xt
|
827 |
+
|
828 |
+
@torch.no_grad()
|
829 |
+
def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
|
830 |
+
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
831 |
+
"""
|
832 |
+
stride = self.num_timesteps // ddim_timesteps
|
833 |
+
|
834 |
+
# predict distribution of p(x_{t-1} | x_t)
|
835 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
836 |
+
|
837 |
+
# derive variables
|
838 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
839 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
840 |
+
alphas_next = _i(
|
841 |
+
torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]),
|
842 |
+
(t + stride).clamp(0, self.num_timesteps), xt)
|
843 |
+
|
844 |
+
# reverse sample
|
845 |
+
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
846 |
+
return mu, x0
|
847 |
+
|
848 |
+
@torch.no_grad()
|
849 |
+
def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20):
|
850 |
+
# prepare input
|
851 |
+
b = x0.size(0)
|
852 |
+
xt = x0
|
853 |
+
|
854 |
+
# reconstruction steps
|
855 |
+
steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)
|
856 |
+
for step in steps:
|
857 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
858 |
+
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps)
|
859 |
+
return xt
|
860 |
+
|
861 |
+
@torch.no_grad()
|
862 |
+
def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
|
863 |
+
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
864 |
+
- condition_fn: for classifier-based guidance (guided-diffusion).
|
865 |
+
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
866 |
+
"""
|
867 |
+
stride = self.num_timesteps // plms_timesteps
|
868 |
+
|
869 |
+
# function for compute eps
|
870 |
+
def compute_eps(xt, t):
|
871 |
+
# predict distribution of p(x_{t-1} | x_t)
|
872 |
+
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale)
|
873 |
+
|
874 |
+
# condition
|
875 |
+
if condition_fn is not None:
|
876 |
+
# x0 -> eps
|
877 |
+
alpha = _i(self.alphas_cumprod, t, xt)
|
878 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
879 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
880 |
+
eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
881 |
+
|
882 |
+
# eps -> x0
|
883 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
884 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
885 |
+
|
886 |
+
# derive eps
|
887 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
888 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
889 |
+
return eps
|
890 |
+
|
891 |
+
# function for compute x_0 and x_{t-1}
|
892 |
+
def compute_x0(eps, t):
|
893 |
+
# eps -> x0
|
894 |
+
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
895 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
896 |
+
|
897 |
+
# deterministic sample
|
898 |
+
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
899 |
+
direction = torch.sqrt(1 - alphas_prev) * eps
|
900 |
+
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
901 |
+
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
902 |
+
return xt_1, x0
|
903 |
+
|
904 |
+
# PLMS sample
|
905 |
+
eps = compute_eps(xt, t)
|
906 |
+
if len(eps_cache) == 0:
|
907 |
+
# 2nd order pseudo improved Euler
|
908 |
+
xt_1, x0 = compute_x0(eps, t)
|
909 |
+
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
910 |
+
eps_prime = (eps + eps_next) / 2.0
|
911 |
+
elif len(eps_cache) == 1:
|
912 |
+
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
913 |
+
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
914 |
+
elif len(eps_cache) == 2:
|
915 |
+
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
916 |
+
eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0
|
917 |
+
elif len(eps_cache) >= 3:
|
918 |
+
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
919 |
+
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0
|
920 |
+
xt_1, x0 = compute_x0(eps_prime, t)
|
921 |
+
return xt_1, x0, eps
|
922 |
+
|
923 |
+
@torch.no_grad()
|
924 |
+
def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20):
|
925 |
+
# prepare input
|
926 |
+
b = noise.size(0)
|
927 |
+
xt = noise
|
928 |
+
|
929 |
+
# diffusion process
|
930 |
+
steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0)
|
931 |
+
eps_cache = []
|
932 |
+
for step in steps:
|
933 |
+
# PLMS sampling step
|
934 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
935 |
+
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache)
|
936 |
+
|
937 |
+
# update eps cache
|
938 |
+
eps_cache.append(eps)
|
939 |
+
if len(eps_cache) >= 4:
|
940 |
+
eps_cache.pop(0)
|
941 |
+
return xt
|
942 |
+
|
943 |
+
def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None):
|
944 |
+
|
945 |
+
# noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32]
|
946 |
+
noise = self.sample_loss(x0, noise)
|
947 |
+
|
948 |
+
xt = self.q_sample(x0, t, noise=noise)
|
949 |
+
|
950 |
+
# compute loss
|
951 |
+
if self.loss_type in ['kl', 'rescaled_kl']:
|
952 |
+
loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs)
|
953 |
+
if self.loss_type == 'rescaled_kl':
|
954 |
+
loss = loss * self.num_timesteps
|
955 |
+
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse
|
956 |
+
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
957 |
+
|
958 |
+
# VLB for variation
|
959 |
+
loss_vlb = 0.0
|
960 |
+
if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small'
|
961 |
+
out, var = out.chunk(2, dim=1)
|
962 |
+
frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
|
963 |
+
loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
|
964 |
+
if self.loss_type.startswith('rescaled_'):
|
965 |
+
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
966 |
+
|
967 |
+
# MSE/L1 for x0/eps
|
968 |
+
# target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
|
969 |
+
target = {
|
970 |
+
'eps': noise,
|
971 |
+
'x0': x0,
|
972 |
+
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0],
|
973 |
+
'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type]
|
974 |
+
if loss_mask is not None:
|
975 |
+
loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same)
|
976 |
+
loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w
|
977 |
+
# use masked diffusion
|
978 |
+
loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
|
979 |
+
else:
|
980 |
+
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1)
|
981 |
+
if weight is not None:
|
982 |
+
loss = loss*weight
|
983 |
+
|
984 |
+
# div loss
|
985 |
+
if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1:
|
986 |
+
|
987 |
+
# derive x0
|
988 |
+
x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
|
989 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
990 |
+
|
991 |
+
|
992 |
+
# ncfhw, std on f
|
993 |
+
div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4)
|
994 |
+
# print(div_loss,loss)
|
995 |
+
loss = loss+div_loss
|
996 |
+
|
997 |
+
# total loss
|
998 |
+
loss = loss + loss_vlb
|
999 |
+
elif self.loss_type in ['charbonnier']:
|
1000 |
+
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
1001 |
+
|
1002 |
+
# VLB for variation
|
1003 |
+
loss_vlb = 0.0
|
1004 |
+
if self.var_type in ['learned', 'learned_range']:
|
1005 |
+
out, var = out.chunk(2, dim=1)
|
1006 |
+
frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean
|
1007 |
+
loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen)
|
1008 |
+
if self.loss_type.startswith('rescaled_'):
|
1009 |
+
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
1010 |
+
|
1011 |
+
# MSE/L1 for x0/eps
|
1012 |
+
target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type]
|
1013 |
+
loss = torch.sqrt((out - target)**2 + self.epsilon)
|
1014 |
+
if weight is not None:
|
1015 |
+
loss = loss*weight
|
1016 |
+
loss = loss.flatten(1).mean(dim=1)
|
1017 |
+
|
1018 |
+
# total loss
|
1019 |
+
loss = loss + loss_vlb
|
1020 |
+
return loss
|
1021 |
+
|
1022 |
+
def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None):
|
1023 |
+
# compute groundtruth and predicted distributions
|
1024 |
+
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
|
1025 |
+
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile)
|
1026 |
+
|
1027 |
+
# compute KL loss
|
1028 |
+
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
|
1029 |
+
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
|
1030 |
+
|
1031 |
+
# compute discretized NLL loss (for p(x0 | x1) only)
|
1032 |
+
nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2)
|
1033 |
+
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
|
1034 |
+
|
1035 |
+
# NLL for p(x0 | x1) and KL otherwise
|
1036 |
+
vlb = torch.where(t == 0, nll, kl)
|
1037 |
+
return vlb, x0
|
1038 |
+
|
1039 |
+
@torch.no_grad()
|
1040 |
+
def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None):
|
1041 |
+
r"""Compute the entire variational lower bound, measured in bits-per-dim.
|
1042 |
+
"""
|
1043 |
+
# prepare input and output
|
1044 |
+
b = x0.size(0)
|
1045 |
+
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
|
1046 |
+
|
1047 |
+
# loop
|
1048 |
+
for step in torch.arange(self.num_timesteps).flip(0):
|
1049 |
+
# compute VLB
|
1050 |
+
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
|
1051 |
+
# noise = torch.randn_like(x0)
|
1052 |
+
noise = self.sample_loss(x0)
|
1053 |
+
xt = self.q_sample(x0, t, noise)
|
1054 |
+
vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile)
|
1055 |
+
|
1056 |
+
# predict eps from x0
|
1057 |
+
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
|
1058 |
+
_i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
1059 |
+
|
1060 |
+
# collect metrics
|
1061 |
+
metrics['vlb'].append(vlb)
|
1062 |
+
metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1))
|
1063 |
+
metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1))
|
1064 |
+
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
|
1065 |
+
|
1066 |
+
# compute the prior KL term for VLB, measured in bits-per-dim
|
1067 |
+
mu, _, log_var = self.q_mean_variance(x0, t)
|
1068 |
+
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var))
|
1069 |
+
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
|
1070 |
+
|
1071 |
+
# update metrics
|
1072 |
+
metrics['prior_bits_per_dim'] = kl_prior
|
1073 |
+
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
|
1074 |
+
return metrics
|
1075 |
+
|
1076 |
+
def _scale_timesteps(self, t):
|
1077 |
+
if self.rescale_timesteps:
|
1078 |
+
return t.float() * 1000.0 / self.num_timesteps
|
1079 |
+
return t
|
1080 |
+
#return t.float()
|
1081 |
+
|
1082 |
+
|
1083 |
+
|
1084 |
+
def ordered_halving(val):
|
1085 |
+
bin_str = f"{val:064b}"
|
1086 |
+
bin_flip = bin_str[::-1]
|
1087 |
+
as_int = int(bin_flip, 2)
|
1088 |
+
|
1089 |
+
return as_int / (1 << 64)
|
1090 |
+
|
1091 |
+
|
1092 |
+
def context_scheduler(
|
1093 |
+
step: int = ...,
|
1094 |
+
num_steps: Optional[int] = None,
|
1095 |
+
num_frames: int = ...,
|
1096 |
+
context_size: Optional[int] = None,
|
1097 |
+
context_stride: int = 3,
|
1098 |
+
context_overlap: int = 4,
|
1099 |
+
closed_loop: bool = False,
|
1100 |
+
):
|
1101 |
+
if num_frames <= context_size:
|
1102 |
+
yield list(range(num_frames))
|
1103 |
+
return
|
1104 |
+
|
1105 |
+
context_stride = min(
|
1106 |
+
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
for context_step in 1 << np.arange(context_stride):
|
1110 |
+
pad = int(round(num_frames * ordered_halving(step)))
|
1111 |
+
for j in range(
|
1112 |
+
int(ordered_halving(step) * context_step) + pad,
|
1113 |
+
num_frames + pad + (0 if closed_loop else -context_overlap),
|
1114 |
+
(context_size * context_step - context_overlap),
|
1115 |
+
):
|
1116 |
+
|
1117 |
+
yield [
|
1118 |
+
e % num_frames
|
1119 |
+
for e in range(j, j + context_size * context_step, context_step)
|
1120 |
+
]
|
1121 |
+
|
UniAnimate/tools/modules/diffusions/diffusion_gauss.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
GaussianDiffusion wraps operators for denoising diffusion models, including the
|
3 |
+
diffusion and denoising processes, as well as the loss evaluation.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torchsde
|
7 |
+
import random
|
8 |
+
from tqdm.auto import trange
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = ['GaussianDiffusion']
|
12 |
+
|
13 |
+
|
14 |
+
def _i(tensor, t, x):
|
15 |
+
"""
|
16 |
+
Index tensor using t and format the output according to x.
|
17 |
+
"""
|
18 |
+
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
19 |
+
return tensor[t.to(tensor.device)].view(shape).to(x.device)
|
20 |
+
|
21 |
+
|
22 |
+
class BatchedBrownianTree:
|
23 |
+
"""
|
24 |
+
A wrapper around torchsde.BrownianTree that enables batches of entropy.
|
25 |
+
"""
|
26 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
27 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
28 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
29 |
+
if seed is None:
|
30 |
+
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
31 |
+
self.batched = True
|
32 |
+
try:
|
33 |
+
assert len(seed) == x.shape[0]
|
34 |
+
w0 = w0[0]
|
35 |
+
except TypeError:
|
36 |
+
seed = [seed]
|
37 |
+
self.batched = False
|
38 |
+
self.trees = [torchsde.BrownianTree(
|
39 |
+
t0, w0, t1, entropy=s, **kwargs
|
40 |
+
) for s in seed]
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def sort(a, b):
|
44 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
45 |
+
|
46 |
+
def __call__(self, t0, t1):
|
47 |
+
t0, t1, sign = self.sort(t0, t1)
|
48 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
49 |
+
return w if self.batched else w[0]
|
50 |
+
|
51 |
+
|
52 |
+
class BrownianTreeNoiseSampler:
|
53 |
+
"""
|
54 |
+
A noise sampler backed by a torchsde.BrownianTree.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
58 |
+
random samples.
|
59 |
+
sigma_min (float): The low end of the valid interval.
|
60 |
+
sigma_max (float): The high end of the valid interval.
|
61 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
62 |
+
supplied instead of a single integer, then the noise sampler will
|
63 |
+
use one BrownianTree per batch item, each with its own seed.
|
64 |
+
transform (callable): A function that maps sigma to the sampler's
|
65 |
+
internal timestep.
|
66 |
+
"""
|
67 |
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
68 |
+
self.transform = transform
|
69 |
+
t0 = self.transform(torch.as_tensor(sigma_min))
|
70 |
+
t1 = self.transform(torch.as_tensor(sigma_max))
|
71 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
72 |
+
|
73 |
+
def __call__(self, sigma, sigma_next):
|
74 |
+
t0 = self.transform(torch.as_tensor(sigma))
|
75 |
+
t1 = self.transform(torch.as_tensor(sigma_next))
|
76 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
77 |
+
|
78 |
+
|
79 |
+
def get_scalings(sigma):
|
80 |
+
c_out = -sigma
|
81 |
+
c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5
|
82 |
+
return c_out, c_in
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def sample_dpmpp_2m_sde(
|
87 |
+
noise,
|
88 |
+
model,
|
89 |
+
sigmas,
|
90 |
+
eta=1.,
|
91 |
+
s_noise=1.,
|
92 |
+
solver_type='midpoint',
|
93 |
+
show_progress=True
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
DPM-Solver++ (2M) SDE.
|
97 |
+
"""
|
98 |
+
assert solver_type in {'heun', 'midpoint'}
|
99 |
+
|
100 |
+
x = noise * sigmas[0]
|
101 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max()
|
102 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
103 |
+
old_denoised = None
|
104 |
+
h_last = None
|
105 |
+
|
106 |
+
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
107 |
+
if sigmas[i] == float('inf'):
|
108 |
+
# Euler method
|
109 |
+
denoised = model(noise, sigmas[i])
|
110 |
+
x = denoised + sigmas[i + 1] * noise
|
111 |
+
else:
|
112 |
+
_, c_in = get_scalings(sigmas[i])
|
113 |
+
denoised = model(x * c_in, sigmas[i])
|
114 |
+
if sigmas[i + 1] == 0:
|
115 |
+
# Denoising step
|
116 |
+
x = denoised
|
117 |
+
else:
|
118 |
+
# DPM-Solver++(2M) SDE
|
119 |
+
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
120 |
+
h = s - t
|
121 |
+
eta_h = eta * h
|
122 |
+
|
123 |
+
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
|
124 |
+
(-h - eta_h).expm1().neg() * denoised
|
125 |
+
|
126 |
+
if old_denoised is not None:
|
127 |
+
r = h_last / h
|
128 |
+
if solver_type == 'heun':
|
129 |
+
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
|
130 |
+
(1 / r) * (denoised - old_denoised)
|
131 |
+
elif solver_type == 'midpoint':
|
132 |
+
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
|
133 |
+
(1 / r) * (denoised - old_denoised)
|
134 |
+
|
135 |
+
x = x + noise_sampler(
|
136 |
+
sigmas[i],
|
137 |
+
sigmas[i + 1]
|
138 |
+
) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
139 |
+
|
140 |
+
old_denoised = denoised
|
141 |
+
h_last = h
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class GaussianDiffusion(object):
|
146 |
+
|
147 |
+
def __init__(self, sigmas, prediction_type='eps'):
|
148 |
+
assert prediction_type in {'x0', 'eps', 'v'}
|
149 |
+
self.sigmas = sigmas.float() # noise coefficients
|
150 |
+
self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients
|
151 |
+
self.num_timesteps = len(sigmas)
|
152 |
+
self.prediction_type = prediction_type
|
153 |
+
|
154 |
+
def diffuse(self, x0, t, noise=None):
|
155 |
+
"""
|
156 |
+
Add Gaussian noise to signal x0 according to:
|
157 |
+
q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I).
|
158 |
+
"""
|
159 |
+
noise = torch.randn_like(x0) if noise is None else noise
|
160 |
+
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
|
161 |
+
return xt
|
162 |
+
|
163 |
+
def denoise(
|
164 |
+
self,
|
165 |
+
xt,
|
166 |
+
t,
|
167 |
+
s,
|
168 |
+
model,
|
169 |
+
model_kwargs={},
|
170 |
+
guide_scale=None,
|
171 |
+
guide_rescale=None,
|
172 |
+
clamp=None,
|
173 |
+
percentile=None
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Apply one step of denoising from the posterior distribution q(x_s | x_t, x0).
|
177 |
+
Since x0 is not available, estimate the denoising results using the learned
|
178 |
+
distribution p(x_s | x_t, \hat{x}_0 == f(x_t)).
|
179 |
+
"""
|
180 |
+
s = t - 1 if s is None else s
|
181 |
+
|
182 |
+
# hyperparams
|
183 |
+
sigmas = _i(self.sigmas, t, xt)
|
184 |
+
alphas = _i(self.alphas, t, xt)
|
185 |
+
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
186 |
+
alphas_s[s < 0] = 1.
|
187 |
+
sigmas_s = torch.sqrt(1 - alphas_s ** 2)
|
188 |
+
|
189 |
+
# precompute variables
|
190 |
+
betas = 1 - (alphas / alphas_s) ** 2
|
191 |
+
coef1 = betas * alphas_s / sigmas ** 2
|
192 |
+
coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2)
|
193 |
+
var = betas * (sigmas_s / sigmas) ** 2
|
194 |
+
log_var = torch.log(var).clamp_(-20, 20)
|
195 |
+
|
196 |
+
# prediction
|
197 |
+
if guide_scale is None:
|
198 |
+
assert isinstance(model_kwargs, dict)
|
199 |
+
out = model(xt, t=t, **model_kwargs)
|
200 |
+
else:
|
201 |
+
# classifier-free guidance (arXiv:2207.12598)
|
202 |
+
# model_kwargs[0]: conditional kwargs
|
203 |
+
# model_kwargs[1]: non-conditional kwargs
|
204 |
+
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
205 |
+
y_out = model(xt, t=t, **model_kwargs[0])
|
206 |
+
if guide_scale == 1.:
|
207 |
+
out = y_out
|
208 |
+
else:
|
209 |
+
u_out = model(xt, t=t, **model_kwargs[1])
|
210 |
+
out = u_out + guide_scale * (y_out - u_out)
|
211 |
+
|
212 |
+
# rescale the output according to arXiv:2305.08891
|
213 |
+
if guide_rescale is not None:
|
214 |
+
assert guide_rescale >= 0 and guide_rescale <= 1
|
215 |
+
ratio = (y_out.flatten(1).std(dim=1) / (
|
216 |
+
out.flatten(1).std(dim=1) + 1e-12
|
217 |
+
)).view((-1, ) + (1, ) * (y_out.ndim - 1))
|
218 |
+
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
219 |
+
|
220 |
+
# compute x0
|
221 |
+
if self.prediction_type == 'x0':
|
222 |
+
x0 = out
|
223 |
+
elif self.prediction_type == 'eps':
|
224 |
+
x0 = (xt - sigmas * out) / alphas
|
225 |
+
elif self.prediction_type == 'v':
|
226 |
+
x0 = alphas * xt - sigmas * out
|
227 |
+
else:
|
228 |
+
raise NotImplementedError(
|
229 |
+
f'prediction_type {self.prediction_type} not implemented'
|
230 |
+
)
|
231 |
+
|
232 |
+
# restrict the range of x0
|
233 |
+
if percentile is not None:
|
234 |
+
# NOTE: percentile should only be used when data is within range [-1, 1]
|
235 |
+
assert percentile > 0 and percentile <= 1
|
236 |
+
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
|
237 |
+
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
|
238 |
+
x0 = torch.min(s, torch.max(-s, x0)) / s
|
239 |
+
elif clamp is not None:
|
240 |
+
x0 = x0.clamp(-clamp, clamp)
|
241 |
+
|
242 |
+
# recompute eps using the restricted x0
|
243 |
+
eps = (xt - alphas * x0) / sigmas
|
244 |
+
|
245 |
+
# compute mu (mean of posterior distribution) using the restricted x0
|
246 |
+
mu = coef1 * x0 + coef2 * xt
|
247 |
+
return mu, var, log_var, x0, eps
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def sample(
|
251 |
+
self,
|
252 |
+
noise,
|
253 |
+
model,
|
254 |
+
model_kwargs={},
|
255 |
+
condition_fn=None,
|
256 |
+
guide_scale=None,
|
257 |
+
guide_rescale=None,
|
258 |
+
clamp=None,
|
259 |
+
percentile=None,
|
260 |
+
solver='euler_a',
|
261 |
+
steps=20,
|
262 |
+
t_max=None,
|
263 |
+
t_min=None,
|
264 |
+
discretization=None,
|
265 |
+
discard_penultimate_step=None,
|
266 |
+
return_intermediate=None,
|
267 |
+
show_progress=False,
|
268 |
+
seed=-1,
|
269 |
+
**kwargs
|
270 |
+
):
|
271 |
+
# sanity check
|
272 |
+
assert isinstance(steps, (int, torch.LongTensor))
|
273 |
+
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
274 |
+
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
275 |
+
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
276 |
+
assert discard_penultimate_step in (None, True, False)
|
277 |
+
assert return_intermediate in (None, 'x0', 'xt')
|
278 |
+
|
279 |
+
# function of diffusion solver
|
280 |
+
solver_fn = {
|
281 |
+
# 'heun': sample_heun,
|
282 |
+
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
283 |
+
}[solver]
|
284 |
+
|
285 |
+
# options
|
286 |
+
schedule = 'karras' if 'karras' in solver else None
|
287 |
+
discretization = discretization or 'linspace'
|
288 |
+
seed = seed if seed >= 0 else random.randint(0, 2 ** 31)
|
289 |
+
if isinstance(steps, torch.LongTensor):
|
290 |
+
discard_penultimate_step = False
|
291 |
+
if discard_penultimate_step is None:
|
292 |
+
discard_penultimate_step = True if solver in (
|
293 |
+
'dpm2',
|
294 |
+
'dpm2_ancestral',
|
295 |
+
'dpmpp_2m_sde',
|
296 |
+
'dpm2_karras',
|
297 |
+
'dpm2_ancestral_karras',
|
298 |
+
'dpmpp_2m_sde_karras'
|
299 |
+
) else False
|
300 |
+
|
301 |
+
# function for denoising xt to get x0
|
302 |
+
intermediates = []
|
303 |
+
def model_fn(xt, sigma):
|
304 |
+
# denoising
|
305 |
+
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
306 |
+
x0 = self.denoise(
|
307 |
+
xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
|
308 |
+
percentile
|
309 |
+
)[-2]
|
310 |
+
|
311 |
+
# collect intermediate outputs
|
312 |
+
if return_intermediate == 'xt':
|
313 |
+
intermediates.append(xt)
|
314 |
+
elif return_intermediate == 'x0':
|
315 |
+
intermediates.append(x0)
|
316 |
+
return x0
|
317 |
+
|
318 |
+
# get timesteps
|
319 |
+
if isinstance(steps, int):
|
320 |
+
steps += 1 if discard_penultimate_step else 0
|
321 |
+
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
322 |
+
t_min = 0 if t_min is None else t_min
|
323 |
+
|
324 |
+
# discretize timesteps
|
325 |
+
if discretization == 'leading':
|
326 |
+
steps = torch.arange(
|
327 |
+
t_min, t_max + 1, (t_max - t_min + 1) / steps
|
328 |
+
).flip(0)
|
329 |
+
elif discretization == 'linspace':
|
330 |
+
steps = torch.linspace(t_max, t_min, steps)
|
331 |
+
elif discretization == 'trailing':
|
332 |
+
steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps))
|
333 |
+
else:
|
334 |
+
raise NotImplementedError(
|
335 |
+
f'{discretization} discretization not implemented'
|
336 |
+
)
|
337 |
+
steps = steps.clamp_(t_min, t_max)
|
338 |
+
steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device)
|
339 |
+
|
340 |
+
# get sigmas
|
341 |
+
sigmas = self._t_to_sigma(steps)
|
342 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
343 |
+
if schedule == 'karras':
|
344 |
+
if sigmas[0] == float('inf'):
|
345 |
+
sigmas = karras_schedule(
|
346 |
+
n=len(steps) - 1,
|
347 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
348 |
+
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
349 |
+
rho=7.
|
350 |
+
).to(sigmas)
|
351 |
+
sigmas = torch.cat([
|
352 |
+
sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1])
|
353 |
+
])
|
354 |
+
else:
|
355 |
+
sigmas = karras_schedule(
|
356 |
+
n=len(steps),
|
357 |
+
sigma_min=sigmas[sigmas > 0].min().item(),
|
358 |
+
sigma_max=sigmas.max().item(),
|
359 |
+
rho=7.
|
360 |
+
).to(sigmas)
|
361 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
362 |
+
if discard_penultimate_step:
|
363 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
364 |
+
|
365 |
+
# sampling
|
366 |
+
x0 = solver_fn(
|
367 |
+
noise,
|
368 |
+
model_fn,
|
369 |
+
sigmas,
|
370 |
+
show_progress=show_progress,
|
371 |
+
**kwargs
|
372 |
+
)
|
373 |
+
return (x0, intermediates) if return_intermediate is not None else x0
|
374 |
+
|
375 |
+
@torch.no_grad()
|
376 |
+
def ddim_reverse_sample(
|
377 |
+
self,
|
378 |
+
xt,
|
379 |
+
t,
|
380 |
+
model,
|
381 |
+
model_kwargs={},
|
382 |
+
clamp=None,
|
383 |
+
percentile=None,
|
384 |
+
guide_scale=None,
|
385 |
+
guide_rescale=None,
|
386 |
+
ddim_timesteps=20,
|
387 |
+
reverse_steps=600
|
388 |
+
):
|
389 |
+
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
390 |
+
"""
|
391 |
+
stride = reverse_steps // ddim_timesteps
|
392 |
+
|
393 |
+
# predict distribution of p(x_{t-1} | x_t)
|
394 |
+
_, _, _, x0, eps = self.denoise(
|
395 |
+
xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
|
396 |
+
percentile
|
397 |
+
)
|
398 |
+
# derive variables
|
399 |
+
s = (t + stride).clamp(0, reverse_steps-1)
|
400 |
+
# hyperparams
|
401 |
+
sigmas = _i(self.sigmas, t, xt)
|
402 |
+
alphas = _i(self.alphas, t, xt)
|
403 |
+
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
404 |
+
alphas_s[s < 0] = 1.
|
405 |
+
sigmas_s = torch.sqrt(1 - alphas_s ** 2)
|
406 |
+
|
407 |
+
# reverse sample
|
408 |
+
mu = alphas_s * x0 + sigmas_s * eps
|
409 |
+
return mu, x0
|
410 |
+
|
411 |
+
@torch.no_grad()
|
412 |
+
def ddim_reverse_sample_loop(
|
413 |
+
self,
|
414 |
+
x0,
|
415 |
+
model,
|
416 |
+
model_kwargs={},
|
417 |
+
clamp=None,
|
418 |
+
percentile=None,
|
419 |
+
guide_scale=None,
|
420 |
+
guide_rescale=None,
|
421 |
+
ddim_timesteps=20,
|
422 |
+
reverse_steps=600
|
423 |
+
):
|
424 |
+
# prepare input
|
425 |
+
b = x0.size(0)
|
426 |
+
xt = x0
|
427 |
+
|
428 |
+
# reconstruction steps
|
429 |
+
steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps)
|
430 |
+
for step in steps:
|
431 |
+
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
432 |
+
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps)
|
433 |
+
return xt
|
434 |
+
|
435 |
+
def _sigma_to_t(self, sigma):
|
436 |
+
if sigma == float('inf'):
|
437 |
+
t = torch.full_like(sigma, len(self.sigmas) - 1)
|
438 |
+
else:
|
439 |
+
log_sigmas = torch.sqrt(
|
440 |
+
self.sigmas ** 2 / (1 - self.sigmas ** 2)
|
441 |
+
).log().to(sigma)
|
442 |
+
log_sigma = sigma.log()
|
443 |
+
dists = log_sigma - log_sigmas[:, None]
|
444 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
|
445 |
+
max=log_sigmas.shape[0] - 2
|
446 |
+
)
|
447 |
+
high_idx = low_idx + 1
|
448 |
+
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
|
449 |
+
w = (low - log_sigma) / (low - high)
|
450 |
+
w = w.clamp(0, 1)
|
451 |
+
t = (1 - w) * low_idx + w * high_idx
|
452 |
+
t = t.view(sigma.shape)
|
453 |
+
if t.ndim == 0:
|
454 |
+
t = t.unsqueeze(0)
|
455 |
+
return t
|
456 |
+
|
457 |
+
def _t_to_sigma(self, t):
|
458 |
+
t = t.float()
|
459 |
+
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
460 |
+
log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t)
|
461 |
+
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
462 |
+
log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf')
|
463 |
+
return log_sigma.exp()
|
464 |
+
|
465 |
+
def prev_step(self, model_out, t, xt, inference_steps=50):
|
466 |
+
prev_t = t - self.num_timesteps // inference_steps
|
467 |
+
|
468 |
+
sigmas = _i(self.sigmas, t, xt)
|
469 |
+
alphas = _i(self.alphas, t, xt)
|
470 |
+
alphas_prev = _i(self.alphas, prev_t.clamp(0), xt)
|
471 |
+
alphas_prev[prev_t < 0] = 1.
|
472 |
+
sigmas_prev = torch.sqrt(1 - alphas_prev ** 2)
|
473 |
+
|
474 |
+
x0 = alphas * xt - sigmas * model_out
|
475 |
+
eps = (xt - alphas * x0) / sigmas
|
476 |
+
prev_sample = alphas_prev * x0 + sigmas_prev * eps
|
477 |
+
return prev_sample
|
478 |
+
|
479 |
+
def next_step(self, model_out, t, xt, inference_steps=50):
|
480 |
+
t, next_t = min(t - self.num_timesteps // inference_steps, 999), t
|
481 |
+
|
482 |
+
sigmas = _i(self.sigmas, t, xt)
|
483 |
+
alphas = _i(self.alphas, t, xt)
|
484 |
+
alphas_next = _i(self.alphas, next_t.clamp(0), xt)
|
485 |
+
alphas_next[next_t < 0] = 1.
|
486 |
+
sigmas_next = torch.sqrt(1 - alphas_next ** 2)
|
487 |
+
|
488 |
+
x0 = alphas * xt - sigmas * model_out
|
489 |
+
eps = (xt - alphas * x0) / sigmas
|
490 |
+
next_sample = alphas_next * x0 + sigmas_next * eps
|
491 |
+
return next_sample
|
492 |
+
|
493 |
+
def get_noise_pred_single(self, xt, t, model, model_kwargs):
|
494 |
+
assert isinstance(model_kwargs, dict)
|
495 |
+
out = model(xt, t=t, **model_kwargs)
|
496 |
+
return out
|
497 |
+
|
498 |
+
|
UniAnimate/tools/modules/diffusions/losses.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
|
5 |
+
|
6 |
+
def kl_divergence(mu1, logvar1, mu2, logvar2):
|
7 |
+
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2))
|
8 |
+
|
9 |
+
def standard_normal_cdf(x):
|
10 |
+
r"""A fast approximation of the cumulative distribution function of the standard normal.
|
11 |
+
"""
|
12 |
+
return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
13 |
+
|
14 |
+
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
|
15 |
+
assert x0.shape == mean.shape == log_scale.shape
|
16 |
+
cx = x0 - mean
|
17 |
+
inv_stdv = torch.exp(-log_scale)
|
18 |
+
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
|
19 |
+
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
|
20 |
+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
21 |
+
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
22 |
+
cdf_delta = cdf_plus - cdf_min
|
23 |
+
log_probs = torch.where(
|
24 |
+
x0 < -0.999,
|
25 |
+
log_cdf_plus,
|
26 |
+
torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))))
|
27 |
+
assert log_probs.shape == x0.shape
|
28 |
+
return log_probs
|
UniAnimate/tools/modules/diffusions/schedules.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def beta_schedule(schedule='cosine',
|
6 |
+
num_timesteps=1000,
|
7 |
+
zero_terminal_snr=False,
|
8 |
+
**kwargs):
|
9 |
+
# compute betas
|
10 |
+
betas = {
|
11 |
+
# 'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
|
12 |
+
'linear': linear_schedule,
|
13 |
+
'linear_sd': linear_sd_schedule,
|
14 |
+
'quadratic': quadratic_schedule,
|
15 |
+
'cosine': cosine_schedule
|
16 |
+
}[schedule](num_timesteps, **kwargs)
|
17 |
+
|
18 |
+
if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001:
|
19 |
+
betas = rescale_zero_terminal_snr(betas)
|
20 |
+
|
21 |
+
return betas
|
22 |
+
|
23 |
+
|
24 |
+
def sigma_schedule(schedule='cosine',
|
25 |
+
num_timesteps=1000,
|
26 |
+
zero_terminal_snr=False,
|
27 |
+
**kwargs):
|
28 |
+
# compute betas
|
29 |
+
betas = {
|
30 |
+
'logsnr_cosine_interp': logsnr_cosine_interp_schedule,
|
31 |
+
'linear': linear_schedule,
|
32 |
+
'linear_sd': linear_sd_schedule,
|
33 |
+
'quadratic': quadratic_schedule,
|
34 |
+
'cosine': cosine_schedule
|
35 |
+
}[schedule](num_timesteps, **kwargs)
|
36 |
+
if schedule == 'logsnr_cosine_interp':
|
37 |
+
sigma = betas
|
38 |
+
else:
|
39 |
+
sigma = betas_to_sigmas(betas)
|
40 |
+
if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001:
|
41 |
+
sigma = rescale_zero_terminal_snr(sigma)
|
42 |
+
|
43 |
+
return sigma
|
44 |
+
|
45 |
+
|
46 |
+
def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs):
|
47 |
+
scale = 1000.0 / num_timesteps
|
48 |
+
init_beta = init_beta or scale * 0.0001
|
49 |
+
ast_beta = last_beta or scale * 0.02
|
50 |
+
return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
51 |
+
|
52 |
+
def logsnr_cosine_interp_schedule(
|
53 |
+
num_timesteps,
|
54 |
+
scale_min=2,
|
55 |
+
scale_max=4,
|
56 |
+
logsnr_min=-15,
|
57 |
+
logsnr_max=15,
|
58 |
+
**kwargs):
|
59 |
+
return logsnrs_to_sigmas(
|
60 |
+
_logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max))
|
61 |
+
|
62 |
+
def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs):
|
63 |
+
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
|
64 |
+
|
65 |
+
|
66 |
+
def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs):
|
67 |
+
init_beta = init_beta or 0.0015
|
68 |
+
last_beta = last_beta or 0.0195
|
69 |
+
return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2
|
70 |
+
|
71 |
+
|
72 |
+
def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs):
|
73 |
+
betas = []
|
74 |
+
for step in range(num_timesteps):
|
75 |
+
t1 = step / num_timesteps
|
76 |
+
t2 = (step + 1) / num_timesteps
|
77 |
+
fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2
|
78 |
+
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
79 |
+
return torch.tensor(betas, dtype=torch.float64)
|
80 |
+
|
81 |
+
|
82 |
+
# def cosine_schedule(n, cosine_s=0.008, **kwargs):
|
83 |
+
# ramp = torch.linspace(0, 1, n + 1)
|
84 |
+
# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2
|
85 |
+
# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999)
|
86 |
+
# return betas_to_sigmas(betas)
|
87 |
+
|
88 |
+
|
89 |
+
def betas_to_sigmas(betas):
|
90 |
+
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
91 |
+
|
92 |
+
|
93 |
+
def sigmas_to_betas(sigmas):
|
94 |
+
square_alphas = 1 - sigmas**2
|
95 |
+
betas = 1 - torch.cat(
|
96 |
+
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
|
97 |
+
return betas
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
def sigmas_to_logsnrs(sigmas):
|
102 |
+
square_sigmas = sigmas**2
|
103 |
+
return torch.log(square_sigmas / (1 - square_sigmas))
|
104 |
+
|
105 |
+
|
106 |
+
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
|
107 |
+
t_min = math.atan(math.exp(-0.5 * logsnr_min))
|
108 |
+
t_max = math.atan(math.exp(-0.5 * logsnr_max))
|
109 |
+
t = torch.linspace(1, 0, n)
|
110 |
+
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
111 |
+
return logsnrs
|
112 |
+
|
113 |
+
|
114 |
+
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
|
115 |
+
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
|
116 |
+
logsnrs += 2 * math.log(1 / scale)
|
117 |
+
return logsnrs
|
118 |
+
|
119 |
+
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
120 |
+
ramp = torch.linspace(1, 0, n)
|
121 |
+
min_inv_rho = sigma_min**(1 / rho)
|
122 |
+
max_inv_rho = sigma_max**(1 / rho)
|
123 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
|
124 |
+
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
|
125 |
+
return sigmas
|
126 |
+
|
127 |
+
def _logsnr_cosine_interp(n,
|
128 |
+
logsnr_min=-15,
|
129 |
+
logsnr_max=15,
|
130 |
+
scale_min=2,
|
131 |
+
scale_max=4):
|
132 |
+
t = torch.linspace(1, 0, n)
|
133 |
+
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
|
134 |
+
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
|
135 |
+
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
|
136 |
+
return logsnrs
|
137 |
+
|
138 |
+
|
139 |
+
def logsnrs_to_sigmas(logsnrs):
|
140 |
+
return torch.sqrt(torch.sigmoid(-logsnrs))
|
141 |
+
|
142 |
+
|
143 |
+
def rescale_zero_terminal_snr(betas):
|
144 |
+
"""
|
145 |
+
Rescale Schedule to Zero Terminal SNR
|
146 |
+
"""
|
147 |
+
# Convert betas to alphas_bar_sqrt
|
148 |
+
alphas = 1 - betas
|
149 |
+
alphas_bar = alphas.cumprod(0)
|
150 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
151 |
+
|
152 |
+
# Store old values. 8 alphas_bar_sqrt_0 = a
|
153 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
154 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
155 |
+
# Shift so last timestep is zero.
|
156 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
157 |
+
# Scale so first timestep is back to old value.
|
158 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
159 |
+
|
160 |
+
# Convert alphas_bar_sqrt to betas
|
161 |
+
alphas_bar = alphas_bar_sqrt ** 2
|
162 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
163 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
164 |
+
betas = 1 - alphas
|
165 |
+
return betas
|
166 |
+
|
UniAnimate/tools/modules/embedding_manager.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import open_clip
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from utils.registry_class import EMBEDMANAGER
|
8 |
+
|
9 |
+
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
|
10 |
+
|
11 |
+
PROGRESSIVE_SCALE = 2000
|
12 |
+
|
13 |
+
per_img_token_list = [
|
14 |
+
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
|
15 |
+
]
|
16 |
+
|
17 |
+
def get_clip_token_for_string(string):
|
18 |
+
tokens = open_clip.tokenize(string)
|
19 |
+
|
20 |
+
return tokens[0, 1]
|
21 |
+
|
22 |
+
def get_embedding_for_clip_token(embedder, token):
|
23 |
+
return embedder(token.unsqueeze(0))[0]
|
24 |
+
|
25 |
+
|
26 |
+
@EMBEDMANAGER.register_class()
|
27 |
+
class EmbeddingManager(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
embedder,
|
31 |
+
placeholder_strings=None,
|
32 |
+
initializer_words=None,
|
33 |
+
per_image_tokens=False,
|
34 |
+
num_vectors_per_token=1,
|
35 |
+
progressive_words=False,
|
36 |
+
temporal_prompt_length=1,
|
37 |
+
token_dim=1024,
|
38 |
+
**kwargs
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.string_to_token_dict = {}
|
43 |
+
|
44 |
+
self.string_to_param_dict = nn.ParameterDict()
|
45 |
+
|
46 |
+
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
|
47 |
+
|
48 |
+
self.progressive_words = progressive_words
|
49 |
+
self.progressive_counter = 0
|
50 |
+
|
51 |
+
self.max_vectors_per_token = num_vectors_per_token
|
52 |
+
|
53 |
+
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu())
|
54 |
+
|
55 |
+
if per_image_tokens:
|
56 |
+
placeholder_strings.extend(per_img_token_list)
|
57 |
+
|
58 |
+
for idx, placeholder_string in enumerate(placeholder_strings):
|
59 |
+
|
60 |
+
token = get_clip_token_for_string(placeholder_string)
|
61 |
+
|
62 |
+
if initializer_words and idx < len(initializer_words):
|
63 |
+
init_word_token = get_clip_token_for_string(initializer_words[idx])
|
64 |
+
|
65 |
+
with torch.no_grad():
|
66 |
+
init_word_embedding = get_embedding_for_tkn(init_word_token)
|
67 |
+
|
68 |
+
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
|
69 |
+
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
|
70 |
+
else:
|
71 |
+
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
|
72 |
+
|
73 |
+
self.string_to_token_dict[placeholder_string] = token
|
74 |
+
self.string_to_param_dict[placeholder_string] = token_params
|
75 |
+
|
76 |
+
|
77 |
+
def forward(
|
78 |
+
self,
|
79 |
+
tokenized_text,
|
80 |
+
embedded_text,
|
81 |
+
):
|
82 |
+
b, n, device = *tokenized_text.shape, tokenized_text.device
|
83 |
+
|
84 |
+
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
|
85 |
+
|
86 |
+
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
|
87 |
+
|
88 |
+
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
|
89 |
+
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
|
90 |
+
embedded_text[placeholder_idx] = placeholder_embedding
|
91 |
+
else: # otherwise, need to insert and keep track of changing indices
|
92 |
+
if self.progressive_words:
|
93 |
+
self.progressive_counter += 1
|
94 |
+
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
|
95 |
+
else:
|
96 |
+
max_step_tokens = self.max_vectors_per_token
|
97 |
+
|
98 |
+
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
|
99 |
+
|
100 |
+
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
|
101 |
+
|
102 |
+
if placeholder_rows.nelement() == 0:
|
103 |
+
continue
|
104 |
+
|
105 |
+
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
|
106 |
+
sorted_rows = placeholder_rows[sort_idx]
|
107 |
+
|
108 |
+
for idx in range(len(sorted_rows)):
|
109 |
+
row = sorted_rows[idx]
|
110 |
+
col = sorted_cols[idx]
|
111 |
+
|
112 |
+
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
|
113 |
+
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
|
114 |
+
|
115 |
+
embedded_text[row] = new_embed_row
|
116 |
+
tokenized_text[row] = new_token_row
|
117 |
+
|
118 |
+
return embedded_text
|
119 |
+
|
120 |
+
def forward_with_text_img(
|
121 |
+
self,
|
122 |
+
tokenized_text,
|
123 |
+
embedded_text,
|
124 |
+
embedded_img,
|
125 |
+
):
|
126 |
+
device = tokenized_text.device
|
127 |
+
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
|
128 |
+
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
|
129 |
+
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
|
130 |
+
embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding
|
131 |
+
return embedded_text
|
132 |
+
|
133 |
+
def forward_with_text(
|
134 |
+
self,
|
135 |
+
tokenized_text,
|
136 |
+
embedded_text
|
137 |
+
):
|
138 |
+
device = tokenized_text.device
|
139 |
+
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
|
140 |
+
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
|
141 |
+
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
|
142 |
+
embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding
|
143 |
+
return embedded_text
|
144 |
+
|
145 |
+
def save(self, ckpt_path):
|
146 |
+
torch.save({"string_to_token": self.string_to_token_dict,
|
147 |
+
"string_to_param": self.string_to_param_dict}, ckpt_path)
|
148 |
+
|
149 |
+
def load(self, ckpt_path):
|
150 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
151 |
+
|
152 |
+
string_to_token = ckpt["string_to_token"]
|
153 |
+
string_to_param = ckpt["string_to_param"]
|
154 |
+
for string, token in string_to_token.items():
|
155 |
+
self.string_to_token_dict[string] = token
|
156 |
+
for string, param in string_to_param.items():
|
157 |
+
self.string_to_param_dict[string] = param
|
158 |
+
|
159 |
+
def get_embedding_norms_squared(self):
|
160 |
+
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
|
161 |
+
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
|
162 |
+
|
163 |
+
return param_norm_squared
|
164 |
+
|
165 |
+
def embedding_parameters(self):
|
166 |
+
return self.string_to_param_dict.parameters()
|
167 |
+
|
168 |
+
def embedding_to_coarse_loss(self):
|
169 |
+
|
170 |
+
loss = 0.
|
171 |
+
num_embeddings = len(self.initial_embeddings)
|
172 |
+
|
173 |
+
for key in self.initial_embeddings:
|
174 |
+
optimized = self.string_to_param_dict[key]
|
175 |
+
coarse = self.initial_embeddings[key].clone().to(optimized.device)
|
176 |
+
|
177 |
+
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
|
178 |
+
|
179 |
+
return loss
|
UniAnimate/tools/modules/unet/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .unet_unianimate import *
|
2 |
+
|
UniAnimate/tools/modules/unet/mha_flash.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.cuda.amp as amp
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
|
11 |
+
# from flash_attn.flash_attention import FlashAttention
|
12 |
+
class FlashAttentionBlock(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
|
15 |
+
# consider head_dim first, then num_heads
|
16 |
+
num_heads = dim // head_dim if head_dim else num_heads
|
17 |
+
head_dim = dim // num_heads
|
18 |
+
assert num_heads * head_dim == dim
|
19 |
+
super(FlashAttentionBlock, self).__init__()
|
20 |
+
self.dim = dim
|
21 |
+
self.context_dim = context_dim
|
22 |
+
self.num_heads = num_heads
|
23 |
+
self.head_dim = head_dim
|
24 |
+
self.scale = math.pow(head_dim, -0.25)
|
25 |
+
|
26 |
+
# layers
|
27 |
+
self.norm = nn.GroupNorm(32, dim)
|
28 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
29 |
+
if context_dim is not None:
|
30 |
+
self.context_kv = nn.Linear(context_dim, dim * 2)
|
31 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
32 |
+
|
33 |
+
if self.head_dim <= 128 and (self.head_dim % 8) == 0:
|
34 |
+
new_scale = math.pow(head_dim, -0.5)
|
35 |
+
self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
|
36 |
+
|
37 |
+
# zero out the last layer params
|
38 |
+
nn.init.zeros_(self.proj.weight)
|
39 |
+
# self.apply(self._init_weight)
|
40 |
+
|
41 |
+
|
42 |
+
def _init_weight(self, module):
|
43 |
+
if isinstance(module, nn.Linear):
|
44 |
+
module.weight.data.normal_(mean=0.0, std=0.15)
|
45 |
+
if module.bias is not None:
|
46 |
+
module.bias.data.zero_()
|
47 |
+
elif isinstance(module, nn.Conv2d):
|
48 |
+
module.weight.data.normal_(mean=0.0, std=0.15)
|
49 |
+
if module.bias is not None:
|
50 |
+
module.bias.data.zero_()
|
51 |
+
|
52 |
+
def forward(self, x, context=None):
|
53 |
+
r"""x: [B, C, H, W].
|
54 |
+
context: [B, L, C] or None.
|
55 |
+
"""
|
56 |
+
identity = x
|
57 |
+
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
58 |
+
|
59 |
+
# compute query, key, value
|
60 |
+
x = self.norm(x)
|
61 |
+
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
62 |
+
if context is not None:
|
63 |
+
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
|
64 |
+
k = torch.cat([ck, k], dim=-1)
|
65 |
+
v = torch.cat([cv, v], dim=-1)
|
66 |
+
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
|
67 |
+
q = torch.cat([q, cq], dim=-1)
|
68 |
+
|
69 |
+
qkv = torch.cat([q,k,v], dim=1)
|
70 |
+
origin_dtype = qkv.dtype
|
71 |
+
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
|
72 |
+
out, _ = self.flash_attn(qkv)
|
73 |
+
out.to(origin_dtype)
|
74 |
+
|
75 |
+
if context is not None:
|
76 |
+
out = out[:, :-4, :, :]
|
77 |
+
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
|
78 |
+
|
79 |
+
# output
|
80 |
+
x = self.proj(out)
|
81 |
+
return x + identity
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
batch_size = 8
|
85 |
+
flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
|
86 |
+
|
87 |
+
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
|
88 |
+
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
|
89 |
+
# context = None
|
90 |
+
flash_net.eval()
|
91 |
+
|
92 |
+
with amp.autocast(enabled=True):
|
93 |
+
# warm up
|
94 |
+
for i in range(5):
|
95 |
+
y = flash_net(x, context)
|
96 |
+
torch.cuda.synchronize()
|
97 |
+
s1 = time.time()
|
98 |
+
for i in range(10):
|
99 |
+
y = flash_net(x, context)
|
100 |
+
torch.cuda.synchronize()
|
101 |
+
s2 = time.time()
|
102 |
+
|
103 |
+
print(f'Average cost time {(s2-s1)*1000/10} ms')
|
UniAnimate/tools/modules/unet/unet_unianimate.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import xformers
|
4 |
+
import xformers.ops
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from rotary_embedding_torch import RotaryEmbedding
|
9 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
10 |
+
|
11 |
+
from .util import *
|
12 |
+
# from .mha_flash import FlashAttentionBlock
|
13 |
+
from utils.registry_class import MODEL
|
14 |
+
|
15 |
+
|
16 |
+
USE_TEMPORAL_TRANSFORMER = True
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class PreNormattention(nn.Module):
|
21 |
+
def __init__(self, dim, fn):
|
22 |
+
super().__init__()
|
23 |
+
self.norm = nn.LayerNorm(dim)
|
24 |
+
self.fn = fn
|
25 |
+
def forward(self, x, **kwargs):
|
26 |
+
return self.fn(self.norm(x), **kwargs) + x
|
27 |
+
|
28 |
+
class PreNormattention_qkv(nn.Module):
|
29 |
+
def __init__(self, dim, fn):
|
30 |
+
super().__init__()
|
31 |
+
self.norm = nn.LayerNorm(dim)
|
32 |
+
self.fn = fn
|
33 |
+
def forward(self, q, k, v, **kwargs):
|
34 |
+
return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
38 |
+
super().__init__()
|
39 |
+
inner_dim = dim_head * heads
|
40 |
+
project_out = not (heads == 1 and dim_head == dim)
|
41 |
+
|
42 |
+
self.heads = heads
|
43 |
+
self.scale = dim_head ** -0.5
|
44 |
+
|
45 |
+
self.attend = nn.Softmax(dim = -1)
|
46 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
47 |
+
|
48 |
+
self.to_out = nn.Sequential(
|
49 |
+
nn.Linear(inner_dim, dim),
|
50 |
+
nn.Dropout(dropout)
|
51 |
+
) if project_out else nn.Identity()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
b, n, _, h = *x.shape, self.heads
|
55 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
56 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
57 |
+
|
58 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
59 |
+
|
60 |
+
attn = self.attend(dots)
|
61 |
+
|
62 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
63 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
64 |
+
return self.to_out(out)
|
65 |
+
|
66 |
+
|
67 |
+
class Attention_qkv(nn.Module):
|
68 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
69 |
+
super().__init__()
|
70 |
+
inner_dim = dim_head * heads
|
71 |
+
project_out = not (heads == 1 and dim_head == dim)
|
72 |
+
|
73 |
+
self.heads = heads
|
74 |
+
self.scale = dim_head ** -0.5
|
75 |
+
|
76 |
+
self.attend = nn.Softmax(dim = -1)
|
77 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
78 |
+
self.to_k = nn.Linear(dim, inner_dim, bias = False)
|
79 |
+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
80 |
+
|
81 |
+
self.to_out = nn.Sequential(
|
82 |
+
nn.Linear(inner_dim, dim),
|
83 |
+
nn.Dropout(dropout)
|
84 |
+
) if project_out else nn.Identity()
|
85 |
+
|
86 |
+
def forward(self, q, k, v):
|
87 |
+
b, n, _, h = *q.shape, self.heads
|
88 |
+
bk = k.shape[0]
|
89 |
+
|
90 |
+
q = self.to_q(q)
|
91 |
+
k = self.to_k(k)
|
92 |
+
v = self.to_v(v)
|
93 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
94 |
+
k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h)
|
95 |
+
v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h)
|
96 |
+
|
97 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
98 |
+
|
99 |
+
attn = self.attend(dots)
|
100 |
+
|
101 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
102 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
103 |
+
return self.to_out(out)
|
104 |
+
|
105 |
+
class PostNormattention(nn.Module):
|
106 |
+
def __init__(self, dim, fn):
|
107 |
+
super().__init__()
|
108 |
+
self.norm = nn.LayerNorm(dim)
|
109 |
+
self.fn = fn
|
110 |
+
def forward(self, x, **kwargs):
|
111 |
+
return self.norm(self.fn(x, **kwargs) + x)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
class Transformer_v2(nn.Module):
|
117 |
+
def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
|
118 |
+
super().__init__()
|
119 |
+
self.layers = nn.ModuleList([])
|
120 |
+
self.depth = depth
|
121 |
+
for _ in range(depth):
|
122 |
+
self.layers.append(nn.ModuleList([
|
123 |
+
PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
|
124 |
+
FeedForward(dim, mlp_dim, dropout = dropout_ffn),
|
125 |
+
]))
|
126 |
+
def forward(self, x):
|
127 |
+
for attn, ff in self.layers[:1]:
|
128 |
+
x = attn(x)
|
129 |
+
x = ff(x) + x
|
130 |
+
if self.depth > 1:
|
131 |
+
for attn, ff in self.layers[1:]:
|
132 |
+
x = attn(x)
|
133 |
+
x = ff(x) + x
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class DropPath(nn.Module):
|
138 |
+
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
|
139 |
+
"""
|
140 |
+
def __init__(self, p):
|
141 |
+
super(DropPath, self).__init__()
|
142 |
+
self.p = p
|
143 |
+
|
144 |
+
def forward(self, *args, zero=None, keep=None):
|
145 |
+
if not self.training:
|
146 |
+
return args[0] if len(args) == 1 else args
|
147 |
+
|
148 |
+
# params
|
149 |
+
x = args[0]
|
150 |
+
b = x.size(0)
|
151 |
+
n = (torch.rand(b) < self.p).sum()
|
152 |
+
|
153 |
+
# non-zero and non-keep mask
|
154 |
+
mask = x.new_ones(b, dtype=torch.bool)
|
155 |
+
if keep is not None:
|
156 |
+
mask[keep] = False
|
157 |
+
if zero is not None:
|
158 |
+
mask[zero] = False
|
159 |
+
|
160 |
+
# drop-path index
|
161 |
+
index = torch.where(mask)[0]
|
162 |
+
index = index[torch.randperm(len(index))[:n]]
|
163 |
+
if zero is not None:
|
164 |
+
index = torch.cat([index, torch.where(zero)[0]], dim=0)
|
165 |
+
|
166 |
+
# drop-path multiplier
|
167 |
+
multiplier = x.new_ones(b)
|
168 |
+
multiplier[index] = 0.0
|
169 |
+
output = tuple(u * self.broadcast(multiplier, u) for u in args)
|
170 |
+
return output[0] if len(args) == 1 else output
|
171 |
+
|
172 |
+
def broadcast(self, src, dst):
|
173 |
+
assert src.size(0) == dst.size(0)
|
174 |
+
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
|
175 |
+
return src.view(shape)
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
@MODEL.register_class()
|
181 |
+
class UNetSD_UniAnimate(nn.Module):
|
182 |
+
|
183 |
+
def __init__(self,
|
184 |
+
config=None,
|
185 |
+
in_dim=4,
|
186 |
+
dim=512,
|
187 |
+
y_dim=512,
|
188 |
+
context_dim=1024,
|
189 |
+
hist_dim = 156,
|
190 |
+
concat_dim = 8,
|
191 |
+
out_dim=6,
|
192 |
+
dim_mult=[1, 2, 3, 4],
|
193 |
+
num_heads=None,
|
194 |
+
head_dim=64,
|
195 |
+
num_res_blocks=3,
|
196 |
+
attn_scales=[1 / 2, 1 / 4, 1 / 8],
|
197 |
+
use_scale_shift_norm=True,
|
198 |
+
dropout=0.1,
|
199 |
+
temporal_attn_times=1,
|
200 |
+
temporal_attention = True,
|
201 |
+
use_checkpoint=False,
|
202 |
+
use_image_dataset=False,
|
203 |
+
use_fps_condition= False,
|
204 |
+
use_sim_mask = False,
|
205 |
+
misc_dropout = 0.5,
|
206 |
+
training=True,
|
207 |
+
inpainting=True,
|
208 |
+
p_all_zero=0.1,
|
209 |
+
p_all_keep=0.1,
|
210 |
+
zero_y = None,
|
211 |
+
black_image_feature = None,
|
212 |
+
adapter_transformer_layers = 1,
|
213 |
+
num_tokens=4,
|
214 |
+
**kwargs
|
215 |
+
):
|
216 |
+
embed_dim = dim * 4
|
217 |
+
num_heads=num_heads if num_heads else dim//32
|
218 |
+
super(UNetSD_UniAnimate, self).__init__()
|
219 |
+
self.zero_y = zero_y
|
220 |
+
self.black_image_feature = black_image_feature
|
221 |
+
self.cfg = config
|
222 |
+
self.in_dim = in_dim
|
223 |
+
self.dim = dim
|
224 |
+
self.y_dim = y_dim
|
225 |
+
self.context_dim = context_dim
|
226 |
+
self.num_tokens = num_tokens
|
227 |
+
self.hist_dim = hist_dim
|
228 |
+
self.concat_dim = concat_dim
|
229 |
+
self.embed_dim = embed_dim
|
230 |
+
self.out_dim = out_dim
|
231 |
+
self.dim_mult = dim_mult
|
232 |
+
|
233 |
+
self.num_heads = num_heads
|
234 |
+
|
235 |
+
self.head_dim = head_dim
|
236 |
+
self.num_res_blocks = num_res_blocks
|
237 |
+
self.attn_scales = attn_scales
|
238 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
239 |
+
self.temporal_attn_times = temporal_attn_times
|
240 |
+
self.temporal_attention = temporal_attention
|
241 |
+
self.use_checkpoint = use_checkpoint
|
242 |
+
self.use_image_dataset = use_image_dataset
|
243 |
+
self.use_fps_condition = use_fps_condition
|
244 |
+
self.use_sim_mask = use_sim_mask
|
245 |
+
self.training=training
|
246 |
+
self.inpainting = inpainting
|
247 |
+
self.video_compositions = self.cfg.video_compositions
|
248 |
+
self.misc_dropout = misc_dropout
|
249 |
+
self.p_all_zero = p_all_zero
|
250 |
+
self.p_all_keep = p_all_keep
|
251 |
+
|
252 |
+
use_linear_in_temporal = False
|
253 |
+
transformer_depth = 1
|
254 |
+
disabled_sa = False
|
255 |
+
# params
|
256 |
+
enc_dims = [dim * u for u in [1] + dim_mult]
|
257 |
+
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
258 |
+
shortcut_dims = []
|
259 |
+
scale = 1.0
|
260 |
+
self.resolution = config.resolution
|
261 |
+
|
262 |
+
|
263 |
+
# embeddings
|
264 |
+
self.time_embed = nn.Sequential(
|
265 |
+
nn.Linear(dim, embed_dim),
|
266 |
+
nn.SiLU(),
|
267 |
+
nn.Linear(embed_dim, embed_dim))
|
268 |
+
if 'image' in self.video_compositions:
|
269 |
+
self.pre_image_condition = nn.Sequential(
|
270 |
+
nn.Linear(self.context_dim, self.context_dim),
|
271 |
+
nn.SiLU(),
|
272 |
+
nn.Linear(self.context_dim, self.context_dim*self.num_tokens))
|
273 |
+
|
274 |
+
|
275 |
+
if 'local_image' in self.video_compositions:
|
276 |
+
self.local_image_embedding = nn.Sequential(
|
277 |
+
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
|
278 |
+
nn.SiLU(),
|
279 |
+
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
|
280 |
+
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
|
281 |
+
nn.SiLU(),
|
282 |
+
nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
|
283 |
+
self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
|
284 |
+
|
285 |
+
if 'dwpose' in self.video_compositions:
|
286 |
+
self.dwpose_embedding = nn.Sequential(
|
287 |
+
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
|
288 |
+
nn.SiLU(),
|
289 |
+
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
|
290 |
+
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
|
291 |
+
nn.SiLU(),
|
292 |
+
nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
|
293 |
+
self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
|
294 |
+
|
295 |
+
if 'randomref_pose' in self.video_compositions:
|
296 |
+
randomref_dim = 4
|
297 |
+
self.randomref_pose2_embedding = nn.Sequential(
|
298 |
+
nn.Conv2d(3, concat_dim * 4, 3, padding=1),
|
299 |
+
nn.SiLU(),
|
300 |
+
nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)),
|
301 |
+
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
|
302 |
+
nn.SiLU(),
|
303 |
+
nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1))
|
304 |
+
self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
|
305 |
+
|
306 |
+
if 'randomref' in self.video_compositions:
|
307 |
+
randomref_dim = 4
|
308 |
+
self.randomref_embedding2 = nn.Sequential(
|
309 |
+
nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1),
|
310 |
+
nn.SiLU(),
|
311 |
+
nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1),
|
312 |
+
nn.SiLU(),
|
313 |
+
nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1))
|
314 |
+
self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers)
|
315 |
+
|
316 |
+
### Condition Dropout
|
317 |
+
self.misc_dropout = DropPath(misc_dropout)
|
318 |
+
|
319 |
+
|
320 |
+
if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
321 |
+
self.rotary_emb = RotaryEmbedding(min(32, head_dim))
|
322 |
+
self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet
|
323 |
+
|
324 |
+
if self.use_fps_condition:
|
325 |
+
self.fps_embedding = nn.Sequential(
|
326 |
+
nn.Linear(dim, embed_dim),
|
327 |
+
nn.SiLU(),
|
328 |
+
nn.Linear(embed_dim, embed_dim))
|
329 |
+
nn.init.zeros_(self.fps_embedding[-1].weight)
|
330 |
+
nn.init.zeros_(self.fps_embedding[-1].bias)
|
331 |
+
|
332 |
+
# encoder
|
333 |
+
self.input_blocks = nn.ModuleList()
|
334 |
+
self.pre_image = nn.Sequential()
|
335 |
+
init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)])
|
336 |
+
|
337 |
+
#### need an initial temporal attention?
|
338 |
+
if temporal_attention:
|
339 |
+
if USE_TEMPORAL_TRANSFORMER:
|
340 |
+
init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim,
|
341 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
|
342 |
+
else:
|
343 |
+
init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset))
|
344 |
+
|
345 |
+
self.input_blocks.append(init_block)
|
346 |
+
shortcut_dims.append(dim)
|
347 |
+
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
348 |
+
for j in range(num_res_blocks):
|
349 |
+
|
350 |
+
block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)])
|
351 |
+
|
352 |
+
if scale in attn_scales:
|
353 |
+
block.append(
|
354 |
+
SpatialTransformer(
|
355 |
+
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
|
356 |
+
disable_self_attn=False, use_linear=True
|
357 |
+
)
|
358 |
+
)
|
359 |
+
if self.temporal_attention:
|
360 |
+
if USE_TEMPORAL_TRANSFORMER:
|
361 |
+
block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
|
362 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset))
|
363 |
+
else:
|
364 |
+
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
|
365 |
+
in_dim = out_dim
|
366 |
+
self.input_blocks.append(block)
|
367 |
+
shortcut_dims.append(out_dim)
|
368 |
+
|
369 |
+
# downsample
|
370 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
371 |
+
downsample = Downsample(
|
372 |
+
out_dim, True, dims=2, out_channels=out_dim
|
373 |
+
)
|
374 |
+
shortcut_dims.append(out_dim)
|
375 |
+
scale /= 2.0
|
376 |
+
self.input_blocks.append(downsample)
|
377 |
+
|
378 |
+
# middle
|
379 |
+
self.middle_block = nn.ModuleList([
|
380 |
+
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,),
|
381 |
+
SpatialTransformer(
|
382 |
+
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim,
|
383 |
+
disable_self_attn=False, use_linear=True
|
384 |
+
)])
|
385 |
+
|
386 |
+
if self.temporal_attention:
|
387 |
+
if USE_TEMPORAL_TRANSFORMER:
|
388 |
+
self.middle_block.append(
|
389 |
+
TemporalTransformer(
|
390 |
+
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
|
391 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal,
|
392 |
+
multiply_zero=use_image_dataset,
|
393 |
+
)
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
|
397 |
+
|
398 |
+
self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
|
399 |
+
|
400 |
+
|
401 |
+
# decoder
|
402 |
+
self.output_blocks = nn.ModuleList()
|
403 |
+
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
404 |
+
for j in range(num_res_blocks + 1):
|
405 |
+
|
406 |
+
block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )])
|
407 |
+
if scale in attn_scales:
|
408 |
+
block.append(
|
409 |
+
SpatialTransformer(
|
410 |
+
out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024,
|
411 |
+
disable_self_attn=False, use_linear=True
|
412 |
+
)
|
413 |
+
)
|
414 |
+
if self.temporal_attention:
|
415 |
+
if USE_TEMPORAL_TRANSFORMER:
|
416 |
+
block.append(
|
417 |
+
TemporalTransformer(
|
418 |
+
out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim,
|
419 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset
|
420 |
+
)
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times))
|
424 |
+
in_dim = out_dim
|
425 |
+
|
426 |
+
# upsample
|
427 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
428 |
+
upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim)
|
429 |
+
scale *= 2.0
|
430 |
+
block.append(upsample)
|
431 |
+
self.output_blocks.append(block)
|
432 |
+
|
433 |
+
# head
|
434 |
+
self.out = nn.Sequential(
|
435 |
+
nn.GroupNorm(32, out_dim),
|
436 |
+
nn.SiLU(),
|
437 |
+
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
438 |
+
|
439 |
+
# zero out the last layer params
|
440 |
+
nn.init.zeros_(self.out[-1].weight)
|
441 |
+
|
442 |
+
def forward(self,
|
443 |
+
x,
|
444 |
+
t,
|
445 |
+
y = None,
|
446 |
+
depth = None,
|
447 |
+
image = None,
|
448 |
+
motion = None,
|
449 |
+
local_image = None,
|
450 |
+
single_sketch = None,
|
451 |
+
masked = None,
|
452 |
+
canny = None,
|
453 |
+
sketch = None,
|
454 |
+
dwpose = None,
|
455 |
+
randomref = None,
|
456 |
+
histogram = None,
|
457 |
+
fps = None,
|
458 |
+
video_mask = None,
|
459 |
+
focus_present_mask = None,
|
460 |
+
prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
|
461 |
+
mask_last_frame_num = 0 # mask last frame num
|
462 |
+
):
|
463 |
+
|
464 |
+
|
465 |
+
assert self.inpainting or masked is None, 'inpainting is not supported'
|
466 |
+
|
467 |
+
batch, c, f, h, w= x.shape
|
468 |
+
frames = f
|
469 |
+
device = x.device
|
470 |
+
self.batch = batch
|
471 |
+
|
472 |
+
#### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
473 |
+
if mask_last_frame_num > 0:
|
474 |
+
focus_present_mask = None
|
475 |
+
video_mask[-mask_last_frame_num:] = False
|
476 |
+
else:
|
477 |
+
focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))
|
478 |
+
|
479 |
+
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
480 |
+
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)
|
481 |
+
else:
|
482 |
+
time_rel_pos_bias = None
|
483 |
+
|
484 |
+
|
485 |
+
# all-zero and all-keep masks
|
486 |
+
zero = torch.zeros(batch, dtype=torch.bool).to(x.device)
|
487 |
+
keep = torch.zeros(batch, dtype=torch.bool).to(x.device)
|
488 |
+
if self.training:
|
489 |
+
nzero = (torch.rand(batch) < self.p_all_zero).sum()
|
490 |
+
nkeep = (torch.rand(batch) < self.p_all_keep).sum()
|
491 |
+
index = torch.randperm(batch)
|
492 |
+
zero[index[0:nzero]] = True
|
493 |
+
keep[index[nzero:nzero + nkeep]] = True
|
494 |
+
assert not (zero & keep).any()
|
495 |
+
misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep)
|
496 |
+
|
497 |
+
|
498 |
+
concat = x.new_zeros(batch, self.concat_dim, f, h, w)
|
499 |
+
|
500 |
+
|
501 |
+
# local_image_embedding (first frame)
|
502 |
+
if local_image is not None:
|
503 |
+
local_image = rearrange(local_image, 'b c f h w -> (b f) c h w')
|
504 |
+
local_image = self.local_image_embedding(local_image)
|
505 |
+
|
506 |
+
h = local_image.shape[2]
|
507 |
+
local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch))
|
508 |
+
local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h)
|
509 |
+
|
510 |
+
concat = concat + misc_dropout(local_image)
|
511 |
+
|
512 |
+
if dwpose is not None:
|
513 |
+
if 'randomref_pose' in self.video_compositions:
|
514 |
+
dwpose_random_ref = dwpose[:,:,:1].clone()
|
515 |
+
dwpose = dwpose[:,:,1:]
|
516 |
+
dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w')
|
517 |
+
dwpose = self.dwpose_embedding(dwpose)
|
518 |
+
|
519 |
+
h = dwpose.shape[2]
|
520 |
+
dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch))
|
521 |
+
dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h)
|
522 |
+
concat = concat + misc_dropout(dwpose)
|
523 |
+
|
524 |
+
randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w)
|
525 |
+
if randomref is not None:
|
526 |
+
randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w')
|
527 |
+
randomref = self.randomref_embedding2(randomref)
|
528 |
+
|
529 |
+
h = randomref.shape[2]
|
530 |
+
randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch))
|
531 |
+
if 'randomref_pose' in self.video_compositions:
|
532 |
+
dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w')
|
533 |
+
dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref)
|
534 |
+
dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch))
|
535 |
+
randomref = randomref + dwpose_random_ref
|
536 |
+
|
537 |
+
randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h)
|
538 |
+
randomref_b = randomref_b + randomref_a
|
539 |
+
|
540 |
+
|
541 |
+
x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2)
|
542 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
543 |
+
x = self.pre_image(x)
|
544 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
|
545 |
+
|
546 |
+
# embeddings
|
547 |
+
if self.use_fps_condition and fps is not None:
|
548 |
+
e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim))
|
549 |
+
else:
|
550 |
+
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
551 |
+
|
552 |
+
context = x.new_zeros(batch, 0, self.context_dim)
|
553 |
+
|
554 |
+
|
555 |
+
if image is not None:
|
556 |
+
y_context = self.zero_y.repeat(batch, 1, 1)
|
557 |
+
context = torch.cat([context, y_context], dim=1)
|
558 |
+
|
559 |
+
image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1)
|
560 |
+
context = torch.cat([context, image_context], dim=1)
|
561 |
+
else:
|
562 |
+
y_context = self.zero_y.repeat(batch, 1, 1)
|
563 |
+
context = torch.cat([context, y_context], dim=1)
|
564 |
+
image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens]
|
565 |
+
context = torch.cat([context, image_context], dim=1)
|
566 |
+
|
567 |
+
# repeat f times for spatial e and context
|
568 |
+
e = e.repeat_interleave(repeats=f+1, dim=0)
|
569 |
+
context = context.repeat_interleave(repeats=f+1, dim=0)
|
570 |
+
|
571 |
+
|
572 |
+
|
573 |
+
## always in shape (b f) c h w, except for temporal layer
|
574 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
575 |
+
# encoder
|
576 |
+
xs = []
|
577 |
+
for block in self.input_blocks:
|
578 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
|
579 |
+
xs.append(x)
|
580 |
+
|
581 |
+
# middle
|
582 |
+
for block in self.middle_block:
|
583 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask)
|
584 |
+
|
585 |
+
# decoder
|
586 |
+
for block in self.output_blocks:
|
587 |
+
x = torch.cat([x, xs.pop()], dim=1)
|
588 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None)
|
589 |
+
|
590 |
+
# head
|
591 |
+
x = self.out(x)
|
592 |
+
|
593 |
+
# reshape back to (b c f h w)
|
594 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = batch)
|
595 |
+
return x[:,:,1:]
|
596 |
+
|
597 |
+
def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None):
|
598 |
+
if isinstance(module, ResidualBlock):
|
599 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
600 |
+
x = x.contiguous()
|
601 |
+
x = module(x, e, reference)
|
602 |
+
elif isinstance(module, ResBlock):
|
603 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
604 |
+
x = x.contiguous()
|
605 |
+
x = module(x, e, self.batch)
|
606 |
+
elif isinstance(module, SpatialTransformer):
|
607 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
608 |
+
x = module(x, context)
|
609 |
+
elif isinstance(module, TemporalTransformer):
|
610 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
611 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
|
612 |
+
x = module(x, context)
|
613 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
614 |
+
elif isinstance(module, CrossAttention):
|
615 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
616 |
+
x = module(x, context)
|
617 |
+
elif isinstance(module, MemoryEfficientCrossAttention):
|
618 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
619 |
+
x = module(x, context)
|
620 |
+
elif isinstance(module, BasicTransformerBlock):
|
621 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
622 |
+
x = module(x, context)
|
623 |
+
elif isinstance(module, FeedForward):
|
624 |
+
x = module(x, context)
|
625 |
+
elif isinstance(module, Upsample):
|
626 |
+
x = module(x)
|
627 |
+
elif isinstance(module, Downsample):
|
628 |
+
x = module(x)
|
629 |
+
elif isinstance(module, Resample):
|
630 |
+
x = module(x, reference)
|
631 |
+
elif isinstance(module, TemporalAttentionBlock):
|
632 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
633 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
|
634 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
635 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
636 |
+
elif isinstance(module, TemporalAttentionMultiBlock):
|
637 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
638 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
|
639 |
+
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
640 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
641 |
+
elif isinstance(module, InitTemporalConvBlock):
|
642 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
643 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
|
644 |
+
x = module(x)
|
645 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
646 |
+
elif isinstance(module, TemporalConvBlock):
|
647 |
+
module = checkpoint_wrapper(module) if self.use_checkpoint else module
|
648 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch)
|
649 |
+
x = module(x)
|
650 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
651 |
+
elif isinstance(module, nn.ModuleList):
|
652 |
+
for block in module:
|
653 |
+
x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference)
|
654 |
+
else:
|
655 |
+
x = module(x)
|
656 |
+
return x
|
657 |
+
|
658 |
+
|
659 |
+
|
UniAnimate/tools/modules/unet/util.py
ADDED
@@ -0,0 +1,1741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import xformers
|
4 |
+
import open_clip
|
5 |
+
import xformers.ops
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import einsum
|
8 |
+
from einops import rearrange
|
9 |
+
from functools import partial
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.nn.init as init
|
12 |
+
from rotary_embedding_torch import RotaryEmbedding
|
13 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
14 |
+
|
15 |
+
# from .mha_flash import FlashAttentionBlock
|
16 |
+
from utils.registry_class import MODEL
|
17 |
+
|
18 |
+
|
19 |
+
### load all keys started with prefix and replace them with new_prefix
|
20 |
+
def load_Block(state, prefix, new_prefix=None):
|
21 |
+
if new_prefix is None:
|
22 |
+
new_prefix = prefix
|
23 |
+
|
24 |
+
state_dict = {}
|
25 |
+
state = {key:value for key,value in state.items() if prefix in key}
|
26 |
+
for key,value in state.items():
|
27 |
+
new_key = key.replace(prefix, new_prefix)
|
28 |
+
state_dict[new_key]=value
|
29 |
+
return state_dict
|
30 |
+
|
31 |
+
|
32 |
+
def load_2d_pretrained_state_dict(state,cfg):
|
33 |
+
|
34 |
+
new_state_dict = {}
|
35 |
+
|
36 |
+
dim = cfg.unet_dim
|
37 |
+
num_res_blocks = cfg.unet_res_blocks
|
38 |
+
temporal_attention = cfg.temporal_attention
|
39 |
+
temporal_conv = cfg.temporal_conv
|
40 |
+
dim_mult = cfg.unet_dim_mult
|
41 |
+
attn_scales = cfg.unet_attn_scales
|
42 |
+
|
43 |
+
# params
|
44 |
+
enc_dims = [dim * u for u in [1] + dim_mult]
|
45 |
+
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
46 |
+
shortcut_dims = []
|
47 |
+
scale = 1.0
|
48 |
+
|
49 |
+
#embeddings
|
50 |
+
state_dict = load_Block(state,prefix=f'time_embedding')
|
51 |
+
new_state_dict.update(state_dict)
|
52 |
+
state_dict = load_Block(state,prefix=f'y_embedding')
|
53 |
+
new_state_dict.update(state_dict)
|
54 |
+
state_dict = load_Block(state,prefix=f'context_embedding')
|
55 |
+
new_state_dict.update(state_dict)
|
56 |
+
|
57 |
+
encoder_idx = 0
|
58 |
+
### init block
|
59 |
+
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
|
60 |
+
new_state_dict.update(state_dict)
|
61 |
+
encoder_idx += 1
|
62 |
+
|
63 |
+
shortcut_dims.append(dim)
|
64 |
+
for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
65 |
+
for j in range(num_res_blocks):
|
66 |
+
# residual (+attention) blocks
|
67 |
+
idx = 0
|
68 |
+
idx_ = 0
|
69 |
+
# residual (+attention) blocks
|
70 |
+
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
|
71 |
+
new_state_dict.update(state_dict)
|
72 |
+
idx += 1
|
73 |
+
idx_ = 2
|
74 |
+
|
75 |
+
if scale in attn_scales:
|
76 |
+
# block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
|
77 |
+
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}')
|
78 |
+
new_state_dict.update(state_dict)
|
79 |
+
# if temporal_attention:
|
80 |
+
# block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
|
81 |
+
in_dim = out_dim
|
82 |
+
encoder_idx += 1
|
83 |
+
shortcut_dims.append(out_dim)
|
84 |
+
|
85 |
+
# downsample
|
86 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
87 |
+
# downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout)
|
88 |
+
state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0')
|
89 |
+
new_state_dict.update(state_dict)
|
90 |
+
|
91 |
+
shortcut_dims.append(out_dim)
|
92 |
+
scale /= 2.0
|
93 |
+
encoder_idx += 1
|
94 |
+
|
95 |
+
# middle
|
96 |
+
# self.middle = nn.ModuleList([
|
97 |
+
# ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'),
|
98 |
+
# TemporalConvBlock(out_dim),
|
99 |
+
# AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
|
100 |
+
# if temporal_attention:
|
101 |
+
# self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
|
102 |
+
# elif temporal_conv:
|
103 |
+
# self.middle.append(TemporalConvBlock(out_dim,dropout=dropout))
|
104 |
+
# self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'))
|
105 |
+
# self.middle.append(TemporalConvBlock(out_dim))
|
106 |
+
|
107 |
+
|
108 |
+
# middle
|
109 |
+
middle_idx = 0
|
110 |
+
# self.middle = nn.ModuleList([
|
111 |
+
# ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout),
|
112 |
+
# AttentionBlock(out_dim, context_dim, num_heads, head_dim)])
|
113 |
+
state_dict = load_Block(state,prefix=f'middle.{middle_idx}')
|
114 |
+
new_state_dict.update(state_dict)
|
115 |
+
middle_idx += 2
|
116 |
+
|
117 |
+
state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}')
|
118 |
+
new_state_dict.update(state_dict)
|
119 |
+
middle_idx += 1
|
120 |
+
|
121 |
+
for _ in range(cfg.temporal_attn_times):
|
122 |
+
# self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
|
123 |
+
middle_idx += 1
|
124 |
+
|
125 |
+
# self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout))
|
126 |
+
state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}')
|
127 |
+
new_state_dict.update(state_dict)
|
128 |
+
middle_idx += 2
|
129 |
+
|
130 |
+
|
131 |
+
decoder_idx = 0
|
132 |
+
for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
133 |
+
for j in range(num_res_blocks + 1):
|
134 |
+
idx = 0
|
135 |
+
idx_ = 0
|
136 |
+
# residual (+attention) blocks
|
137 |
+
# block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
|
138 |
+
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
|
139 |
+
new_state_dict.update(state_dict)
|
140 |
+
idx += 1
|
141 |
+
idx_ += 2
|
142 |
+
if scale in attn_scales:
|
143 |
+
# block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
|
144 |
+
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
|
145 |
+
new_state_dict.update(state_dict)
|
146 |
+
idx += 1
|
147 |
+
idx_ += 1
|
148 |
+
for _ in range(cfg.temporal_attn_times):
|
149 |
+
# block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb))
|
150 |
+
idx_ +=1
|
151 |
+
|
152 |
+
in_dim = out_dim
|
153 |
+
|
154 |
+
# upsample
|
155 |
+
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
156 |
+
|
157 |
+
# upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout)
|
158 |
+
state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}')
|
159 |
+
new_state_dict.update(state_dict)
|
160 |
+
idx += 1
|
161 |
+
idx_ += 2
|
162 |
+
|
163 |
+
scale *= 2.0
|
164 |
+
# block.append(upsample)
|
165 |
+
# self.decoder.append(block)
|
166 |
+
decoder_idx += 1
|
167 |
+
|
168 |
+
# head
|
169 |
+
# self.head = nn.Sequential(
|
170 |
+
# nn.GroupNorm(32, out_dim),
|
171 |
+
# nn.SiLU(),
|
172 |
+
# nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1)))
|
173 |
+
state_dict = load_Block(state,prefix=f'head')
|
174 |
+
new_state_dict.update(state_dict)
|
175 |
+
|
176 |
+
return new_state_dict
|
177 |
+
|
178 |
+
def sinusoidal_embedding(timesteps, dim):
|
179 |
+
# check input
|
180 |
+
half = dim // 2
|
181 |
+
timesteps = timesteps.float()
|
182 |
+
|
183 |
+
# compute sinusoidal embedding
|
184 |
+
sinusoid = torch.outer(
|
185 |
+
timesteps,
|
186 |
+
torch.pow(10000, -torch.arange(half).to(timesteps).div(half)))
|
187 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
188 |
+
if dim % 2 != 0:
|
189 |
+
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
190 |
+
return x
|
191 |
+
|
192 |
+
def exists(x):
|
193 |
+
return x is not None
|
194 |
+
|
195 |
+
def default(val, d):
|
196 |
+
if exists(val):
|
197 |
+
return val
|
198 |
+
return d() if callable(d) else d
|
199 |
+
|
200 |
+
def prob_mask_like(shape, prob, device):
|
201 |
+
if prob == 1:
|
202 |
+
return torch.ones(shape, device = device, dtype = torch.bool)
|
203 |
+
elif prob == 0:
|
204 |
+
return torch.zeros(shape, device = device, dtype = torch.bool)
|
205 |
+
else:
|
206 |
+
mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
207 |
+
### aviod mask all, which will cause find_unused_parameters error
|
208 |
+
if mask.all():
|
209 |
+
mask[0]=False
|
210 |
+
return mask
|
211 |
+
|
212 |
+
|
213 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
214 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
215 |
+
def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
216 |
+
super().__init__()
|
217 |
+
inner_dim = dim_head * heads
|
218 |
+
context_dim = default(context_dim, query_dim)
|
219 |
+
|
220 |
+
self.max_bs = max_bs
|
221 |
+
self.heads = heads
|
222 |
+
self.dim_head = dim_head
|
223 |
+
|
224 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
225 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
226 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
227 |
+
|
228 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
229 |
+
self.attention_op: Optional[Any] = None
|
230 |
+
|
231 |
+
def forward(self, x, context=None, mask=None):
|
232 |
+
q = self.to_q(x)
|
233 |
+
context = default(context, x)
|
234 |
+
k = self.to_k(context)
|
235 |
+
v = self.to_v(context)
|
236 |
+
|
237 |
+
b, _, _ = q.shape
|
238 |
+
q, k, v = map(
|
239 |
+
lambda t: t.unsqueeze(3)
|
240 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
241 |
+
.permute(0, 2, 1, 3)
|
242 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
243 |
+
.contiguous(),
|
244 |
+
(q, k, v),
|
245 |
+
)
|
246 |
+
|
247 |
+
# actually compute the attention, what we cannot get enough of
|
248 |
+
if q.shape[0] > self.max_bs:
|
249 |
+
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
|
250 |
+
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
|
251 |
+
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
|
252 |
+
out_list = []
|
253 |
+
for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
|
254 |
+
out = xformers.ops.memory_efficient_attention(
|
255 |
+
q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
|
256 |
+
out_list.append(out)
|
257 |
+
out = torch.cat(out_list, dim=0)
|
258 |
+
else:
|
259 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
260 |
+
|
261 |
+
if exists(mask):
|
262 |
+
raise NotImplementedError
|
263 |
+
out = (
|
264 |
+
out.unsqueeze(0)
|
265 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
266 |
+
.permute(0, 2, 1, 3)
|
267 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
268 |
+
)
|
269 |
+
return self.to_out(out)
|
270 |
+
|
271 |
+
class RelativePositionBias(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
heads = 8,
|
275 |
+
num_buckets = 32,
|
276 |
+
max_distance = 128
|
277 |
+
):
|
278 |
+
super().__init__()
|
279 |
+
self.num_buckets = num_buckets
|
280 |
+
self.max_distance = max_distance
|
281 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128):
|
285 |
+
ret = 0
|
286 |
+
n = -relative_position
|
287 |
+
|
288 |
+
num_buckets //= 2
|
289 |
+
ret += (n < 0).long() * num_buckets
|
290 |
+
n = torch.abs(n)
|
291 |
+
|
292 |
+
max_exact = num_buckets // 2
|
293 |
+
is_small = n < max_exact
|
294 |
+
|
295 |
+
val_if_large = max_exact + (
|
296 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
297 |
+
).long()
|
298 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
299 |
+
|
300 |
+
ret += torch.where(is_small, n, val_if_large)
|
301 |
+
return ret
|
302 |
+
|
303 |
+
def forward(self, n, device):
|
304 |
+
q_pos = torch.arange(n, dtype = torch.long, device = device)
|
305 |
+
k_pos = torch.arange(n, dtype = torch.long, device = device)
|
306 |
+
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
307 |
+
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
308 |
+
values = self.relative_attention_bias(rp_bucket)
|
309 |
+
return rearrange(values, 'i j h -> h i j')
|
310 |
+
|
311 |
+
class SpatialTransformer(nn.Module):
|
312 |
+
"""
|
313 |
+
Transformer block for image-like data.
|
314 |
+
First, project the input (aka embedding)
|
315 |
+
and reshape to b, t, d.
|
316 |
+
Then apply standard transformer action.
|
317 |
+
Finally, reshape to image
|
318 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
319 |
+
"""
|
320 |
+
def __init__(self, in_channels, n_heads, d_head,
|
321 |
+
depth=1, dropout=0., context_dim=None,
|
322 |
+
disable_self_attn=False, use_linear=False,
|
323 |
+
use_checkpoint=True):
|
324 |
+
super().__init__()
|
325 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
326 |
+
context_dim = [context_dim]
|
327 |
+
self.in_channels = in_channels
|
328 |
+
inner_dim = n_heads * d_head
|
329 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
330 |
+
if not use_linear:
|
331 |
+
self.proj_in = nn.Conv2d(in_channels,
|
332 |
+
inner_dim,
|
333 |
+
kernel_size=1,
|
334 |
+
stride=1,
|
335 |
+
padding=0)
|
336 |
+
else:
|
337 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
338 |
+
|
339 |
+
self.transformer_blocks = nn.ModuleList(
|
340 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
341 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
342 |
+
for d in range(depth)]
|
343 |
+
)
|
344 |
+
if not use_linear:
|
345 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
346 |
+
in_channels,
|
347 |
+
kernel_size=1,
|
348 |
+
stride=1,
|
349 |
+
padding=0))
|
350 |
+
else:
|
351 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
352 |
+
self.use_linear = use_linear
|
353 |
+
|
354 |
+
def forward(self, x, context=None):
|
355 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
356 |
+
if not isinstance(context, list):
|
357 |
+
context = [context]
|
358 |
+
b, c, h, w = x.shape
|
359 |
+
x_in = x
|
360 |
+
x = self.norm(x)
|
361 |
+
if not self.use_linear:
|
362 |
+
x = self.proj_in(x)
|
363 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
364 |
+
if self.use_linear:
|
365 |
+
x = self.proj_in(x)
|
366 |
+
for i, block in enumerate(self.transformer_blocks):
|
367 |
+
x = block(x, context=context[i])
|
368 |
+
if self.use_linear:
|
369 |
+
x = self.proj_out(x)
|
370 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
371 |
+
if not self.use_linear:
|
372 |
+
x = self.proj_out(x)
|
373 |
+
return x + x_in
|
374 |
+
|
375 |
+
|
376 |
+
class SpatialTransformerWithAdapter(nn.Module):
|
377 |
+
"""
|
378 |
+
Transformer block for image-like data.
|
379 |
+
First, project the input (aka embedding)
|
380 |
+
and reshape to b, t, d.
|
381 |
+
Then apply standard transformer action.
|
382 |
+
Finally, reshape to image
|
383 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
384 |
+
"""
|
385 |
+
def __init__(self, in_channels, n_heads, d_head,
|
386 |
+
depth=1, dropout=0., context_dim=None,
|
387 |
+
disable_self_attn=False, use_linear=False,
|
388 |
+
use_checkpoint=True,
|
389 |
+
adapter_list=[], adapter_position_list=['', 'parallel', ''],
|
390 |
+
adapter_hidden_dim=None):
|
391 |
+
super().__init__()
|
392 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
393 |
+
context_dim = [context_dim]
|
394 |
+
self.in_channels = in_channels
|
395 |
+
inner_dim = n_heads * d_head
|
396 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
397 |
+
if not use_linear:
|
398 |
+
self.proj_in = nn.Conv2d(in_channels,
|
399 |
+
inner_dim,
|
400 |
+
kernel_size=1,
|
401 |
+
stride=1,
|
402 |
+
padding=0)
|
403 |
+
else:
|
404 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
405 |
+
|
406 |
+
self.transformer_blocks = nn.ModuleList(
|
407 |
+
[BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
408 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint,
|
409 |
+
adapter_list=adapter_list, adapter_position_list=adapter_position_list,
|
410 |
+
adapter_hidden_dim=adapter_hidden_dim)
|
411 |
+
for d in range(depth)]
|
412 |
+
)
|
413 |
+
if not use_linear:
|
414 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
415 |
+
in_channels,
|
416 |
+
kernel_size=1,
|
417 |
+
stride=1,
|
418 |
+
padding=0))
|
419 |
+
else:
|
420 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
421 |
+
self.use_linear = use_linear
|
422 |
+
|
423 |
+
def forward(self, x, context=None):
|
424 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
425 |
+
if not isinstance(context, list):
|
426 |
+
context = [context]
|
427 |
+
b, c, h, w = x.shape
|
428 |
+
x_in = x
|
429 |
+
x = self.norm(x)
|
430 |
+
if not self.use_linear:
|
431 |
+
x = self.proj_in(x)
|
432 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
433 |
+
if self.use_linear:
|
434 |
+
x = self.proj_in(x)
|
435 |
+
for i, block in enumerate(self.transformer_blocks):
|
436 |
+
x = block(x, context=context[i])
|
437 |
+
if self.use_linear:
|
438 |
+
x = self.proj_out(x)
|
439 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
440 |
+
if not self.use_linear:
|
441 |
+
x = self.proj_out(x)
|
442 |
+
return x + x_in
|
443 |
+
|
444 |
+
import os
|
445 |
+
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
446 |
+
|
447 |
+
class CrossAttention(nn.Module):
|
448 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
449 |
+
super().__init__()
|
450 |
+
inner_dim = dim_head * heads
|
451 |
+
context_dim = default(context_dim, query_dim)
|
452 |
+
|
453 |
+
self.scale = dim_head ** -0.5
|
454 |
+
self.heads = heads
|
455 |
+
|
456 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
457 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
458 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
459 |
+
|
460 |
+
self.to_out = nn.Sequential(
|
461 |
+
nn.Linear(inner_dim, query_dim),
|
462 |
+
nn.Dropout(dropout)
|
463 |
+
)
|
464 |
+
|
465 |
+
def forward(self, x, context=None, mask=None):
|
466 |
+
h = self.heads
|
467 |
+
|
468 |
+
q = self.to_q(x)
|
469 |
+
context = default(context, x)
|
470 |
+
k = self.to_k(context)
|
471 |
+
v = self.to_v(context)
|
472 |
+
|
473 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
474 |
+
|
475 |
+
# force cast to fp32 to avoid overflowing
|
476 |
+
if _ATTN_PRECISION =="fp32":
|
477 |
+
with torch.autocast(enabled=False, device_type = 'cuda'):
|
478 |
+
q, k = q.float(), k.float()
|
479 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
480 |
+
else:
|
481 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
482 |
+
|
483 |
+
del q, k
|
484 |
+
|
485 |
+
if exists(mask):
|
486 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
487 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
488 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
489 |
+
sim.masked_fill_(~mask, max_neg_value)
|
490 |
+
|
491 |
+
# attention, what we cannot get enough of
|
492 |
+
sim = sim.softmax(dim=-1)
|
493 |
+
|
494 |
+
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
495 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
496 |
+
return self.to_out(out)
|
497 |
+
|
498 |
+
|
499 |
+
class Adapter(nn.Module):
|
500 |
+
def __init__(self, in_dim, hidden_dim, condition_dim=None):
|
501 |
+
super().__init__()
|
502 |
+
self.down_linear = nn.Linear(in_dim, hidden_dim)
|
503 |
+
self.up_linear = nn.Linear(hidden_dim, in_dim)
|
504 |
+
self.condition_dim = condition_dim
|
505 |
+
if condition_dim is not None:
|
506 |
+
self.condition_linear = nn.Linear(condition_dim, in_dim)
|
507 |
+
|
508 |
+
init.zeros_(self.up_linear.weight)
|
509 |
+
init.zeros_(self.up_linear.bias)
|
510 |
+
|
511 |
+
def forward(self, x, condition=None, condition_lam=1):
|
512 |
+
x_in = x
|
513 |
+
if self.condition_dim is not None and condition is not None:
|
514 |
+
x = x + condition_lam * self.condition_linear(condition)
|
515 |
+
x = self.down_linear(x)
|
516 |
+
x = F.gelu(x)
|
517 |
+
x = self.up_linear(x)
|
518 |
+
x += x_in
|
519 |
+
return x
|
520 |
+
|
521 |
+
|
522 |
+
class MemoryEfficientCrossAttention_attemask(nn.Module):
|
523 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
524 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
525 |
+
super().__init__()
|
526 |
+
inner_dim = dim_head * heads
|
527 |
+
context_dim = default(context_dim, query_dim)
|
528 |
+
|
529 |
+
self.heads = heads
|
530 |
+
self.dim_head = dim_head
|
531 |
+
|
532 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
533 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
534 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
535 |
+
|
536 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
537 |
+
self.attention_op: Optional[Any] = None
|
538 |
+
|
539 |
+
def forward(self, x, context=None, mask=None):
|
540 |
+
q = self.to_q(x)
|
541 |
+
context = default(context, x)
|
542 |
+
k = self.to_k(context)
|
543 |
+
v = self.to_v(context)
|
544 |
+
|
545 |
+
b, _, _ = q.shape
|
546 |
+
q, k, v = map(
|
547 |
+
lambda t: t.unsqueeze(3)
|
548 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
549 |
+
.permute(0, 2, 1, 3)
|
550 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
551 |
+
.contiguous(),
|
552 |
+
(q, k, v),
|
553 |
+
)
|
554 |
+
|
555 |
+
# actually compute the attention, what we cannot get enough of
|
556 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op)
|
557 |
+
|
558 |
+
if exists(mask):
|
559 |
+
raise NotImplementedError
|
560 |
+
out = (
|
561 |
+
out.unsqueeze(0)
|
562 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
563 |
+
.permute(0, 2, 1, 3)
|
564 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
565 |
+
)
|
566 |
+
return self.to_out(out)
|
567 |
+
|
568 |
+
|
569 |
+
|
570 |
+
class BasicTransformerBlock_attemask(nn.Module):
|
571 |
+
# ATTENTION_MODES = {
|
572 |
+
# "softmax": CrossAttention, # vanilla attention
|
573 |
+
# "softmax-xformers": MemoryEfficientCrossAttention
|
574 |
+
# }
|
575 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
576 |
+
disable_self_attn=False):
|
577 |
+
super().__init__()
|
578 |
+
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
579 |
+
# assert attn_mode in self.ATTENTION_MODES
|
580 |
+
# attn_cls = CrossAttention
|
581 |
+
attn_cls = MemoryEfficientCrossAttention_attemask
|
582 |
+
self.disable_self_attn = disable_self_attn
|
583 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
584 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
585 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
586 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
587 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
588 |
+
self.norm1 = nn.LayerNorm(dim)
|
589 |
+
self.norm2 = nn.LayerNorm(dim)
|
590 |
+
self.norm3 = nn.LayerNorm(dim)
|
591 |
+
self.checkpoint = checkpoint
|
592 |
+
|
593 |
+
def forward_(self, x, context=None):
|
594 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
595 |
+
|
596 |
+
def forward(self, x, context=None):
|
597 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
598 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
599 |
+
x = self.ff(self.norm3(x)) + x
|
600 |
+
return x
|
601 |
+
|
602 |
+
|
603 |
+
class BasicTransformerBlockWithAdapter(nn.Module):
|
604 |
+
# ATTENTION_MODES = {
|
605 |
+
# "softmax": CrossAttention, # vanilla attention
|
606 |
+
# "softmax-xformers": MemoryEfficientCrossAttention
|
607 |
+
# }
|
608 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False,
|
609 |
+
adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], adapter_hidden_dim=None, adapter_condition_dim=None
|
610 |
+
):
|
611 |
+
super().__init__()
|
612 |
+
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
613 |
+
# assert attn_mode in self.ATTENTION_MODES
|
614 |
+
# attn_cls = CrossAttention
|
615 |
+
attn_cls = MemoryEfficientCrossAttention
|
616 |
+
self.disable_self_attn = disable_self_attn
|
617 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
618 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
619 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
620 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
621 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
622 |
+
self.norm1 = nn.LayerNorm(dim)
|
623 |
+
self.norm2 = nn.LayerNorm(dim)
|
624 |
+
self.norm3 = nn.LayerNorm(dim)
|
625 |
+
self.checkpoint = checkpoint
|
626 |
+
# adapter
|
627 |
+
self.adapter_list = adapter_list
|
628 |
+
self.adapter_position_list = adapter_position_list
|
629 |
+
hidden_dim = dim//2 if not adapter_hidden_dim else adapter_hidden_dim
|
630 |
+
if "self_attention" in adapter_list:
|
631 |
+
self.attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
|
632 |
+
if "cross_attention" in adapter_list:
|
633 |
+
self.cross_attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
|
634 |
+
if "feedforward" in adapter_list:
|
635 |
+
self.ff_adapter = Adapter(dim, hidden_dim, adapter_condition_dim)
|
636 |
+
|
637 |
+
|
638 |
+
def forward_(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
|
639 |
+
return checkpoint(self._forward, (x, context, adapter_condition, adapter_condition_lam), self.parameters(), self.checkpoint)
|
640 |
+
|
641 |
+
def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
|
642 |
+
if "self_attention" in self.adapter_list:
|
643 |
+
if self.adapter_position_list[0] == 'parallel':
|
644 |
+
# parallel
|
645 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + self.attn_adapter(x, adapter_condition, adapter_condition_lam)
|
646 |
+
elif self.adapter_position_list[0] == 'serial':
|
647 |
+
# serial
|
648 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
649 |
+
x = self.attn_adapter(x, adapter_condition, adapter_condition_lam)
|
650 |
+
else:
|
651 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
652 |
+
|
653 |
+
if "cross_attention" in self.adapter_list:
|
654 |
+
if self.adapter_position_list[1] == 'parallel':
|
655 |
+
# parallel
|
656 |
+
x = self.attn2(self.norm2(x), context=context) + self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam)
|
657 |
+
elif self.adapter_position_list[1] == 'serial':
|
658 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
659 |
+
x = self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam)
|
660 |
+
else:
|
661 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
662 |
+
|
663 |
+
if "feedforward" in self.adapter_list:
|
664 |
+
if self.adapter_position_list[2] == 'parallel':
|
665 |
+
x = self.ff(self.norm3(x)) + self.ff_adapter(x, adapter_condition, adapter_condition_lam)
|
666 |
+
elif self.adapter_position_list[2] == 'serial':
|
667 |
+
x = self.ff(self.norm3(x)) + x
|
668 |
+
x = self.ff_adapter(x, adapter_condition, adapter_condition_lam)
|
669 |
+
else:
|
670 |
+
x = self.ff(self.norm3(x)) + x
|
671 |
+
|
672 |
+
return x
|
673 |
+
|
674 |
+
class BasicTransformerBlock(nn.Module):
|
675 |
+
# ATTENTION_MODES = {
|
676 |
+
# "softmax": CrossAttention, # vanilla attention
|
677 |
+
# "softmax-xformers": MemoryEfficientCrossAttention
|
678 |
+
# }
|
679 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
680 |
+
disable_self_attn=False):
|
681 |
+
super().__init__()
|
682 |
+
# attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
683 |
+
# assert attn_mode in self.ATTENTION_MODES
|
684 |
+
# attn_cls = CrossAttention
|
685 |
+
attn_cls = MemoryEfficientCrossAttention
|
686 |
+
self.disable_self_attn = disable_self_attn
|
687 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
688 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
689 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
690 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
691 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
692 |
+
self.norm1 = nn.LayerNorm(dim)
|
693 |
+
self.norm2 = nn.LayerNorm(dim)
|
694 |
+
self.norm3 = nn.LayerNorm(dim)
|
695 |
+
self.checkpoint = checkpoint
|
696 |
+
|
697 |
+
def forward_(self, x, context=None):
|
698 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
699 |
+
|
700 |
+
def forward(self, x, context=None):
|
701 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
702 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
703 |
+
x = self.ff(self.norm3(x)) + x
|
704 |
+
return x
|
705 |
+
|
706 |
+
# feedforward
|
707 |
+
class GEGLU(nn.Module):
|
708 |
+
def __init__(self, dim_in, dim_out):
|
709 |
+
super().__init__()
|
710 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
711 |
+
|
712 |
+
def forward(self, x):
|
713 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
714 |
+
return x * F.gelu(gate)
|
715 |
+
|
716 |
+
def zero_module(module):
|
717 |
+
"""
|
718 |
+
Zero out the parameters of a module and return it.
|
719 |
+
"""
|
720 |
+
for p in module.parameters():
|
721 |
+
p.detach().zero_()
|
722 |
+
return module
|
723 |
+
|
724 |
+
class FeedForward(nn.Module):
|
725 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
726 |
+
super().__init__()
|
727 |
+
inner_dim = int(dim * mult)
|
728 |
+
dim_out = default(dim_out, dim)
|
729 |
+
project_in = nn.Sequential(
|
730 |
+
nn.Linear(dim, inner_dim),
|
731 |
+
nn.GELU()
|
732 |
+
) if not glu else GEGLU(dim, inner_dim)
|
733 |
+
|
734 |
+
self.net = nn.Sequential(
|
735 |
+
project_in,
|
736 |
+
nn.Dropout(dropout),
|
737 |
+
nn.Linear(inner_dim, dim_out)
|
738 |
+
)
|
739 |
+
|
740 |
+
def forward(self, x):
|
741 |
+
return self.net(x)
|
742 |
+
|
743 |
+
class Upsample(nn.Module):
|
744 |
+
"""
|
745 |
+
An upsampling layer with an optional convolution.
|
746 |
+
:param channels: channels in the inputs and outputs.
|
747 |
+
:param use_conv: a bool determining if a convolution is applied.
|
748 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
749 |
+
upsampling occurs in the inner-two dimensions.
|
750 |
+
"""
|
751 |
+
|
752 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
753 |
+
super().__init__()
|
754 |
+
self.channels = channels
|
755 |
+
self.out_channels = out_channels or channels
|
756 |
+
self.use_conv = use_conv
|
757 |
+
self.dims = dims
|
758 |
+
if use_conv:
|
759 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
|
760 |
+
|
761 |
+
def forward(self, x):
|
762 |
+
assert x.shape[1] == self.channels
|
763 |
+
if self.dims == 3:
|
764 |
+
x = F.interpolate(
|
765 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
766 |
+
)
|
767 |
+
else:
|
768 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
769 |
+
if self.use_conv:
|
770 |
+
x = self.conv(x)
|
771 |
+
return x
|
772 |
+
|
773 |
+
|
774 |
+
class UpsampleSR600(nn.Module):
|
775 |
+
"""
|
776 |
+
An upsampling layer with an optional convolution.
|
777 |
+
:param channels: channels in the inputs and outputs.
|
778 |
+
:param use_conv: a bool determining if a convolution is applied.
|
779 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
780 |
+
upsampling occurs in the inner-two dimensions.
|
781 |
+
"""
|
782 |
+
|
783 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
784 |
+
super().__init__()
|
785 |
+
self.channels = channels
|
786 |
+
self.out_channels = out_channels or channels
|
787 |
+
self.use_conv = use_conv
|
788 |
+
self.dims = dims
|
789 |
+
if use_conv:
|
790 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding)
|
791 |
+
|
792 |
+
def forward(self, x):
|
793 |
+
assert x.shape[1] == self.channels
|
794 |
+
if self.dims == 3:
|
795 |
+
x = F.interpolate(
|
796 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
797 |
+
)
|
798 |
+
else:
|
799 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
800 |
+
# TODO: to match input_blocks, remove elements of two sides
|
801 |
+
x = x[..., 1:-1, :]
|
802 |
+
if self.use_conv:
|
803 |
+
x = self.conv(x)
|
804 |
+
return x
|
805 |
+
|
806 |
+
|
807 |
+
class ResBlock(nn.Module):
|
808 |
+
"""
|
809 |
+
A residual block that can optionally change the number of channels.
|
810 |
+
:param channels: the number of input channels.
|
811 |
+
:param emb_channels: the number of timestep embedding channels.
|
812 |
+
:param dropout: the rate of dropout.
|
813 |
+
:param out_channels: if specified, the number of out channels.
|
814 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
815 |
+
convolution instead of a smaller 1x1 convolution to change the
|
816 |
+
channels in the skip connection.
|
817 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
818 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
819 |
+
:param up: if True, use this block for upsampling.
|
820 |
+
:param down: if True, use this block for downsampling.
|
821 |
+
"""
|
822 |
+
def __init__(
|
823 |
+
self,
|
824 |
+
channels,
|
825 |
+
emb_channels,
|
826 |
+
dropout,
|
827 |
+
out_channels=None,
|
828 |
+
use_conv=False,
|
829 |
+
use_scale_shift_norm=False,
|
830 |
+
dims=2,
|
831 |
+
up=False,
|
832 |
+
down=False,
|
833 |
+
use_temporal_conv=True,
|
834 |
+
use_image_dataset=False,
|
835 |
+
):
|
836 |
+
super().__init__()
|
837 |
+
self.channels = channels
|
838 |
+
self.emb_channels = emb_channels
|
839 |
+
self.dropout = dropout
|
840 |
+
self.out_channels = out_channels or channels
|
841 |
+
self.use_conv = use_conv
|
842 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
843 |
+
self.use_temporal_conv = use_temporal_conv
|
844 |
+
|
845 |
+
self.in_layers = nn.Sequential(
|
846 |
+
nn.GroupNorm(32, channels),
|
847 |
+
nn.SiLU(),
|
848 |
+
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
849 |
+
)
|
850 |
+
|
851 |
+
self.updown = up or down
|
852 |
+
|
853 |
+
if up:
|
854 |
+
self.h_upd = Upsample(channels, False, dims)
|
855 |
+
self.x_upd = Upsample(channels, False, dims)
|
856 |
+
elif down:
|
857 |
+
self.h_upd = Downsample(channels, False, dims)
|
858 |
+
self.x_upd = Downsample(channels, False, dims)
|
859 |
+
else:
|
860 |
+
self.h_upd = self.x_upd = nn.Identity()
|
861 |
+
|
862 |
+
self.emb_layers = nn.Sequential(
|
863 |
+
nn.SiLU(),
|
864 |
+
nn.Linear(
|
865 |
+
emb_channels,
|
866 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
867 |
+
),
|
868 |
+
)
|
869 |
+
self.out_layers = nn.Sequential(
|
870 |
+
nn.GroupNorm(32, self.out_channels),
|
871 |
+
nn.SiLU(),
|
872 |
+
nn.Dropout(p=dropout),
|
873 |
+
zero_module(
|
874 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
|
875 |
+
),
|
876 |
+
)
|
877 |
+
|
878 |
+
if self.out_channels == channels:
|
879 |
+
self.skip_connection = nn.Identity()
|
880 |
+
elif use_conv:
|
881 |
+
self.skip_connection = conv_nd(
|
882 |
+
dims, channels, self.out_channels, 3, padding=1
|
883 |
+
)
|
884 |
+
else:
|
885 |
+
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
886 |
+
|
887 |
+
if self.use_temporal_conv:
|
888 |
+
self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
|
889 |
+
# self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset)
|
890 |
+
|
891 |
+
def forward(self, x, emb, batch_size):
|
892 |
+
"""
|
893 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
894 |
+
:param x: an [N x C x ...] Tensor of features.
|
895 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
896 |
+
:return: an [N x C x ...] Tensor of outputs.
|
897 |
+
"""
|
898 |
+
return self._forward(x, emb, batch_size)
|
899 |
+
|
900 |
+
def _forward(self, x, emb, batch_size):
|
901 |
+
if self.updown:
|
902 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
903 |
+
h = in_rest(x)
|
904 |
+
h = self.h_upd(h)
|
905 |
+
x = self.x_upd(x)
|
906 |
+
h = in_conv(h)
|
907 |
+
else:
|
908 |
+
h = self.in_layers(x)
|
909 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
910 |
+
while len(emb_out.shape) < len(h.shape):
|
911 |
+
emb_out = emb_out[..., None]
|
912 |
+
if self.use_scale_shift_norm:
|
913 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
914 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
915 |
+
h = out_norm(h) * (1 + scale) + shift
|
916 |
+
h = out_rest(h)
|
917 |
+
else:
|
918 |
+
h = h + emb_out
|
919 |
+
h = self.out_layers(h)
|
920 |
+
h = self.skip_connection(x) + h
|
921 |
+
|
922 |
+
if self.use_temporal_conv:
|
923 |
+
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
|
924 |
+
h = self.temopral_conv(h)
|
925 |
+
# h = self.temopral_conv_2(h)
|
926 |
+
h = rearrange(h, 'b c f h w -> (b f) c h w')
|
927 |
+
return h
|
928 |
+
|
929 |
+
class Downsample(nn.Module):
|
930 |
+
"""
|
931 |
+
A downsampling layer with an optional convolution.
|
932 |
+
:param channels: channels in the inputs and outputs.
|
933 |
+
:param use_conv: a bool determining if a convolution is applied.
|
934 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
935 |
+
downsampling occurs in the inner-two dimensions.
|
936 |
+
"""
|
937 |
+
|
938 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
939 |
+
super().__init__()
|
940 |
+
self.channels = channels
|
941 |
+
self.out_channels = out_channels or channels
|
942 |
+
self.use_conv = use_conv
|
943 |
+
self.dims = dims
|
944 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
945 |
+
if use_conv:
|
946 |
+
self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
947 |
+
else:
|
948 |
+
assert self.channels == self.out_channels
|
949 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
950 |
+
|
951 |
+
def forward(self, x):
|
952 |
+
assert x.shape[1] == self.channels
|
953 |
+
return self.op(x)
|
954 |
+
|
955 |
+
class Resample(nn.Module):
|
956 |
+
|
957 |
+
def __init__(self, in_dim, out_dim, mode):
|
958 |
+
assert mode in ['none', 'upsample', 'downsample']
|
959 |
+
super(Resample, self).__init__()
|
960 |
+
self.in_dim = in_dim
|
961 |
+
self.out_dim = out_dim
|
962 |
+
self.mode = mode
|
963 |
+
|
964 |
+
def forward(self, x, reference=None):
|
965 |
+
if self.mode == 'upsample':
|
966 |
+
assert reference is not None
|
967 |
+
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
|
968 |
+
elif self.mode == 'downsample':
|
969 |
+
x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:]))
|
970 |
+
return x
|
971 |
+
|
972 |
+
class ResidualBlock(nn.Module):
|
973 |
+
|
974 |
+
def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True,
|
975 |
+
mode='none', dropout=0.0):
|
976 |
+
super(ResidualBlock, self).__init__()
|
977 |
+
self.in_dim = in_dim
|
978 |
+
self.embed_dim = embed_dim
|
979 |
+
self.out_dim = out_dim
|
980 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
981 |
+
self.mode = mode
|
982 |
+
|
983 |
+
# layers
|
984 |
+
self.layer1 = nn.Sequential(
|
985 |
+
nn.GroupNorm(32, in_dim),
|
986 |
+
nn.SiLU(),
|
987 |
+
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
988 |
+
self.resample = Resample(in_dim, in_dim, mode)
|
989 |
+
self.embedding = nn.Sequential(
|
990 |
+
nn.SiLU(),
|
991 |
+
nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim))
|
992 |
+
self.layer2 = nn.Sequential(
|
993 |
+
nn.GroupNorm(32, out_dim),
|
994 |
+
nn.SiLU(),
|
995 |
+
nn.Dropout(dropout),
|
996 |
+
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
997 |
+
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1)
|
998 |
+
|
999 |
+
# zero out the last layer params
|
1000 |
+
nn.init.zeros_(self.layer2[-1].weight)
|
1001 |
+
|
1002 |
+
def forward(self, x, e, reference=None):
|
1003 |
+
identity = self.resample(x, reference)
|
1004 |
+
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
|
1005 |
+
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
1006 |
+
if self.use_scale_shift_norm:
|
1007 |
+
scale, shift = e.chunk(2, dim=1)
|
1008 |
+
x = self.layer2[0](x) * (1 + scale) + shift
|
1009 |
+
x = self.layer2[1:](x)
|
1010 |
+
else:
|
1011 |
+
x = x + e
|
1012 |
+
x = self.layer2(x)
|
1013 |
+
x = x + self.shortcut(identity)
|
1014 |
+
return x
|
1015 |
+
|
1016 |
+
class AttentionBlock(nn.Module):
|
1017 |
+
|
1018 |
+
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
1019 |
+
# consider head_dim first, then num_heads
|
1020 |
+
num_heads = dim // head_dim if head_dim else num_heads
|
1021 |
+
head_dim = dim // num_heads
|
1022 |
+
assert num_heads * head_dim == dim
|
1023 |
+
super(AttentionBlock, self).__init__()
|
1024 |
+
self.dim = dim
|
1025 |
+
self.context_dim = context_dim
|
1026 |
+
self.num_heads = num_heads
|
1027 |
+
self.head_dim = head_dim
|
1028 |
+
self.scale = math.pow(head_dim, -0.25)
|
1029 |
+
|
1030 |
+
# layers
|
1031 |
+
self.norm = nn.GroupNorm(32, dim)
|
1032 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
1033 |
+
if context_dim is not None:
|
1034 |
+
self.context_kv = nn.Linear(context_dim, dim * 2)
|
1035 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
1036 |
+
|
1037 |
+
# zero out the last layer params
|
1038 |
+
nn.init.zeros_(self.proj.weight)
|
1039 |
+
|
1040 |
+
def forward(self, x, context=None):
|
1041 |
+
r"""x: [B, C, H, W].
|
1042 |
+
context: [B, L, C] or None.
|
1043 |
+
"""
|
1044 |
+
identity = x
|
1045 |
+
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
1046 |
+
|
1047 |
+
# compute query, key, value
|
1048 |
+
x = self.norm(x)
|
1049 |
+
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
1050 |
+
if context is not None:
|
1051 |
+
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
|
1052 |
+
k = torch.cat([ck, k], dim=-1)
|
1053 |
+
v = torch.cat([cv, v], dim=-1)
|
1054 |
+
|
1055 |
+
# compute attention
|
1056 |
+
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
1057 |
+
attn = F.softmax(attn, dim=-1)
|
1058 |
+
|
1059 |
+
# gather context
|
1060 |
+
x = torch.matmul(v, attn.transpose(-1, -2))
|
1061 |
+
x = x.reshape(b, c, h, w)
|
1062 |
+
|
1063 |
+
# output
|
1064 |
+
x = self.proj(x)
|
1065 |
+
return x + identity
|
1066 |
+
|
1067 |
+
|
1068 |
+
class TemporalAttentionBlock(nn.Module):
|
1069 |
+
def __init__(
|
1070 |
+
self,
|
1071 |
+
dim,
|
1072 |
+
heads = 4,
|
1073 |
+
dim_head = 32,
|
1074 |
+
rotary_emb = None,
|
1075 |
+
use_image_dataset = False,
|
1076 |
+
use_sim_mask = False
|
1077 |
+
):
|
1078 |
+
super().__init__()
|
1079 |
+
# consider num_heads first, as pos_bias needs fixed num_heads
|
1080 |
+
# heads = dim // dim_head if dim_head else heads
|
1081 |
+
dim_head = dim // heads
|
1082 |
+
assert heads * dim_head == dim
|
1083 |
+
self.use_image_dataset = use_image_dataset
|
1084 |
+
self.use_sim_mask = use_sim_mask
|
1085 |
+
|
1086 |
+
self.scale = dim_head ** -0.5
|
1087 |
+
self.heads = heads
|
1088 |
+
hidden_dim = dim_head * heads
|
1089 |
+
|
1090 |
+
self.norm = nn.GroupNorm(32, dim)
|
1091 |
+
self.rotary_emb = rotary_emb
|
1092 |
+
self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False)
|
1093 |
+
self.to_out = nn.Linear(hidden_dim, dim)#, bias = False)
|
1094 |
+
|
1095 |
+
# nn.init.zeros_(self.to_out.weight)
|
1096 |
+
# nn.init.zeros_(self.to_out.bias)
|
1097 |
+
|
1098 |
+
def forward(
|
1099 |
+
self,
|
1100 |
+
x,
|
1101 |
+
pos_bias = None,
|
1102 |
+
focus_present_mask = None,
|
1103 |
+
video_mask = None
|
1104 |
+
):
|
1105 |
+
|
1106 |
+
identity = x
|
1107 |
+
n, height, device = x.shape[2], x.shape[-2], x.device
|
1108 |
+
|
1109 |
+
x = self.norm(x)
|
1110 |
+
x = rearrange(x, 'b c f h w -> b (h w) f c')
|
1111 |
+
|
1112 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
1113 |
+
|
1114 |
+
if exists(focus_present_mask) and focus_present_mask.all():
|
1115 |
+
# if all batch samples are focusing on present
|
1116 |
+
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
|
1117 |
+
values = qkv[-1]
|
1118 |
+
out = self.to_out(values)
|
1119 |
+
out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
|
1120 |
+
|
1121 |
+
return out + identity
|
1122 |
+
|
1123 |
+
# split out heads
|
1124 |
+
# q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads)
|
1125 |
+
# shape [b (hw) h n c/h], n=f
|
1126 |
+
q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads)
|
1127 |
+
k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads)
|
1128 |
+
v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads)
|
1129 |
+
|
1130 |
+
|
1131 |
+
# scale
|
1132 |
+
|
1133 |
+
q = q * self.scale
|
1134 |
+
|
1135 |
+
# rotate positions into queries and keys for time attention
|
1136 |
+
if exists(self.rotary_emb):
|
1137 |
+
q = self.rotary_emb.rotate_queries_or_keys(q)
|
1138 |
+
k = self.rotary_emb.rotate_queries_or_keys(k)
|
1139 |
+
|
1140 |
+
# similarity
|
1141 |
+
# shape [b (hw) h n n], n=f
|
1142 |
+
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
|
1143 |
+
|
1144 |
+
# relative positional bias
|
1145 |
+
|
1146 |
+
if exists(pos_bias):
|
1147 |
+
# print(sim.shape,pos_bias.shape)
|
1148 |
+
sim = sim + pos_bias
|
1149 |
+
|
1150 |
+
if (focus_present_mask is None and video_mask is not None):
|
1151 |
+
#video_mask: [B, n]
|
1152 |
+
mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n]
|
1153 |
+
mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n]
|
1154 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
1155 |
+
elif exists(focus_present_mask) and not (~focus_present_mask).all():
|
1156 |
+
attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool)
|
1157 |
+
attend_self_mask = torch.eye(n, device = device, dtype = torch.bool)
|
1158 |
+
|
1159 |
+
mask = torch.where(
|
1160 |
+
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
|
1161 |
+
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
|
1162 |
+
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
1166 |
+
|
1167 |
+
if self.use_sim_mask:
|
1168 |
+
sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0)
|
1169 |
+
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
|
1170 |
+
|
1171 |
+
# numerical stability
|
1172 |
+
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
1173 |
+
attn = sim.softmax(dim = -1)
|
1174 |
+
|
1175 |
+
# aggregate values
|
1176 |
+
|
1177 |
+
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
1178 |
+
out = rearrange(out, '... h n d -> ... n (h d)')
|
1179 |
+
out = self.to_out(out)
|
1180 |
+
|
1181 |
+
out = rearrange(out, 'b (h w) f c -> b c f h w', h = height)
|
1182 |
+
|
1183 |
+
if self.use_image_dataset:
|
1184 |
+
out = identity + 0*out
|
1185 |
+
else:
|
1186 |
+
out = identity + out
|
1187 |
+
return out
|
1188 |
+
|
1189 |
+
class TemporalTransformer(nn.Module):
|
1190 |
+
"""
|
1191 |
+
Transformer block for image-like data.
|
1192 |
+
First, project the input (aka embedding)
|
1193 |
+
and reshape to b, t, d.
|
1194 |
+
Then apply standard transformer action.
|
1195 |
+
Finally, reshape to image
|
1196 |
+
"""
|
1197 |
+
def __init__(self, in_channels, n_heads, d_head,
|
1198 |
+
depth=1, dropout=0., context_dim=None,
|
1199 |
+
disable_self_attn=False, use_linear=False,
|
1200 |
+
use_checkpoint=True, only_self_att=True, multiply_zero=False):
|
1201 |
+
super().__init__()
|
1202 |
+
self.multiply_zero = multiply_zero
|
1203 |
+
self.only_self_att = only_self_att
|
1204 |
+
self.use_adaptor = False
|
1205 |
+
if self.only_self_att:
|
1206 |
+
context_dim = None
|
1207 |
+
if not isinstance(context_dim, list):
|
1208 |
+
context_dim = [context_dim]
|
1209 |
+
self.in_channels = in_channels
|
1210 |
+
inner_dim = n_heads * d_head
|
1211 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
1212 |
+
if not use_linear:
|
1213 |
+
self.proj_in = nn.Conv1d(in_channels,
|
1214 |
+
inner_dim,
|
1215 |
+
kernel_size=1,
|
1216 |
+
stride=1,
|
1217 |
+
padding=0)
|
1218 |
+
else:
|
1219 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
1220 |
+
if self.use_adaptor:
|
1221 |
+
self.adaptor_in = nn.Linear(frames, frames)
|
1222 |
+
|
1223 |
+
self.transformer_blocks = nn.ModuleList(
|
1224 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
1225 |
+
checkpoint=use_checkpoint)
|
1226 |
+
for d in range(depth)]
|
1227 |
+
)
|
1228 |
+
if not use_linear:
|
1229 |
+
self.proj_out = zero_module(nn.Conv1d(inner_dim,
|
1230 |
+
in_channels,
|
1231 |
+
kernel_size=1,
|
1232 |
+
stride=1,
|
1233 |
+
padding=0))
|
1234 |
+
else:
|
1235 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
1236 |
+
if self.use_adaptor:
|
1237 |
+
self.adaptor_out = nn.Linear(frames, frames)
|
1238 |
+
self.use_linear = use_linear
|
1239 |
+
|
1240 |
+
def forward(self, x, context=None):
|
1241 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
1242 |
+
if self.only_self_att:
|
1243 |
+
context = None
|
1244 |
+
if not isinstance(context, list):
|
1245 |
+
context = [context]
|
1246 |
+
b, c, f, h, w = x.shape
|
1247 |
+
x_in = x
|
1248 |
+
x = self.norm(x)
|
1249 |
+
|
1250 |
+
if not self.use_linear:
|
1251 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
1252 |
+
x = self.proj_in(x)
|
1253 |
+
# [16384, 16, 320]
|
1254 |
+
if self.use_linear:
|
1255 |
+
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
|
1256 |
+
x = self.proj_in(x)
|
1257 |
+
|
1258 |
+
if self.only_self_att:
|
1259 |
+
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
1260 |
+
for i, block in enumerate(self.transformer_blocks):
|
1261 |
+
x = block(x)
|
1262 |
+
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
1263 |
+
else:
|
1264 |
+
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
1265 |
+
for i, block in enumerate(self.transformer_blocks):
|
1266 |
+
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1267 |
+
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
|
1268 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
1269 |
+
for j in range(b):
|
1270 |
+
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1271 |
+
x[j] = block(x[j], context=context_i_j)
|
1272 |
+
|
1273 |
+
if self.use_linear:
|
1274 |
+
x = self.proj_out(x)
|
1275 |
+
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
|
1276 |
+
if not self.use_linear:
|
1277 |
+
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
|
1278 |
+
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
1279 |
+
x = self.proj_out(x)
|
1280 |
+
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
1281 |
+
|
1282 |
+
if self.multiply_zero:
|
1283 |
+
x = 0.0 * x + x_in
|
1284 |
+
else:
|
1285 |
+
x = x + x_in
|
1286 |
+
return x
|
1287 |
+
|
1288 |
+
|
1289 |
+
class TemporalTransformerWithAdapter(nn.Module):
|
1290 |
+
"""
|
1291 |
+
Transformer block for image-like data.
|
1292 |
+
First, project the input (aka embedding)
|
1293 |
+
and reshape to b, t, d.
|
1294 |
+
Then apply standard transformer action.
|
1295 |
+
Finally, reshape to image
|
1296 |
+
"""
|
1297 |
+
def __init__(self, in_channels, n_heads, d_head,
|
1298 |
+
depth=1, dropout=0., context_dim=None,
|
1299 |
+
disable_self_attn=False, use_linear=False,
|
1300 |
+
use_checkpoint=True, only_self_att=True, multiply_zero=False,
|
1301 |
+
adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'],
|
1302 |
+
adapter_hidden_dim=None, adapter_condition_dim=None):
|
1303 |
+
super().__init__()
|
1304 |
+
self.multiply_zero = multiply_zero
|
1305 |
+
self.only_self_att = only_self_att
|
1306 |
+
self.use_adaptor = False
|
1307 |
+
if self.only_self_att:
|
1308 |
+
context_dim = None
|
1309 |
+
if not isinstance(context_dim, list):
|
1310 |
+
context_dim = [context_dim]
|
1311 |
+
self.in_channels = in_channels
|
1312 |
+
inner_dim = n_heads * d_head
|
1313 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
1314 |
+
if not use_linear:
|
1315 |
+
self.proj_in = nn.Conv1d(in_channels,
|
1316 |
+
inner_dim,
|
1317 |
+
kernel_size=1,
|
1318 |
+
stride=1,
|
1319 |
+
padding=0)
|
1320 |
+
else:
|
1321 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
1322 |
+
if self.use_adaptor:
|
1323 |
+
self.adaptor_in = nn.Linear(frames, frames)
|
1324 |
+
|
1325 |
+
self.transformer_blocks = nn.ModuleList(
|
1326 |
+
[BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
1327 |
+
checkpoint=use_checkpoint, adapter_list=adapter_list, adapter_position_list=adapter_position_list,
|
1328 |
+
adapter_hidden_dim=adapter_hidden_dim, adapter_condition_dim=adapter_condition_dim)
|
1329 |
+
for d in range(depth)]
|
1330 |
+
)
|
1331 |
+
if not use_linear:
|
1332 |
+
self.proj_out = zero_module(nn.Conv1d(inner_dim,
|
1333 |
+
in_channels,
|
1334 |
+
kernel_size=1,
|
1335 |
+
stride=1,
|
1336 |
+
padding=0))
|
1337 |
+
else:
|
1338 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
1339 |
+
if self.use_adaptor:
|
1340 |
+
self.adaptor_out = nn.Linear(frames, frames)
|
1341 |
+
self.use_linear = use_linear
|
1342 |
+
|
1343 |
+
def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1):
|
1344 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
1345 |
+
if self.only_self_att:
|
1346 |
+
context = None
|
1347 |
+
if not isinstance(context, list):
|
1348 |
+
context = [context]
|
1349 |
+
b, c, f, h, w = x.shape
|
1350 |
+
x_in = x
|
1351 |
+
x = self.norm(x)
|
1352 |
+
|
1353 |
+
if not self.use_linear:
|
1354 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
1355 |
+
x = self.proj_in(x)
|
1356 |
+
# [16384, 16, 320]
|
1357 |
+
if self.use_linear:
|
1358 |
+
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
|
1359 |
+
x = self.proj_in(x)
|
1360 |
+
|
1361 |
+
if adapter_condition is not None:
|
1362 |
+
b_cond, f_cond, c_cond = adapter_condition.shape
|
1363 |
+
adapter_condition = adapter_condition.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1, 1)
|
1364 |
+
adapter_condition = adapter_condition.reshape(b_cond*h*w, f_cond, c_cond)
|
1365 |
+
|
1366 |
+
if self.only_self_att:
|
1367 |
+
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
1368 |
+
for i, block in enumerate(self.transformer_blocks):
|
1369 |
+
x = block(x, adapter_condition=adapter_condition, adapter_condition_lam=adapter_condition_lam)
|
1370 |
+
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
1371 |
+
else:
|
1372 |
+
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
1373 |
+
for i, block in enumerate(self.transformer_blocks):
|
1374 |
+
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1375 |
+
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
|
1376 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
1377 |
+
for j in range(b):
|
1378 |
+
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1379 |
+
x[j] = block(x[j], context=context_i_j)
|
1380 |
+
|
1381 |
+
if self.use_linear:
|
1382 |
+
x = self.proj_out(x)
|
1383 |
+
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
|
1384 |
+
if not self.use_linear:
|
1385 |
+
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
|
1386 |
+
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
1387 |
+
x = self.proj_out(x)
|
1388 |
+
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
1389 |
+
|
1390 |
+
if self.multiply_zero:
|
1391 |
+
x = 0.0 * x + x_in
|
1392 |
+
else:
|
1393 |
+
x = x + x_in
|
1394 |
+
return x
|
1395 |
+
|
1396 |
+
class Attention(nn.Module):
|
1397 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
1398 |
+
super().__init__()
|
1399 |
+
inner_dim = dim_head * heads
|
1400 |
+
project_out = not (heads == 1 and dim_head == dim)
|
1401 |
+
|
1402 |
+
self.heads = heads
|
1403 |
+
self.scale = dim_head ** -0.5
|
1404 |
+
|
1405 |
+
self.attend = nn.Softmax(dim = -1)
|
1406 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
1407 |
+
|
1408 |
+
self.to_out = nn.Sequential(
|
1409 |
+
nn.Linear(inner_dim, dim),
|
1410 |
+
nn.Dropout(dropout)
|
1411 |
+
) if project_out else nn.Identity()
|
1412 |
+
|
1413 |
+
def forward(self, x):
|
1414 |
+
b, n, _, h = *x.shape, self.heads
|
1415 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
1416 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
1417 |
+
|
1418 |
+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
1419 |
+
|
1420 |
+
attn = self.attend(dots)
|
1421 |
+
|
1422 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
1423 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
1424 |
+
return self.to_out(out)
|
1425 |
+
|
1426 |
+
class PreNormattention(nn.Module):
|
1427 |
+
def __init__(self, dim, fn):
|
1428 |
+
super().__init__()
|
1429 |
+
self.norm = nn.LayerNorm(dim)
|
1430 |
+
self.fn = fn
|
1431 |
+
def forward(self, x, **kwargs):
|
1432 |
+
return self.fn(self.norm(x), **kwargs) + x
|
1433 |
+
|
1434 |
+
class TransformerV2(nn.Module):
|
1435 |
+
def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1):
|
1436 |
+
super().__init__()
|
1437 |
+
self.layers = nn.ModuleList([])
|
1438 |
+
self.depth = depth
|
1439 |
+
for _ in range(depth):
|
1440 |
+
self.layers.append(nn.ModuleList([
|
1441 |
+
PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)),
|
1442 |
+
FeedForward(dim, mlp_dim, dropout = dropout_ffn),
|
1443 |
+
]))
|
1444 |
+
def forward(self, x):
|
1445 |
+
# if self.depth
|
1446 |
+
for attn, ff in self.layers[:1]:
|
1447 |
+
x = attn(x)
|
1448 |
+
x = ff(x) + x
|
1449 |
+
if self.depth > 1:
|
1450 |
+
for attn, ff in self.layers[1:]:
|
1451 |
+
x = attn(x)
|
1452 |
+
x = ff(x) + x
|
1453 |
+
return x
|
1454 |
+
|
1455 |
+
class TemporalTransformer_attemask(nn.Module):
|
1456 |
+
"""
|
1457 |
+
Transformer block for image-like data.
|
1458 |
+
First, project the input (aka embedding)
|
1459 |
+
and reshape to b, t, d.
|
1460 |
+
Then apply standard transformer action.
|
1461 |
+
Finally, reshape to image
|
1462 |
+
"""
|
1463 |
+
def __init__(self, in_channels, n_heads, d_head,
|
1464 |
+
depth=1, dropout=0., context_dim=None,
|
1465 |
+
disable_self_attn=False, use_linear=False,
|
1466 |
+
use_checkpoint=True, only_self_att=True, multiply_zero=False):
|
1467 |
+
super().__init__()
|
1468 |
+
self.multiply_zero = multiply_zero
|
1469 |
+
self.only_self_att = only_self_att
|
1470 |
+
self.use_adaptor = False
|
1471 |
+
if self.only_self_att:
|
1472 |
+
context_dim = None
|
1473 |
+
if not isinstance(context_dim, list):
|
1474 |
+
context_dim = [context_dim]
|
1475 |
+
self.in_channels = in_channels
|
1476 |
+
inner_dim = n_heads * d_head
|
1477 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
1478 |
+
if not use_linear:
|
1479 |
+
self.proj_in = nn.Conv1d(in_channels,
|
1480 |
+
inner_dim,
|
1481 |
+
kernel_size=1,
|
1482 |
+
stride=1,
|
1483 |
+
padding=0)
|
1484 |
+
else:
|
1485 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
1486 |
+
if self.use_adaptor:
|
1487 |
+
self.adaptor_in = nn.Linear(frames, frames)
|
1488 |
+
|
1489 |
+
self.transformer_blocks = nn.ModuleList(
|
1490 |
+
[BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
1491 |
+
checkpoint=use_checkpoint)
|
1492 |
+
for d in range(depth)]
|
1493 |
+
)
|
1494 |
+
if not use_linear:
|
1495 |
+
self.proj_out = zero_module(nn.Conv1d(inner_dim,
|
1496 |
+
in_channels,
|
1497 |
+
kernel_size=1,
|
1498 |
+
stride=1,
|
1499 |
+
padding=0))
|
1500 |
+
else:
|
1501 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
1502 |
+
if self.use_adaptor:
|
1503 |
+
self.adaptor_out = nn.Linear(frames, frames)
|
1504 |
+
self.use_linear = use_linear
|
1505 |
+
|
1506 |
+
def forward(self, x, context=None):
|
1507 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
1508 |
+
if self.only_self_att:
|
1509 |
+
context = None
|
1510 |
+
if not isinstance(context, list):
|
1511 |
+
context = [context]
|
1512 |
+
b, c, f, h, w = x.shape
|
1513 |
+
x_in = x
|
1514 |
+
x = self.norm(x)
|
1515 |
+
|
1516 |
+
if not self.use_linear:
|
1517 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
1518 |
+
x = self.proj_in(x)
|
1519 |
+
# [16384, 16, 320]
|
1520 |
+
if self.use_linear:
|
1521 |
+
x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
|
1522 |
+
x = self.proj_in(x)
|
1523 |
+
|
1524 |
+
if self.only_self_att:
|
1525 |
+
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
1526 |
+
for i, block in enumerate(self.transformer_blocks):
|
1527 |
+
x = block(x)
|
1528 |
+
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
1529 |
+
else:
|
1530 |
+
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
1531 |
+
for i, block in enumerate(self.transformer_blocks):
|
1532 |
+
# context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1533 |
+
context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous()
|
1534 |
+
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
1535 |
+
for j in range(b):
|
1536 |
+
context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous()
|
1537 |
+
x[j] = block(x[j], context=context_i_j)
|
1538 |
+
|
1539 |
+
if self.use_linear:
|
1540 |
+
x = self.proj_out(x)
|
1541 |
+
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
|
1542 |
+
if not self.use_linear:
|
1543 |
+
# x = rearrange(x, 'bhw f c -> bhw c f').contiguous()
|
1544 |
+
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
1545 |
+
x = self.proj_out(x)
|
1546 |
+
x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
1547 |
+
|
1548 |
+
if self.multiply_zero:
|
1549 |
+
x = 0.0 * x + x_in
|
1550 |
+
else:
|
1551 |
+
x = x + x_in
|
1552 |
+
return x
|
1553 |
+
|
1554 |
+
class TemporalAttentionMultiBlock(nn.Module):
|
1555 |
+
def __init__(
|
1556 |
+
self,
|
1557 |
+
dim,
|
1558 |
+
heads=4,
|
1559 |
+
dim_head=32,
|
1560 |
+
rotary_emb=None,
|
1561 |
+
use_image_dataset=False,
|
1562 |
+
use_sim_mask=False,
|
1563 |
+
temporal_attn_times=1,
|
1564 |
+
):
|
1565 |
+
super().__init__()
|
1566 |
+
self.att_layers = nn.ModuleList(
|
1567 |
+
[TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask)
|
1568 |
+
for _ in range(temporal_attn_times)]
|
1569 |
+
)
|
1570 |
+
|
1571 |
+
def forward(
|
1572 |
+
self,
|
1573 |
+
x,
|
1574 |
+
pos_bias = None,
|
1575 |
+
focus_present_mask = None,
|
1576 |
+
video_mask = None
|
1577 |
+
):
|
1578 |
+
for layer in self.att_layers:
|
1579 |
+
x = layer(x, pos_bias, focus_present_mask, video_mask)
|
1580 |
+
return x
|
1581 |
+
|
1582 |
+
|
1583 |
+
class InitTemporalConvBlock(nn.Module):
|
1584 |
+
|
1585 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False):
|
1586 |
+
super(InitTemporalConvBlock, self).__init__()
|
1587 |
+
if out_dim is None:
|
1588 |
+
out_dim = in_dim#int(1.5*in_dim)
|
1589 |
+
self.in_dim = in_dim
|
1590 |
+
self.out_dim = out_dim
|
1591 |
+
self.use_image_dataset = use_image_dataset
|
1592 |
+
|
1593 |
+
# conv layers
|
1594 |
+
self.conv = nn.Sequential(
|
1595 |
+
nn.GroupNorm(32, out_dim),
|
1596 |
+
nn.SiLU(),
|
1597 |
+
nn.Dropout(dropout),
|
1598 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1599 |
+
|
1600 |
+
# zero out the last layer params,so the conv block is identity
|
1601 |
+
# nn.init.zeros_(self.conv1[-1].weight)
|
1602 |
+
# nn.init.zeros_(self.conv1[-1].bias)
|
1603 |
+
nn.init.zeros_(self.conv[-1].weight)
|
1604 |
+
nn.init.zeros_(self.conv[-1].bias)
|
1605 |
+
|
1606 |
+
def forward(self, x):
|
1607 |
+
identity = x
|
1608 |
+
x = self.conv(x)
|
1609 |
+
if self.use_image_dataset:
|
1610 |
+
x = identity + 0*x
|
1611 |
+
else:
|
1612 |
+
x = identity + x
|
1613 |
+
return x
|
1614 |
+
|
1615 |
+
class TemporalConvBlock(nn.Module):
|
1616 |
+
|
1617 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False):
|
1618 |
+
super(TemporalConvBlock, self).__init__()
|
1619 |
+
if out_dim is None:
|
1620 |
+
out_dim = in_dim#int(1.5*in_dim)
|
1621 |
+
self.in_dim = in_dim
|
1622 |
+
self.out_dim = out_dim
|
1623 |
+
self.use_image_dataset = use_image_dataset
|
1624 |
+
|
1625 |
+
# conv layers
|
1626 |
+
self.conv1 = nn.Sequential(
|
1627 |
+
nn.GroupNorm(32, in_dim),
|
1628 |
+
nn.SiLU(),
|
1629 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1630 |
+
self.conv2 = nn.Sequential(
|
1631 |
+
nn.GroupNorm(32, out_dim),
|
1632 |
+
nn.SiLU(),
|
1633 |
+
nn.Dropout(dropout),
|
1634 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1635 |
+
|
1636 |
+
# zero out the last layer params,so the conv block is identity
|
1637 |
+
# nn.init.zeros_(self.conv1[-1].weight)
|
1638 |
+
# nn.init.zeros_(self.conv1[-1].bias)
|
1639 |
+
nn.init.zeros_(self.conv2[-1].weight)
|
1640 |
+
nn.init.zeros_(self.conv2[-1].bias)
|
1641 |
+
|
1642 |
+
def forward(self, x):
|
1643 |
+
identity = x
|
1644 |
+
x = self.conv1(x)
|
1645 |
+
x = self.conv2(x)
|
1646 |
+
if self.use_image_dataset:
|
1647 |
+
x = identity + 0*x
|
1648 |
+
else:
|
1649 |
+
x = identity + x
|
1650 |
+
return x
|
1651 |
+
|
1652 |
+
class TemporalConvBlock_v2(nn.Module):
|
1653 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False):
|
1654 |
+
super(TemporalConvBlock_v2, self).__init__()
|
1655 |
+
if out_dim is None:
|
1656 |
+
out_dim = in_dim # int(1.5*in_dim)
|
1657 |
+
self.in_dim = in_dim
|
1658 |
+
self.out_dim = out_dim
|
1659 |
+
self.use_image_dataset = use_image_dataset
|
1660 |
+
|
1661 |
+
# conv layers
|
1662 |
+
self.conv1 = nn.Sequential(
|
1663 |
+
nn.GroupNorm(32, in_dim),
|
1664 |
+
nn.SiLU(),
|
1665 |
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1666 |
+
self.conv2 = nn.Sequential(
|
1667 |
+
nn.GroupNorm(32, out_dim),
|
1668 |
+
nn.SiLU(),
|
1669 |
+
nn.Dropout(dropout),
|
1670 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1671 |
+
self.conv3 = nn.Sequential(
|
1672 |
+
nn.GroupNorm(32, out_dim),
|
1673 |
+
nn.SiLU(),
|
1674 |
+
nn.Dropout(dropout),
|
1675 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1676 |
+
self.conv4 = nn.Sequential(
|
1677 |
+
nn.GroupNorm(32, out_dim),
|
1678 |
+
nn.SiLU(),
|
1679 |
+
nn.Dropout(dropout),
|
1680 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0)))
|
1681 |
+
|
1682 |
+
# zero out the last layer params,so the conv block is identity
|
1683 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
1684 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
1685 |
+
|
1686 |
+
def forward(self, x):
|
1687 |
+
identity = x
|
1688 |
+
x = self.conv1(x)
|
1689 |
+
x = self.conv2(x)
|
1690 |
+
x = self.conv3(x)
|
1691 |
+
x = self.conv4(x)
|
1692 |
+
|
1693 |
+
if self.use_image_dataset:
|
1694 |
+
x = identity + 0.0 * x
|
1695 |
+
else:
|
1696 |
+
x = identity + x
|
1697 |
+
return x
|
1698 |
+
|
1699 |
+
|
1700 |
+
class DropPath(nn.Module):
|
1701 |
+
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
|
1702 |
+
"""
|
1703 |
+
def __init__(self, p):
|
1704 |
+
super(DropPath, self).__init__()
|
1705 |
+
self.p = p
|
1706 |
+
|
1707 |
+
def forward(self, *args, zero=None, keep=None):
|
1708 |
+
if not self.training:
|
1709 |
+
return args[0] if len(args) == 1 else args
|
1710 |
+
|
1711 |
+
# params
|
1712 |
+
x = args[0]
|
1713 |
+
b = x.size(0)
|
1714 |
+
n = (torch.rand(b) < self.p).sum()
|
1715 |
+
|
1716 |
+
# non-zero and non-keep mask
|
1717 |
+
mask = x.new_ones(b, dtype=torch.bool)
|
1718 |
+
if keep is not None:
|
1719 |
+
mask[keep] = False
|
1720 |
+
if zero is not None:
|
1721 |
+
mask[zero] = False
|
1722 |
+
|
1723 |
+
# drop-path index
|
1724 |
+
index = torch.where(mask)[0]
|
1725 |
+
index = index[torch.randperm(len(index))[:n]]
|
1726 |
+
if zero is not None:
|
1727 |
+
index = torch.cat([index, torch.where(zero)[0]], dim=0)
|
1728 |
+
|
1729 |
+
# drop-path multiplier
|
1730 |
+
multiplier = x.new_ones(b)
|
1731 |
+
multiplier[index] = 0.0
|
1732 |
+
output = tuple(u * self.broadcast(multiplier, u) for u in args)
|
1733 |
+
return output[0] if len(args) == 1 else output
|
1734 |
+
|
1735 |
+
def broadcast(self, src, dst):
|
1736 |
+
assert src.size(0) == dst.size(0)
|
1737 |
+
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
|
1738 |
+
return src.view(shape)
|
1739 |
+
|
1740 |
+
|
1741 |
+
|
UniAnimate/utils/__init__.py
ADDED
File without changes
|
UniAnimate/utils/assign_cfg.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, yaml
|
2 |
+
from copy import deepcopy, copy
|
3 |
+
|
4 |
+
|
5 |
+
# def get prior and ldm config
|
6 |
+
def assign_prior_mudule_cfg(cfg):
|
7 |
+
'''
|
8 |
+
'''
|
9 |
+
#
|
10 |
+
prior_cfg = deepcopy(cfg)
|
11 |
+
vldm_cfg = deepcopy(cfg)
|
12 |
+
|
13 |
+
with open(cfg.prior_cfg, 'r') as f:
|
14 |
+
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
15 |
+
# _cfg_update = _cfg_update.cfg_dict
|
16 |
+
for k, v in _cfg_update.items():
|
17 |
+
if isinstance(v, dict) and k in cfg:
|
18 |
+
prior_cfg[k].update(v)
|
19 |
+
else:
|
20 |
+
prior_cfg[k] = v
|
21 |
+
|
22 |
+
with open(cfg.vldm_cfg, 'r') as f:
|
23 |
+
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
24 |
+
# _cfg_update = _cfg_update.cfg_dict
|
25 |
+
for k, v in _cfg_update.items():
|
26 |
+
if isinstance(v, dict) and k in cfg:
|
27 |
+
vldm_cfg[k].update(v)
|
28 |
+
else:
|
29 |
+
vldm_cfg[k] = v
|
30 |
+
|
31 |
+
return prior_cfg, vldm_cfg
|
32 |
+
|
33 |
+
|
34 |
+
# def get prior and ldm config
|
35 |
+
def assign_vldm_vsr_mudule_cfg(cfg):
|
36 |
+
'''
|
37 |
+
'''
|
38 |
+
#
|
39 |
+
vldm_cfg = deepcopy(cfg)
|
40 |
+
vsr_cfg = deepcopy(cfg)
|
41 |
+
|
42 |
+
with open(cfg.vldm_cfg, 'r') as f:
|
43 |
+
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
44 |
+
# _cfg_update = _cfg_update.cfg_dict
|
45 |
+
for k, v in _cfg_update.items():
|
46 |
+
if isinstance(v, dict) and k in cfg:
|
47 |
+
vldm_cfg[k].update(v)
|
48 |
+
else:
|
49 |
+
vldm_cfg[k] = v
|
50 |
+
|
51 |
+
with open(cfg.vsr_cfg, 'r') as f:
|
52 |
+
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
53 |
+
# _cfg_update = _cfg_update.cfg_dict
|
54 |
+
for k, v in _cfg_update.items():
|
55 |
+
if isinstance(v, dict) and k in cfg:
|
56 |
+
vsr_cfg[k].update(v)
|
57 |
+
else:
|
58 |
+
vsr_cfg[k] = v
|
59 |
+
|
60 |
+
return vldm_cfg, vsr_cfg
|
61 |
+
|
62 |
+
|
63 |
+
# def get prior and ldm config
|
64 |
+
def assign_signle_cfg(cfg, _cfg_update, tname):
|
65 |
+
'''
|
66 |
+
'''
|
67 |
+
#
|
68 |
+
vldm_cfg = deepcopy(cfg)
|
69 |
+
if os.path.exists(_cfg_update[tname]):
|
70 |
+
with open(_cfg_update[tname], 'r') as f:
|
71 |
+
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
72 |
+
# _cfg_update = _cfg_update.cfg_dict
|
73 |
+
for k, v in _cfg_update.items():
|
74 |
+
if isinstance(v, dict) and k in cfg:
|
75 |
+
vldm_cfg[k].update(v)
|
76 |
+
else:
|
77 |
+
vldm_cfg[k] = v
|
78 |
+
return vldm_cfg
|
UniAnimate/utils/config.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import json
|
4 |
+
import copy
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import utils.logging as logging
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
class Config(object):
|
11 |
+
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
|
12 |
+
self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
|
13 |
+
if load:
|
14 |
+
self.args = self._parse_args()
|
15 |
+
logger.info("Loading config from {}.".format(self.args.cfg_file))
|
16 |
+
self.need_initialization = True
|
17 |
+
cfg_base = self._load_yaml(self.args) # self._initialize_cfg()
|
18 |
+
cfg_dict = self._load_yaml(self.args)
|
19 |
+
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
|
20 |
+
cfg_dict = self._update_from_args(cfg_dict)
|
21 |
+
self.cfg_dict = cfg_dict
|
22 |
+
self._update_dict(cfg_dict)
|
23 |
+
|
24 |
+
def _parse_args(self):
|
25 |
+
parser = argparse.ArgumentParser(
|
26 |
+
description="Argparser for configuring [code base name to think of] codebase"
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--cfg",
|
30 |
+
dest="cfg_file",
|
31 |
+
help="Path to the configuration file",
|
32 |
+
default='configs/UniAnimate_infer.yaml'
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--init_method",
|
36 |
+
help="Initialization method, includes TCP or shared file-system",
|
37 |
+
default="tcp://localhost:9999",
|
38 |
+
type=str,
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
'--debug',
|
42 |
+
action='store_true',
|
43 |
+
default=False,
|
44 |
+
help='Into debug information'
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"opts",
|
48 |
+
help="other configurations",
|
49 |
+
default=None,
|
50 |
+
nargs=argparse.REMAINDER)
|
51 |
+
return parser.parse_args()
|
52 |
+
|
53 |
+
def _path_join(self, path_list):
|
54 |
+
path = ""
|
55 |
+
for p in path_list:
|
56 |
+
path+= p + '/'
|
57 |
+
return path[:-1]
|
58 |
+
|
59 |
+
def _update_from_args(self, cfg_dict):
|
60 |
+
args = self.args
|
61 |
+
for var in vars(args):
|
62 |
+
cfg_dict[var] = getattr(args, var)
|
63 |
+
return cfg_dict
|
64 |
+
|
65 |
+
def _initialize_cfg(self):
|
66 |
+
if self.need_initialization:
|
67 |
+
self.need_initialization = False
|
68 |
+
if os.path.exists('./configs/base.yaml'):
|
69 |
+
with open("./configs/base.yaml", 'r') as f:
|
70 |
+
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
71 |
+
else:
|
72 |
+
with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
|
73 |
+
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
74 |
+
return cfg
|
75 |
+
|
76 |
+
def _load_yaml(self, args, file_name=""):
|
77 |
+
assert args.cfg_file is not None
|
78 |
+
if not file_name == "": # reading from base file
|
79 |
+
with open(file_name, 'r') as f:
|
80 |
+
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
81 |
+
else:
|
82 |
+
if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
|
83 |
+
args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
|
84 |
+
with open(args.cfg_file, 'r') as f:
|
85 |
+
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
86 |
+
file_name = args.cfg_file
|
87 |
+
|
88 |
+
if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
|
89 |
+
# return cfg if the base file is being accessed
|
90 |
+
cfg = self._merge_cfg_from_command_update(args, cfg)
|
91 |
+
return cfg
|
92 |
+
|
93 |
+
if "_BASE" in cfg.keys():
|
94 |
+
if cfg["_BASE"][1] == '.':
|
95 |
+
prev_count = cfg["_BASE"].count('..')
|
96 |
+
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
|
97 |
+
else:
|
98 |
+
cfg_base_file = cfg["_BASE"].replace(
|
99 |
+
"./",
|
100 |
+
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
|
101 |
+
)
|
102 |
+
cfg_base = self._load_yaml(args, cfg_base_file)
|
103 |
+
cfg = self._merge_cfg_from_base(cfg_base, cfg)
|
104 |
+
else:
|
105 |
+
if "_BASE_RUN" in cfg.keys():
|
106 |
+
if cfg["_BASE_RUN"][1] == '.':
|
107 |
+
prev_count = cfg["_BASE_RUN"].count('..')
|
108 |
+
cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
|
109 |
+
else:
|
110 |
+
cfg_base_file = cfg["_BASE_RUN"].replace(
|
111 |
+
"./",
|
112 |
+
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
|
113 |
+
)
|
114 |
+
cfg_base = self._load_yaml(args, cfg_base_file)
|
115 |
+
cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
|
116 |
+
if "_BASE_MODEL" in cfg.keys():
|
117 |
+
if cfg["_BASE_MODEL"][1] == '.':
|
118 |
+
prev_count = cfg["_BASE_MODEL"].count('..')
|
119 |
+
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
|
120 |
+
else:
|
121 |
+
cfg_base_file = cfg["_BASE_MODEL"].replace(
|
122 |
+
"./",
|
123 |
+
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
|
124 |
+
)
|
125 |
+
cfg_base = self._load_yaml(args, cfg_base_file)
|
126 |
+
cfg = self._merge_cfg_from_base(cfg_base, cfg)
|
127 |
+
cfg = self._merge_cfg_from_command(args, cfg)
|
128 |
+
return cfg
|
129 |
+
|
130 |
+
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
|
131 |
+
for k,v in cfg_new.items():
|
132 |
+
if k in cfg_base.keys():
|
133 |
+
if isinstance(v, dict):
|
134 |
+
self._merge_cfg_from_base(cfg_base[k], v)
|
135 |
+
else:
|
136 |
+
cfg_base[k] = v
|
137 |
+
else:
|
138 |
+
if "BASE" not in k or preserve_base:
|
139 |
+
cfg_base[k] = v
|
140 |
+
return cfg_base
|
141 |
+
|
142 |
+
def _merge_cfg_from_command_update(self, args, cfg):
|
143 |
+
if len(args.opts) == 0:
|
144 |
+
return cfg
|
145 |
+
|
146 |
+
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
|
147 |
+
args.opts, len(args.opts)
|
148 |
+
)
|
149 |
+
keys = args.opts[0::2]
|
150 |
+
vals = args.opts[1::2]
|
151 |
+
|
152 |
+
for key, val in zip(keys, vals):
|
153 |
+
cfg[key] = val
|
154 |
+
|
155 |
+
return cfg
|
156 |
+
|
157 |
+
def _merge_cfg_from_command(self, args, cfg):
|
158 |
+
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
|
159 |
+
args.opts, len(args.opts)
|
160 |
+
)
|
161 |
+
keys = args.opts[0::2]
|
162 |
+
vals = args.opts[1::2]
|
163 |
+
|
164 |
+
# maximum supported depth 3
|
165 |
+
for idx, key in enumerate(keys):
|
166 |
+
key_split = key.split('.')
|
167 |
+
assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
|
168 |
+
len(key_split)
|
169 |
+
)
|
170 |
+
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
|
171 |
+
key_split[0]
|
172 |
+
)
|
173 |
+
if len(key_split) == 2:
|
174 |
+
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
|
175 |
+
key
|
176 |
+
)
|
177 |
+
elif len(key_split) == 3:
|
178 |
+
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
|
179 |
+
key
|
180 |
+
)
|
181 |
+
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
|
182 |
+
key
|
183 |
+
)
|
184 |
+
elif len(key_split) == 4:
|
185 |
+
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
|
186 |
+
key
|
187 |
+
)
|
188 |
+
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
|
189 |
+
key
|
190 |
+
)
|
191 |
+
assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
|
192 |
+
key
|
193 |
+
)
|
194 |
+
if len(key_split) == 1:
|
195 |
+
cfg[key_split[0]] = vals[idx]
|
196 |
+
elif len(key_split) == 2:
|
197 |
+
cfg[key_split[0]][key_split[1]] = vals[idx]
|
198 |
+
elif len(key_split) == 3:
|
199 |
+
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
|
200 |
+
elif len(key_split) == 4:
|
201 |
+
cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
|
202 |
+
return cfg
|
203 |
+
|
204 |
+
def _update_dict(self, cfg_dict):
|
205 |
+
def recur(key, elem):
|
206 |
+
if type(elem) is dict:
|
207 |
+
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
|
208 |
+
else:
|
209 |
+
if type(elem) is str and elem[1:3]=="e-":
|
210 |
+
elem = float(elem)
|
211 |
+
return key, elem
|
212 |
+
dic = dict(recur(k, v) for k, v in cfg_dict.items())
|
213 |
+
self.__dict__.update(dic)
|
214 |
+
|
215 |
+
def get_args(self):
|
216 |
+
return self.args
|
217 |
+
|
218 |
+
def __repr__(self):
|
219 |
+
return "{}\n".format(self.dump())
|
220 |
+
|
221 |
+
def dump(self):
|
222 |
+
return json.dumps(self.cfg_dict, indent=2)
|
223 |
+
|
224 |
+
def deep_copy(self):
|
225 |
+
return copy.deepcopy(self)
|
226 |
+
|
227 |
+
if __name__ == '__main__':
|
228 |
+
# debug
|
229 |
+
cfg = Config(load=True)
|
230 |
+
print(cfg.DATA)
|
UniAnimate/utils/distributed.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.distributed as dist
|
7 |
+
import functools
|
8 |
+
import pickle
|
9 |
+
import numpy as np
|
10 |
+
from collections import OrderedDict
|
11 |
+
from torch.autograd import Function
|
12 |
+
|
13 |
+
__all__ = ['is_dist_initialized',
|
14 |
+
'get_world_size',
|
15 |
+
'get_rank',
|
16 |
+
'new_group',
|
17 |
+
'destroy_process_group',
|
18 |
+
'barrier',
|
19 |
+
'broadcast',
|
20 |
+
'all_reduce',
|
21 |
+
'reduce',
|
22 |
+
'gather',
|
23 |
+
'all_gather',
|
24 |
+
'reduce_dict',
|
25 |
+
'get_global_gloo_group',
|
26 |
+
'generalized_all_gather',
|
27 |
+
'generalized_gather',
|
28 |
+
'scatter',
|
29 |
+
'reduce_scatter',
|
30 |
+
'send',
|
31 |
+
'recv',
|
32 |
+
'isend',
|
33 |
+
'irecv',
|
34 |
+
'shared_random_seed',
|
35 |
+
'diff_all_gather',
|
36 |
+
'diff_all_reduce',
|
37 |
+
'diff_scatter',
|
38 |
+
'diff_copy',
|
39 |
+
'spherical_kmeans',
|
40 |
+
'sinkhorn']
|
41 |
+
|
42 |
+
#-------------------------------- Distributed operations --------------------------------#
|
43 |
+
|
44 |
+
def is_dist_initialized():
|
45 |
+
return dist.is_available() and dist.is_initialized()
|
46 |
+
|
47 |
+
def get_world_size(group=None):
|
48 |
+
return dist.get_world_size(group) if is_dist_initialized() else 1
|
49 |
+
|
50 |
+
def get_rank(group=None):
|
51 |
+
return dist.get_rank(group) if is_dist_initialized() else 0
|
52 |
+
|
53 |
+
def new_group(ranks=None, **kwargs):
|
54 |
+
if is_dist_initialized():
|
55 |
+
return dist.new_group(ranks, **kwargs)
|
56 |
+
return None
|
57 |
+
|
58 |
+
def destroy_process_group():
|
59 |
+
if is_dist_initialized():
|
60 |
+
dist.destroy_process_group()
|
61 |
+
|
62 |
+
def barrier(group=None, **kwargs):
|
63 |
+
if get_world_size(group) > 1:
|
64 |
+
dist.barrier(group, **kwargs)
|
65 |
+
|
66 |
+
def broadcast(tensor, src, group=None, **kwargs):
|
67 |
+
if get_world_size(group) > 1:
|
68 |
+
return dist.broadcast(tensor, src, group, **kwargs)
|
69 |
+
|
70 |
+
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
|
71 |
+
if get_world_size(group) > 1:
|
72 |
+
return dist.all_reduce(tensor, op, group, **kwargs)
|
73 |
+
|
74 |
+
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
|
75 |
+
if get_world_size(group) > 1:
|
76 |
+
return dist.reduce(tensor, dst, op, group, **kwargs)
|
77 |
+
|
78 |
+
def gather(tensor, dst=0, group=None, **kwargs):
|
79 |
+
rank = get_rank() # global rank
|
80 |
+
world_size = get_world_size(group)
|
81 |
+
if world_size == 1:
|
82 |
+
return [tensor]
|
83 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None
|
84 |
+
dist.gather(tensor, tensor_list, dst, group, **kwargs)
|
85 |
+
return tensor_list
|
86 |
+
|
87 |
+
def all_gather(tensor, uniform_size=True, group=None, **kwargs):
|
88 |
+
world_size = get_world_size(group)
|
89 |
+
if world_size == 1:
|
90 |
+
return [tensor]
|
91 |
+
assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()'
|
92 |
+
|
93 |
+
if uniform_size:
|
94 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
95 |
+
dist.all_gather(tensor_list, tensor, group, **kwargs)
|
96 |
+
return tensor_list
|
97 |
+
else:
|
98 |
+
# collect tensor shapes across GPUs
|
99 |
+
shape = tuple(tensor.shape)
|
100 |
+
shape_list = generalized_all_gather(shape, group)
|
101 |
+
|
102 |
+
# flatten the tensor
|
103 |
+
tensor = tensor.reshape(-1)
|
104 |
+
size = int(np.prod(shape))
|
105 |
+
size_list = [int(np.prod(u)) for u in shape_list]
|
106 |
+
max_size = max(size_list)
|
107 |
+
|
108 |
+
# pad to maximum size
|
109 |
+
if size != max_size:
|
110 |
+
padding = tensor.new_zeros(max_size - size)
|
111 |
+
tensor = torch.cat([tensor, padding], dim=0)
|
112 |
+
|
113 |
+
# all_gather
|
114 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
115 |
+
dist.all_gather(tensor_list, tensor, group, **kwargs)
|
116 |
+
|
117 |
+
# reshape tensors
|
118 |
+
tensor_list = [t[:n].view(s) for t, n, s in zip(
|
119 |
+
tensor_list, size_list, shape_list)]
|
120 |
+
return tensor_list
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
|
124 |
+
assert reduction in ['mean', 'sum']
|
125 |
+
world_size = get_world_size(group)
|
126 |
+
if world_size == 1:
|
127 |
+
return input_dict
|
128 |
+
|
129 |
+
# ensure that the orders of keys are consistent across processes
|
130 |
+
if isinstance(input_dict, OrderedDict):
|
131 |
+
keys = list(input_dict.keys)
|
132 |
+
else:
|
133 |
+
keys = sorted(input_dict.keys())
|
134 |
+
vals = [input_dict[key] for key in keys]
|
135 |
+
vals = torch.stack(vals, dim=0)
|
136 |
+
dist.reduce(vals, dst=0, group=group, **kwargs)
|
137 |
+
if dist.get_rank(group) == 0 and reduction == 'mean':
|
138 |
+
vals /= world_size
|
139 |
+
dist.broadcast(vals, src=0, group=group, **kwargs)
|
140 |
+
reduced_dict = type(input_dict)([
|
141 |
+
(key, val) for key, val in zip(keys, vals)])
|
142 |
+
return reduced_dict
|
143 |
+
|
144 |
+
@functools.lru_cache()
|
145 |
+
def get_global_gloo_group():
|
146 |
+
backend = dist.get_backend()
|
147 |
+
assert backend in ['gloo', 'nccl']
|
148 |
+
if backend == 'nccl':
|
149 |
+
return dist.new_group(backend='gloo')
|
150 |
+
else:
|
151 |
+
return dist.group.WORLD
|
152 |
+
|
153 |
+
def _serialize_to_tensor(data, group):
|
154 |
+
backend = dist.get_backend(group)
|
155 |
+
assert backend in ['gloo', 'nccl']
|
156 |
+
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
|
157 |
+
|
158 |
+
buffer = pickle.dumps(data)
|
159 |
+
if len(buffer) > 1024 ** 3:
|
160 |
+
logger = logging.getLogger(__name__)
|
161 |
+
logger.warning(
|
162 |
+
'Rank {} trying to all-gather {:.2f} GB of data on device'
|
163 |
+
'{}'.format(get_rank(), len(buffer) / (1024 ** 3), device))
|
164 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
165 |
+
tensor = torch.ByteTensor(storage).to(device=device)
|
166 |
+
return tensor
|
167 |
+
|
168 |
+
def _pad_to_largest_tensor(tensor, group):
|
169 |
+
world_size = dist.get_world_size(group=group)
|
170 |
+
assert world_size >= 1, \
|
171 |
+
'gather/all_gather must be called from ranks within' \
|
172 |
+
'the give group!'
|
173 |
+
local_size = torch.tensor(
|
174 |
+
[tensor.numel()], dtype=torch.int64, device=tensor.device)
|
175 |
+
size_list = [torch.zeros(
|
176 |
+
[1], dtype=torch.int64, device=tensor.device)
|
177 |
+
for _ in range(world_size)]
|
178 |
+
|
179 |
+
# gather tensors and compute the maximum size
|
180 |
+
dist.all_gather(size_list, local_size, group=group)
|
181 |
+
size_list = [int(size.item()) for size in size_list]
|
182 |
+
max_size = max(size_list)
|
183 |
+
|
184 |
+
# pad tensors to the same size
|
185 |
+
if local_size != max_size:
|
186 |
+
padding = torch.zeros(
|
187 |
+
(max_size - local_size, ),
|
188 |
+
dtype=torch.uint8, device=tensor.device)
|
189 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
190 |
+
return size_list, tensor
|
191 |
+
|
192 |
+
def generalized_all_gather(data, group=None):
|
193 |
+
if get_world_size(group) == 1:
|
194 |
+
return [data]
|
195 |
+
if group is None:
|
196 |
+
group = get_global_gloo_group()
|
197 |
+
|
198 |
+
tensor = _serialize_to_tensor(data, group)
|
199 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
200 |
+
max_size = max(size_list)
|
201 |
+
|
202 |
+
# receiving tensors from all ranks
|
203 |
+
tensor_list = [torch.empty(
|
204 |
+
(max_size, ), dtype=torch.uint8, device=tensor.device)
|
205 |
+
for _ in size_list]
|
206 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
207 |
+
|
208 |
+
data_list = []
|
209 |
+
for size, tensor in zip(size_list, tensor_list):
|
210 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
211 |
+
data_list.append(pickle.loads(buffer))
|
212 |
+
return data_list
|
213 |
+
|
214 |
+
def generalized_gather(data, dst=0, group=None):
|
215 |
+
world_size = get_world_size(group)
|
216 |
+
if world_size == 1:
|
217 |
+
return [data]
|
218 |
+
if group is None:
|
219 |
+
group = get_global_gloo_group()
|
220 |
+
rank = dist.get_rank() # global rank
|
221 |
+
|
222 |
+
tensor = _serialize_to_tensor(data, group)
|
223 |
+
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
224 |
+
|
225 |
+
# receiving tensors from all ranks to dst
|
226 |
+
if rank == dst:
|
227 |
+
max_size = max(size_list)
|
228 |
+
tensor_list = [torch.empty(
|
229 |
+
(max_size, ), dtype=torch.uint8, device=tensor.device)
|
230 |
+
for _ in size_list]
|
231 |
+
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
232 |
+
|
233 |
+
data_list = []
|
234 |
+
for size, tensor in zip(size_list, tensor_list):
|
235 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
236 |
+
data_list.append(pickle.loads(buffer))
|
237 |
+
return data_list
|
238 |
+
else:
|
239 |
+
dist.gather(tensor, [], dst=dst, group=group)
|
240 |
+
return []
|
241 |
+
|
242 |
+
def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
|
243 |
+
r"""NOTE: only supports CPU tensor communication.
|
244 |
+
"""
|
245 |
+
if get_world_size(group) > 1:
|
246 |
+
return dist.scatter(data, scatter_list, src, group, **kwargs)
|
247 |
+
|
248 |
+
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs):
|
249 |
+
if get_world_size(group) > 1:
|
250 |
+
return dist.reduce_scatter(output, input_list, op, group, **kwargs)
|
251 |
+
|
252 |
+
def send(tensor, dst, group=None, **kwargs):
|
253 |
+
if get_world_size(group) > 1:
|
254 |
+
assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()'
|
255 |
+
return dist.send(tensor, dst, group, **kwargs)
|
256 |
+
|
257 |
+
def recv(tensor, src=None, group=None, **kwargs):
|
258 |
+
if get_world_size(group) > 1:
|
259 |
+
assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()'
|
260 |
+
return dist.recv(tensor, src, group, **kwargs)
|
261 |
+
|
262 |
+
def isend(tensor, dst, group=None, **kwargs):
|
263 |
+
if get_world_size(group) > 1:
|
264 |
+
assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()'
|
265 |
+
return dist.isend(tensor, dst, group, **kwargs)
|
266 |
+
|
267 |
+
def irecv(tensor, src=None, group=None, **kwargs):
|
268 |
+
if get_world_size(group) > 1:
|
269 |
+
assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()'
|
270 |
+
return dist.irecv(tensor, src, group, **kwargs)
|
271 |
+
|
272 |
+
def shared_random_seed(group=None):
|
273 |
+
seed = np.random.randint(2 ** 31)
|
274 |
+
all_seeds = generalized_all_gather(seed, group)
|
275 |
+
return all_seeds[0]
|
276 |
+
|
277 |
+
#-------------------------------- Differentiable operations --------------------------------#
|
278 |
+
|
279 |
+
def _all_gather(x):
|
280 |
+
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
|
281 |
+
return x
|
282 |
+
rank = dist.get_rank()
|
283 |
+
world_size = dist.get_world_size()
|
284 |
+
tensors = [torch.empty_like(x) for _ in range(world_size)]
|
285 |
+
tensors[rank] = x
|
286 |
+
dist.all_gather(tensors, x)
|
287 |
+
return torch.cat(tensors, dim=0).contiguous()
|
288 |
+
|
289 |
+
def _all_reduce(x):
|
290 |
+
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
|
291 |
+
return x
|
292 |
+
dist.all_reduce(x)
|
293 |
+
return x
|
294 |
+
|
295 |
+
def _split(x):
|
296 |
+
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
|
297 |
+
return x
|
298 |
+
rank = dist.get_rank()
|
299 |
+
world_size = dist.get_world_size()
|
300 |
+
return x.chunk(world_size, dim=0)[rank].contiguous()
|
301 |
+
|
302 |
+
class DiffAllGather(Function):
|
303 |
+
r"""Differentiable all-gather.
|
304 |
+
"""
|
305 |
+
@staticmethod
|
306 |
+
def symbolic(graph, input):
|
307 |
+
return _all_gather(input)
|
308 |
+
|
309 |
+
@staticmethod
|
310 |
+
def forward(ctx, input):
|
311 |
+
return _all_gather(input)
|
312 |
+
|
313 |
+
@staticmethod
|
314 |
+
def backward(ctx, grad_output):
|
315 |
+
return _split(grad_output)
|
316 |
+
|
317 |
+
class DiffAllReduce(Function):
|
318 |
+
r"""Differentiable all-reducd.
|
319 |
+
"""
|
320 |
+
@staticmethod
|
321 |
+
def symbolic(graph, input):
|
322 |
+
return _all_reduce(input)
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def forward(ctx, input):
|
326 |
+
return _all_reduce(input)
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def backward(ctx, grad_output):
|
330 |
+
return grad_output
|
331 |
+
|
332 |
+
class DiffScatter(Function):
|
333 |
+
r"""Differentiable scatter.
|
334 |
+
"""
|
335 |
+
@staticmethod
|
336 |
+
def symbolic(graph, input):
|
337 |
+
return _split(input)
|
338 |
+
|
339 |
+
@staticmethod
|
340 |
+
def symbolic(ctx, input):
|
341 |
+
return _split(input)
|
342 |
+
|
343 |
+
@staticmethod
|
344 |
+
def backward(ctx, grad_output):
|
345 |
+
return _all_gather(grad_output)
|
346 |
+
|
347 |
+
class DiffCopy(Function):
|
348 |
+
r"""Differentiable copy that reduces all gradients during backward.
|
349 |
+
"""
|
350 |
+
@staticmethod
|
351 |
+
def symbolic(graph, input):
|
352 |
+
return input
|
353 |
+
|
354 |
+
@staticmethod
|
355 |
+
def forward(ctx, input):
|
356 |
+
return input
|
357 |
+
|
358 |
+
@staticmethod
|
359 |
+
def backward(ctx, grad_output):
|
360 |
+
return _all_reduce(grad_output)
|
361 |
+
|
362 |
+
diff_all_gather = DiffAllGather.apply
|
363 |
+
diff_all_reduce = DiffAllReduce.apply
|
364 |
+
diff_scatter = DiffScatter.apply
|
365 |
+
diff_copy = DiffCopy.apply
|
366 |
+
|
367 |
+
#-------------------------------- Distributed algorithms --------------------------------#
|
368 |
+
|
369 |
+
@torch.no_grad()
|
370 |
+
def spherical_kmeans(feats, num_clusters, num_iters=10):
|
371 |
+
k, n, c = num_clusters, *feats.size()
|
372 |
+
ones = feats.new_ones(n, dtype=torch.long)
|
373 |
+
|
374 |
+
# distributed settings
|
375 |
+
rank = get_rank()
|
376 |
+
world_size = get_world_size()
|
377 |
+
|
378 |
+
# init clusters
|
379 |
+
rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
|
380 |
+
clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
|
381 |
+
|
382 |
+
# variables
|
383 |
+
new_clusters = feats.new_zeros(k, c)
|
384 |
+
counts = feats.new_zeros(k, dtype=torch.long)
|
385 |
+
|
386 |
+
# iterative Expectation-Maximization
|
387 |
+
for step in range(num_iters + 1):
|
388 |
+
# Expectation step
|
389 |
+
simmat = torch.mm(feats, clusters.t())
|
390 |
+
scores, assigns = simmat.max(dim=1)
|
391 |
+
if step == num_iters:
|
392 |
+
break
|
393 |
+
|
394 |
+
# Maximization step
|
395 |
+
new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats)
|
396 |
+
all_reduce(new_clusters)
|
397 |
+
|
398 |
+
counts.zero_()
|
399 |
+
counts.index_add_(0, assigns, ones)
|
400 |
+
all_reduce(counts)
|
401 |
+
|
402 |
+
mask = (counts > 0)
|
403 |
+
clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
|
404 |
+
clusters = F.normalize(clusters, p=2, dim=1)
|
405 |
+
return clusters, assigns, scores
|
406 |
+
|
407 |
+
@torch.no_grad()
|
408 |
+
def sinkhorn(Q, eps=0.5, num_iters=3):
|
409 |
+
# normalize Q
|
410 |
+
Q = torch.exp(Q / eps).t()
|
411 |
+
sum_Q = Q.sum()
|
412 |
+
all_reduce(sum_Q)
|
413 |
+
Q /= sum_Q
|
414 |
+
|
415 |
+
# variables
|
416 |
+
n, m = Q.size()
|
417 |
+
u = Q.new_zeros(n)
|
418 |
+
r = Q.new_ones(n) / n
|
419 |
+
c = Q.new_ones(m) / (m * get_world_size())
|
420 |
+
|
421 |
+
# iterative update
|
422 |
+
cur_sum = Q.sum(dim=1)
|
423 |
+
all_reduce(cur_sum)
|
424 |
+
for i in range(num_iters):
|
425 |
+
u = cur_sum
|
426 |
+
Q *= (r / u).unsqueeze(1)
|
427 |
+
Q *= (c / Q.sum(dim=0)).unsqueeze(0)
|
428 |
+
cur_sum = Q.sum(dim=1)
|
429 |
+
all_reduce(cur_sum)
|
430 |
+
return (Q / Q.sum(dim=0, keepdim=True)).t().float()
|
UniAnimate/utils/logging.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
"""Logging."""
|
5 |
+
|
6 |
+
import builtins
|
7 |
+
import decimal
|
8 |
+
import functools
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import simplejson
|
13 |
+
# from fvcore.common.file_io import PathManager
|
14 |
+
|
15 |
+
import utils.distributed as du
|
16 |
+
|
17 |
+
|
18 |
+
def _suppress_print():
|
19 |
+
"""
|
20 |
+
Suppresses printing from the current process.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
|
24 |
+
pass
|
25 |
+
|
26 |
+
builtins.print = print_pass
|
27 |
+
|
28 |
+
|
29 |
+
# @functools.lru_cache(maxsize=None)
|
30 |
+
# def _cached_log_stream(filename):
|
31 |
+
# return PathManager.open(filename, "a")
|
32 |
+
|
33 |
+
|
34 |
+
def setup_logging(cfg, log_file):
|
35 |
+
"""
|
36 |
+
Sets up the logging for multiple processes. Only enable the logging for the
|
37 |
+
master process, and suppress logging for the non-master processes.
|
38 |
+
"""
|
39 |
+
if du.is_master_proc():
|
40 |
+
# Enable logging for the master process.
|
41 |
+
logging.root.handlers = []
|
42 |
+
else:
|
43 |
+
# Suppress logging for non-master processes.
|
44 |
+
_suppress_print()
|
45 |
+
|
46 |
+
logger = logging.getLogger()
|
47 |
+
logger.setLevel(logging.INFO)
|
48 |
+
logger.propagate = False
|
49 |
+
plain_formatter = logging.Formatter(
|
50 |
+
"[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
|
51 |
+
datefmt="%m/%d %H:%M:%S",
|
52 |
+
)
|
53 |
+
|
54 |
+
if du.is_master_proc():
|
55 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
56 |
+
ch.setLevel(logging.DEBUG)
|
57 |
+
ch.setFormatter(plain_formatter)
|
58 |
+
logger.addHandler(ch)
|
59 |
+
|
60 |
+
if log_file is not None and du.is_master_proc(du.get_world_size()):
|
61 |
+
filename = os.path.join(cfg.OUTPUT_DIR, log_file)
|
62 |
+
fh = logging.FileHandler(filename)
|
63 |
+
fh.setLevel(logging.DEBUG)
|
64 |
+
fh.setFormatter(plain_formatter)
|
65 |
+
logger.addHandler(fh)
|
66 |
+
|
67 |
+
|
68 |
+
def get_logger(name):
|
69 |
+
"""
|
70 |
+
Retrieve the logger with the specified name or, if name is None, return a
|
71 |
+
logger which is the root logger of the hierarchy.
|
72 |
+
Args:
|
73 |
+
name (string): name of the logger.
|
74 |
+
"""
|
75 |
+
return logging.getLogger(name)
|
76 |
+
|
77 |
+
|
78 |
+
def log_json_stats(stats):
|
79 |
+
"""
|
80 |
+
Logs json stats.
|
81 |
+
Args:
|
82 |
+
stats (dict): a dictionary of statistical information to log.
|
83 |
+
"""
|
84 |
+
stats = {
|
85 |
+
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
|
86 |
+
for k, v in stats.items()
|
87 |
+
}
|
88 |
+
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
|
89 |
+
logger = get_logger(__name__)
|
90 |
+
logger.info("{:s}".format(json_stats))
|
UniAnimate/utils/mp4_to_gif.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
# source_mp4_dir = "outputs/UniAnimate_infer"
|
6 |
+
# target_gif_dir = "outputs/UniAnimate_infer_gif"
|
7 |
+
|
8 |
+
source_mp4_dir = "outputs/UniAnimate_infer_long"
|
9 |
+
target_gif_dir = "outputs/UniAnimate_infer_long_gif"
|
10 |
+
|
11 |
+
os.makedirs(target_gif_dir, exist_ok=True)
|
12 |
+
for video in os.listdir(source_mp4_dir):
|
13 |
+
video_dir = os.path.join(source_mp4_dir, video)
|
14 |
+
gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif"))
|
15 |
+
cmd = f'ffmpeg -i {video_dir} {gif_dir}'
|
16 |
+
os.system(cmd)
|
UniAnimate/utils/multi_port.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
from contextlib import closing
|
3 |
+
|
4 |
+
def find_free_port():
|
5 |
+
""" https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
|
6 |
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
7 |
+
s.bind(('', 0))
|
8 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
9 |
+
return str(s.getsockname()[1])
|
UniAnimate/utils/optim/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .lr_scheduler import *
|
2 |
+
from .adafactor import *
|
UniAnimate/utils/optim/adafactor.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.optim import Optimizer
|
4 |
+
from torch.optim.lr_scheduler import LambdaLR
|
5 |
+
|
6 |
+
__all__ = ['Adafactor']
|
7 |
+
|
8 |
+
class Adafactor(Optimizer):
|
9 |
+
"""
|
10 |
+
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
11 |
+
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
12 |
+
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
13 |
+
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
14 |
+
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
15 |
+
`relative_step=False`.
|
16 |
+
Arguments:
|
17 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
18 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
19 |
+
lr (`float`, *optional*):
|
20 |
+
The external learning rate.
|
21 |
+
eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)):
|
22 |
+
Regularization constants for square gradient and parameter scale respectively
|
23 |
+
clip_threshold (`float`, *optional*, defaults 1.0):
|
24 |
+
Threshold of root mean square of final gradient update
|
25 |
+
decay_rate (`float`, *optional*, defaults to -0.8):
|
26 |
+
Coefficient used to compute running averages of square
|
27 |
+
beta1 (`float`, *optional*):
|
28 |
+
Coefficient used for computing running averages of gradient
|
29 |
+
weight_decay (`float`, *optional*, defaults to 0):
|
30 |
+
Weight decay (L2 penalty)
|
31 |
+
scale_parameter (`bool`, *optional*, defaults to `True`):
|
32 |
+
If True, learning rate is scaled by root mean square
|
33 |
+
relative_step (`bool`, *optional*, defaults to `True`):
|
34 |
+
If True, time-dependent learning rate is computed instead of external learning rate
|
35 |
+
warmup_init (`bool`, *optional*, defaults to `False`):
|
36 |
+
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
37 |
+
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
38 |
+
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
39 |
+
- Training without LR warmup or clip_threshold is not recommended.
|
40 |
+
- use scheduled LR warm-up to fixed LR
|
41 |
+
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
42 |
+
- Disable relative updates
|
43 |
+
- Use scale_parameter=False
|
44 |
+
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
45 |
+
Example:
|
46 |
+
```python
|
47 |
+
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
48 |
+
```
|
49 |
+
Others reported the following combination to work well:
|
50 |
+
```python
|
51 |
+
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
52 |
+
```
|
53 |
+
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
54 |
+
scheduler as following:
|
55 |
+
```python
|
56 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
57 |
+
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
58 |
+
lr_scheduler = AdafactorSchedule(optimizer)
|
59 |
+
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
60 |
+
```
|
61 |
+
Usage:
|
62 |
+
```python
|
63 |
+
# replace AdamW with Adafactor
|
64 |
+
optimizer = Adafactor(
|
65 |
+
model.parameters(),
|
66 |
+
lr=1e-3,
|
67 |
+
eps=(1e-30, 1e-3),
|
68 |
+
clip_threshold=1.0,
|
69 |
+
decay_rate=-0.8,
|
70 |
+
beta1=None,
|
71 |
+
weight_decay=0.0,
|
72 |
+
relative_step=False,
|
73 |
+
scale_parameter=False,
|
74 |
+
warmup_init=False,
|
75 |
+
)
|
76 |
+
```"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
params,
|
81 |
+
lr=None,
|
82 |
+
eps=(1e-30, 1e-3),
|
83 |
+
clip_threshold=1.0,
|
84 |
+
decay_rate=-0.8,
|
85 |
+
beta1=None,
|
86 |
+
weight_decay=0.0,
|
87 |
+
scale_parameter=True,
|
88 |
+
relative_step=True,
|
89 |
+
warmup_init=False,
|
90 |
+
):
|
91 |
+
r"""require_version("torch>=1.5.0") # add_ with alpha
|
92 |
+
"""
|
93 |
+
if lr is not None and relative_step:
|
94 |
+
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
|
95 |
+
if warmup_init and not relative_step:
|
96 |
+
raise ValueError("`warmup_init=True` requires `relative_step=True`")
|
97 |
+
|
98 |
+
defaults = dict(
|
99 |
+
lr=lr,
|
100 |
+
eps=eps,
|
101 |
+
clip_threshold=clip_threshold,
|
102 |
+
decay_rate=decay_rate,
|
103 |
+
beta1=beta1,
|
104 |
+
weight_decay=weight_decay,
|
105 |
+
scale_parameter=scale_parameter,
|
106 |
+
relative_step=relative_step,
|
107 |
+
warmup_init=warmup_init,
|
108 |
+
)
|
109 |
+
super().__init__(params, defaults)
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def _get_lr(param_group, param_state):
|
113 |
+
rel_step_sz = param_group["lr"]
|
114 |
+
if param_group["relative_step"]:
|
115 |
+
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
116 |
+
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
117 |
+
param_scale = 1.0
|
118 |
+
if param_group["scale_parameter"]:
|
119 |
+
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
120 |
+
return param_scale * rel_step_sz
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def _get_options(param_group, param_shape):
|
124 |
+
factored = len(param_shape) >= 2
|
125 |
+
use_first_moment = param_group["beta1"] is not None
|
126 |
+
return factored, use_first_moment
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def _rms(tensor):
|
130 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
134 |
+
# copy from fairseq's adafactor implementation:
|
135 |
+
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
136 |
+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
137 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
138 |
+
return torch.mul(r_factor, c_factor)
|
139 |
+
|
140 |
+
def step(self, closure=None):
|
141 |
+
"""
|
142 |
+
Performs a single optimization step
|
143 |
+
Arguments:
|
144 |
+
closure (callable, optional): A closure that reevaluates the model
|
145 |
+
and returns the loss.
|
146 |
+
"""
|
147 |
+
loss = None
|
148 |
+
if closure is not None:
|
149 |
+
loss = closure()
|
150 |
+
|
151 |
+
for group in self.param_groups:
|
152 |
+
for p in group["params"]:
|
153 |
+
if p.grad is None:
|
154 |
+
continue
|
155 |
+
grad = p.grad.data
|
156 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
157 |
+
grad = grad.float()
|
158 |
+
if grad.is_sparse:
|
159 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
160 |
+
|
161 |
+
state = self.state[p]
|
162 |
+
grad_shape = grad.shape
|
163 |
+
|
164 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
165 |
+
# State Initialization
|
166 |
+
if len(state) == 0:
|
167 |
+
state["step"] = 0
|
168 |
+
|
169 |
+
if use_first_moment:
|
170 |
+
# Exponential moving average of gradient values
|
171 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
172 |
+
if factored:
|
173 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
174 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
175 |
+
else:
|
176 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
177 |
+
|
178 |
+
state["RMS"] = 0
|
179 |
+
else:
|
180 |
+
if use_first_moment:
|
181 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
182 |
+
if factored:
|
183 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
184 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
185 |
+
else:
|
186 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
187 |
+
|
188 |
+
p_data_fp32 = p.data
|
189 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
190 |
+
p_data_fp32 = p_data_fp32.float()
|
191 |
+
|
192 |
+
state["step"] += 1
|
193 |
+
state["RMS"] = self._rms(p_data_fp32)
|
194 |
+
lr = self._get_lr(group, state)
|
195 |
+
|
196 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
197 |
+
update = (grad**2) + group["eps"][0]
|
198 |
+
if factored:
|
199 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
200 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
201 |
+
|
202 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
203 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
204 |
+
|
205 |
+
# Approximation of exponential moving average of square of gradient
|
206 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
207 |
+
update.mul_(grad)
|
208 |
+
else:
|
209 |
+
exp_avg_sq = state["exp_avg_sq"]
|
210 |
+
|
211 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
212 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
213 |
+
|
214 |
+
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
215 |
+
update.mul_(lr)
|
216 |
+
|
217 |
+
if use_first_moment:
|
218 |
+
exp_avg = state["exp_avg"]
|
219 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
220 |
+
update = exp_avg
|
221 |
+
|
222 |
+
if group["weight_decay"] != 0:
|
223 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
224 |
+
|
225 |
+
p_data_fp32.add_(-update)
|
226 |
+
|
227 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
228 |
+
p.data.copy_(p_data_fp32)
|
229 |
+
|
230 |
+
return loss
|
UniAnimate/utils/optim/lr_scheduler.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
3 |
+
|
4 |
+
__all__ = ['AnnealingLR']
|
5 |
+
|
6 |
+
class AnnealingLR(_LRScheduler):
|
7 |
+
|
8 |
+
def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1):
|
9 |
+
assert decay_mode in ['linear', 'cosine', 'none']
|
10 |
+
self.optimizer = optimizer
|
11 |
+
self.base_lr = base_lr
|
12 |
+
self.warmup_steps = warmup_steps
|
13 |
+
self.total_steps = total_steps
|
14 |
+
self.decay_mode = decay_mode
|
15 |
+
self.min_lr = min_lr
|
16 |
+
self.current_step = last_step + 1
|
17 |
+
self.step(self.current_step)
|
18 |
+
|
19 |
+
def get_lr(self):
|
20 |
+
if self.warmup_steps > 0 and self.current_step <= self.warmup_steps:
|
21 |
+
return self.base_lr * self.current_step / self.warmup_steps
|
22 |
+
else:
|
23 |
+
ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
24 |
+
ratio = min(1.0, max(0.0, ratio))
|
25 |
+
if self.decay_mode == 'linear':
|
26 |
+
return self.base_lr * (1 - ratio)
|
27 |
+
elif self.decay_mode == 'cosine':
|
28 |
+
return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0
|
29 |
+
else:
|
30 |
+
return self.base_lr
|
31 |
+
|
32 |
+
def step(self, current_step=None):
|
33 |
+
if current_step is None:
|
34 |
+
current_step = self.current_step + 1
|
35 |
+
self.current_step = current_step
|
36 |
+
new_lr = max(self.min_lr, self.get_lr())
|
37 |
+
if isinstance(self.optimizer, list):
|
38 |
+
for o in self.optimizer:
|
39 |
+
for group in o.param_groups:
|
40 |
+
group['lr'] = new_lr
|
41 |
+
else:
|
42 |
+
for group in self.optimizer.param_groups:
|
43 |
+
group['lr'] = new_lr
|
44 |
+
|
45 |
+
def state_dict(self):
|
46 |
+
return {
|
47 |
+
'base_lr': self.base_lr,
|
48 |
+
'warmup_steps': self.warmup_steps,
|
49 |
+
'total_steps': self.total_steps,
|
50 |
+
'decay_mode': self.decay_mode,
|
51 |
+
'current_step': self.current_step}
|
52 |
+
|
53 |
+
def load_state_dict(self, state_dict):
|
54 |
+
self.base_lr = state_dict['base_lr']
|
55 |
+
self.warmup_steps = state_dict['warmup_steps']
|
56 |
+
self.total_steps = state_dict['total_steps']
|
57 |
+
self.decay_mode = state_dict['decay_mode']
|
58 |
+
self.current_step = state_dict['current_step']
|